mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Brutus
This commit is contained in:
parent
53bd421670
commit
e291bc2e44
9 changed files with 55 additions and 44 deletions
|
@ -2,13 +2,12 @@ from utils import *
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
'''
|
#'''
|
||||||
files=[
|
files=[
|
||||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
"res/log/Aug_mod(Data_augV5(Mix0.8-23TFx4-Mag)-LeNet)-100 epochs (dataug:0)- 1 in_it.json",
|
||||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
|
||||||
"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-0.json",
|
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
|
||||||
"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-1.json",
|
#"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
|
||||||
"res/brutus-tests/log/Aug_mod(Data_augV5(Uniform-14TFx3-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-2.json",
|
|
||||||
#"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -18,7 +17,7 @@ if __name__ == "__main__":
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
||||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||||
'''
|
#'''
|
||||||
## Loss , Acc, Proba = f(epoch) ##
|
## Loss , Acc, Proba = f(epoch) ##
|
||||||
#plot_compare(filenames=files, fig_name="res/compare")
|
#plot_compare(filenames=files, fig_name="res/compare")
|
||||||
|
|
||||||
|
@ -78,7 +77,7 @@ if __name__ == "__main__":
|
||||||
'''
|
'''
|
||||||
|
|
||||||
#Res print
|
#Res print
|
||||||
#'''
|
'''
|
||||||
nb_run=3
|
nb_run=3
|
||||||
accs = []
|
accs = []
|
||||||
times = []
|
times = []
|
||||||
|
@ -93,4 +92,4 @@ if __name__ == "__main__":
|
||||||
print(idx, data['Accuracy'])
|
print(idx, data['Accuracy'])
|
||||||
|
|
||||||
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
||||||
#'''
|
'''
|
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ):
|
||||||
super(Data_augV5, self).__init__()
|
super(Data_augV5, self).__init__()
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
|
|
||||||
|
@ -548,8 +548,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||||
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)) if self._shared_mag
|
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag
|
||||||
else torch.tensor(float(TF.PARAMETER_MAX)).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||||
})
|
})
|
||||||
|
|
||||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||||
|
@ -633,7 +633,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
|
||||||
if not self._fixed_mag:
|
if not self._fixed_mag:
|
||||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||||
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
#self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||||
|
|
||||||
def loss_weight(self):
|
def loss_weight(self):
|
||||||
|
|
|
@ -93,15 +93,15 @@ if __name__ == "__main__":
|
||||||
json.dump(out, f, indent=True)
|
json.dump(out, f, indent=True)
|
||||||
print('Log :\"',f.name, '\" saved !')
|
print('Log :\"',f.name, '\" saved !')
|
||||||
'''
|
'''
|
||||||
res_folder="res/brutus-tests/"
|
res_folder="res/brutus-tests2/"
|
||||||
epochs= 150
|
epochs= 150
|
||||||
inner_its = [1]
|
inner_its = [1]
|
||||||
dist_mix = [0.0, 0.5, 0.8, 1.0]
|
dist_mix = [0.0, 0.5, 0.8, 1.0]
|
||||||
dataug_epoch_starts= [0]
|
dataug_epoch_starts= [0]
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||||
N_seq_TF= [2, 3]
|
N_seq_TF= [2, 3, 4]
|
||||||
mag_setup = [(True,True), (False, False)]
|
mag_setup = [(True,True), (False, False)] #(Fixed, Shared)
|
||||||
#prob_setup = [True, False]
|
#prob_setup = [True, False]
|
||||||
nb_run= 3
|
nb_run= 3
|
||||||
|
|
||||||
|
@ -118,12 +118,14 @@ if __name__ == "__main__":
|
||||||
#for i in TF_nb:
|
#for i in TF_nb:
|
||||||
for m_setup in mag_setup:
|
for m_setup in mag_setup:
|
||||||
#for p_setup in prob_setup:
|
#for p_setup in prob_setup:
|
||||||
p_setup=True
|
p_setup=False
|
||||||
for run in range(nb_run):
|
for run in range(nb_run):
|
||||||
if n_inner_iter == 0 and (m_setup!=(True,True) and p_setup!=True): continue #Autres setup inutiles sans meta-opti
|
if (n_inner_iter == 0 and (m_setup!=(True,True) and p_setup!=True)) or (p_setup and dist!=0.0): continue #Autres setup inutiles sans meta-opti
|
||||||
#keys = list(TF.TF_dict.keys())[0:i]
|
#keys = list(TF.TF_dict.keys())[0:i]
|
||||||
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
|
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
|
||||||
|
|
||||||
|
t0 = time.process_time()
|
||||||
|
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
|
@ -143,9 +145,9 @@ if __name__ == "__main__":
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
|
||||||
with open("res/log/%s.json" % filename, "w+") as f:
|
with open(res_folder+"log/%s.json" % filename, "w+") as f:
|
||||||
json.dump(out, f, indent=True)
|
json.dump(out, f, indent=True)
|
||||||
print('Log :\"',f.name, '\" saved !')
|
print('Log :\"',f.name, '\" saved !')
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
'''
|
#'''
|
||||||
|
|
|
@ -19,8 +19,8 @@ tf_names = [
|
||||||
'Color',
|
'Color',
|
||||||
'Brightness',
|
'Brightness',
|
||||||
'Sharpness',
|
'Sharpness',
|
||||||
'Posterize',
|
#'Posterize',
|
||||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
#Color TF (Common mag scale)
|
#Color TF (Common mag scale)
|
||||||
#'+Contrast',
|
#'+Contrast',
|
||||||
|
@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||||
#'aug_dataset',
|
#'aug_dataset',
|
||||||
'aug_model'
|
'aug_model'
|
||||||
}
|
}
|
||||||
n_inner_iter = 10
|
n_inner_iter = 1
|
||||||
epochs = 100
|
epochs = 100
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
optim_param={
|
optim_param={
|
||||||
|
@ -168,7 +168,7 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
|
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||||
|
@ -187,7 +187,7 @@ if __name__ == "__main__":
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+"demi_mag"
|
||||||
with open("res/log/%s.json" % filename, "w+") as f:
|
with open("res/log/%s.json" % filename, "w+") as f:
|
||||||
json.dump(out, f, indent=True)
|
json.dump(out, f, indent=True)
|
||||||
print('Log :\"',f.name, '\" saved !')
|
print('Log :\"',f.name, '\" saved !')
|
||||||
|
|
|
@ -90,6 +90,7 @@ def plot_resV2(log, fig_name='res', param_names=None):
|
||||||
|
|
||||||
ax[0, 2].set_title('Mag =f(epoch)')
|
ax[0, 2].set_title('Mag =f(epoch)')
|
||||||
ax[0, 2].stackplot(epochs, mag, labels=param_names)
|
ax[0, 2].stackplot(epochs, mag, labels=param_names)
|
||||||
|
#ax[0, 2].plot(epochs, np.array(mag).T, label=param_names)
|
||||||
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
ax[1, 2].set_title('Mag =f(TF)')
|
ax[1, 2].set_title('Mag =f(TF)')
|
||||||
|
|
Binary file not shown.
|
@ -31,12 +31,12 @@ tf_names = [
|
||||||
'ShearY',
|
'ShearY',
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
'Contrast',
|
#'Contrast',
|
||||||
'Color',
|
#'Color',
|
||||||
'Brightness',
|
#'Brightness',
|
||||||
'Sharpness',
|
#'Sharpness',
|
||||||
'Posterize',
|
#'Posterize',
|
||||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
]
|
]
|
||||||
|
|
||||||
class Lambda(nn.Module):
|
class Lambda(nn.Module):
|
||||||
|
@ -95,6 +95,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mas
|
||||||
|
|
||||||
unsupp_coeff = 1
|
unsupp_coeff = 1
|
||||||
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
|
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
|
||||||
|
#print(sup_loss.item(), (aug_loss + KL_loss).item())
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -210,7 +211,7 @@ def get_train_valid_loader(args, augment, random_seed, valid_size=0.1, shuffle=T
|
||||||
split = int(np.floor(valid_size * num_train))
|
split = int(np.floor(valid_size * num_train))
|
||||||
|
|
||||||
if shuffle:
|
if shuffle:
|
||||||
#np.random.seed(random_seed)
|
np.random.seed(random_seed)
|
||||||
np.random.shuffle(indices)
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
train_idx, valid_idx = indices[split:], indices[:split]
|
train_idx, valid_idx = indices[split:], indices[:split]
|
||||||
|
@ -277,6 +278,8 @@ def main(args):
|
||||||
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
if args.augment=='RandKL': Kldiv=True
|
if args.augment=='RandKL': Kldiv=True
|
||||||
|
|
||||||
|
model['data_aug']['mag'].data = model['data_aug']['mag'].data * args.magnitude
|
||||||
print("Augmodel")
|
print("Augmodel")
|
||||||
|
|
||||||
# model.fc = nn.Linear(model.fc.in_features, 2)
|
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||||
|
@ -294,7 +297,7 @@ def main(args):
|
||||||
optimizer,
|
optimizer,
|
||||||
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
||||||
|
|
||||||
es = utils.EarlyStopping()
|
es = utils.EarlyStopping() if not (args.augment=='Rand' or args.augment=='RandKL') else utils.EarlyStopping(augmented_model=True)
|
||||||
|
|
||||||
if args.test_only:
|
if args.test_only:
|
||||||
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||||
|
@ -324,8 +327,8 @@ def main(args):
|
||||||
|
|
||||||
# print('Train')
|
# print('Train')
|
||||||
# print(train_confmat)
|
# print(train_confmat)
|
||||||
print('Valid')
|
#print('Valid')
|
||||||
print(valid_confmat)
|
#print(valid_confmat)
|
||||||
|
|
||||||
# if es.early_stop:
|
# if es.early_stop:
|
||||||
# break
|
# break
|
||||||
|
@ -339,9 +342,9 @@ def parse_args():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
||||||
|
|
||||||
parser.add_argument('--data-path', default='/Salvador', help='dataset')
|
parser.add_argument('--data-path', default='/github/smart_augmentation/salvador/data', help='dataset')
|
||||||
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
|
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
|
||||||
parser.add_argument('--device', default='cuda:1', help='device')
|
parser.add_argument('--device', default='cuda:0', help='device')
|
||||||
parser.add_argument('-b', '--batch-size', default=8, type=int)
|
parser.add_argument('-b', '--batch-size', default=8, type=int)
|
||||||
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
||||||
help='number of total epochs to run')
|
help='number of total epochs to run')
|
||||||
|
@ -364,6 +367,10 @@ def parse_args():
|
||||||
parser.add_argument('-a', '--augment', default='None', type=str,
|
parser.add_argument('-a', '--augment', default='None', type=str,
|
||||||
metavar='N', help='Data augment',
|
metavar='N', help='Data augment',
|
||||||
dest='augment')
|
dest='augment')
|
||||||
|
parser.add_argument('-m', '--magnitude', default=1.0, type=float,
|
||||||
|
metavar='N', help='Augmentation magnitude',
|
||||||
|
dest='magnitude')
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -549,10 +549,10 @@ def parse_args():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
||||||
|
|
||||||
parser.add_argument('--data-path', default='/Salvador', help='dataset')
|
parser.add_argument('--data-path', default='/github/smart_augmentation/salvador/data', help='dataset')
|
||||||
parser.add_argument('--model', default='resnet50', help='model')
|
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
|
||||||
parser.add_argument('--device', default='cuda:1', help='device')
|
parser.add_argument('--device', default='cuda:0', help='device')
|
||||||
parser.add_argument('-b', '--batch-size', default=4, type=int)
|
parser.add_argument('-b', '--batch-size', default=8, type=int)
|
||||||
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
||||||
help='number of total epochs to run')
|
help='number of total epochs to run')
|
||||||
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
|
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
|
||||||
|
|
|
@ -157,7 +157,7 @@ def accuracy(output, target, topk=(1,)):
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
||||||
def __init__(self, patience=7, verbose=False, delta=0):
|
def __init__(self, patience=7, verbose=False, delta=0, augmented_model=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
patience (int): How long to wait after last time validation loss improved.
|
patience (int): How long to wait after last time validation loss improved.
|
||||||
|
@ -175,6 +175,8 @@ class EarlyStopping:
|
||||||
self.val_loss_min = np.Inf
|
self.val_loss_min = np.Inf
|
||||||
self.delta = delta
|
self.delta = delta
|
||||||
|
|
||||||
|
self.augmented_model = augmented_model
|
||||||
|
|
||||||
def __call__(self, val_loss, model):
|
def __call__(self, val_loss, model):
|
||||||
|
|
||||||
score = -val_loss
|
score = -val_loss
|
||||||
|
@ -196,5 +198,5 @@ class EarlyStopping:
|
||||||
'''Saves model when validation loss decrease.'''
|
'''Saves model when validation loss decrease.'''
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
||||||
torch.save(model.state_dict(), 'checkpoint.pt')
|
torch.save(model.state_dict(), 'checkpoint.pt') if not self.augmented_model else torch.save(model['model'].state_dict(), 'checkpoint.pt')
|
||||||
self.val_loss_min = val_loss
|
self.val_loss_min = val_loss
|
Loading…
Add table
Add a link
Reference in a new issue