diff --git a/higher/datasets.py b/higher/datasets.py index 17be0ff..7d0589f 100644 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -38,8 +38,8 @@ data_test = torchvision.datasets.CIFAR10( "./data", train=False, download=True, transform=transform ) #''' -#train_subset_indices=range(int(len(data_train)/2)) -train_subset_indices=range(BATCH_SIZE*10) +train_subset_indices=range(int(len(data_train)/2)) +#train_subset_indices=range(BATCH_SIZE*10) val_subset_indices=range(int(len(data_train)/2),len(data_train)) dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices)) diff --git a/higher/dataug.py b/higher/dataug.py index 1726936..b80de85 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -114,7 +114,7 @@ class Data_augV2(nn.Module): #Methode exacte return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w) - def adjust_prob(self): #Detach from gradient ? + def adjust_param(self): #Detach from gradient ? self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) #print('proba',self._params['prob']) self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 @@ -262,7 +262,7 @@ class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte # warp the original image by the found transform return kornia.warp_perspective(x, M, dsize=(h, w)) - def adjust_prob(self, soft=False): #Detach from gradient ? + def adjust_param(self, soft=False): #Detach from gradient ? if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible @@ -478,7 +478,7 @@ class Data_augV4(nn.Module): #Transformations avec mask ''' return x - def adjust_prob(self, soft=False): #Detach from gradient ? + def adjust_param(self, soft=False): #Detach from gradient ? if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible @@ -549,15 +549,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme "mag" : nn.Parameter(torch.tensor(0.5) if self._shared_mag - else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10 + else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX] }) - self._samples = [] + #Distribution + self._samples = [] self._mix_dist = False if mix_dist != 0.0: self._mix_dist = True self._mix_factor = max(min(mix_dist, 1.0), 0.0) + #Mag regularisation + if not self._fixed_mag: + ignore={'Identity', 'FlipUD', 'FlipLR', 'Solarize', 'Posterize'} + self._reg_mask=[self._TF.index(t) for t in self._TF if t not in ignore] + self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max + def forward(self, x): if self._data_augmentation: device = x.device @@ -610,18 +617,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) return x - def adjust_prob(self, soft=False): #Detach from gradient ? + def adjust_param(self, soft=False): #Detach from gradient ? if soft : self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible else: - #self._params['prob'].clamp(min=0.0,max=1.0) self._params['prob'].data = F.relu(self._params['prob'].data) #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) - self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 - + #self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme + self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) def loss_weight(self): # 1 seule TF @@ -642,6 +648,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) w_loss = torch.sum(w_loss,dim=1) return w_loss + def reg_loss(self, reg_factor=0.005): + #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') + return reg_factor * F.mse_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt.to(self._params['mag'].device), reduction='mean') def train(self, mode=None): if mode is None : diff --git a/higher/res/Aug_mod(Data_augV4(Uniform-1 TF)-LeNet)-2 epochs (dataug:0)- 0 in_it.png b/higher/res/Aug_mod(Data_augV4(Uniform-1 TF)-LeNet)-2 epochs (dataug:0)- 0 in_it.png deleted file mode 100644 index 27234e7..0000000 Binary files a/higher/res/Aug_mod(Data_augV4(Uniform-1 TF)-LeNet)-2 epochs (dataug:0)- 0 in_it.png and /dev/null differ diff --git a/higher/res/Aug_mod(Data_augV4(Uniform-2 TF)-LeNet)-2 epochs (dataug:0)- 0 in_it.png b/higher/res/Aug_mod(Data_augV4(Uniform-2 TF)-LeNet)-2 epochs (dataug:0)- 0 in_it.png deleted file mode 100644 index cd401f8..0000000 Binary files a/higher/res/Aug_mod(Data_augV4(Uniform-2 TF)-LeNet)-2 epochs (dataug:0)- 0 in_it.png and /dev/null differ diff --git a/higher/res/log/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-2 epochs (dataug:0)- 10 in_it.json b/higher/res/log/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-2 epochs (dataug:0)- 10 in_it.json deleted file mode 100644 index 334d4ce..0000000 --- a/higher/res/log/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-2 epochs (dataug:0)- 10 in_it.json +++ /dev/null @@ -1,72 +0,0 @@ -{ - "Accuracy": 20.8, - "Time": [ - 51.4427050715, - 0.4778038694999971 - ], - "Device": "TITAN RTX", - "Param_names": [ - "Identity", - "FlipUD", - "FlipLR", - "Rotate", - "TranslateX", - "TranslateY", - "ShearX", - "ShearY", - "Contrast", - "Color", - "Brightness", - "Sharpness", - "Posterize", - "Solarize" - ], - "Log": [ - { - "epoch": 1, - "train_loss": 2.3032476902008057, - "val_loss": 2.2924728393554688, - "acc": 11.1, - "time": 51.920508941, - "param": [ - 0.07925213128328323, - 0.08312409371137619, - 0.08779778331518173, - 0.0853320062160492, - 0.08577536046504974, - 0.057290591299533844, - 0.0774931013584137, - 0.08246791362762451, - 0.047001805156469345, - 0.07887403666973114, - 0.05897113308310509, - 0.05021947622299194, - 0.07581018656492233, - 0.050590354949235916 - ] - }, - { - "epoch": 2, - "train_loss": 2.171858787536621, - "val_loss": 2.078795909881592, - "acc": 20.8, - "time": 50.96490120200001, - "param": [ - 0.07892196625471115, - 0.07488056272268295, - 0.08041033148765564, - 0.09144628793001175, - 0.09114645421504974, - 0.055715303868055344, - 0.0672164335846901, - 0.07994510233402252, - 0.05105787515640259, - 0.09191103279590607, - 0.07849953323602676, - 0.07014491409063339, - 0.07624118775129318, - 0.012463102117180824 - ] - } - ] -} \ No newline at end of file diff --git a/higher/res/log/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-3 epochs (dataug:0)- 0 in_it.json b/higher/res/log/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-3 epochs (dataug:0)- 0 in_it.json deleted file mode 100644 index 17859fd..0000000 --- a/higher/res/log/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-3 epochs (dataug:0)- 0 in_it.json +++ /dev/null @@ -1,95 +0,0 @@ -{ - "Accuracy": 31.369999999999997, - "Time": [ - 38.67262149066667, - 0.4140408795968137 - ], - "Device": "TITAN RTX", - "Param_names": [ - "Identity", - "FlipUD", - "FlipLR", - "Rotate", - "TranslateX", - "TranslateY", - "ShearX", - "ShearY", - "Contrast", - "Color", - "Brightness", - "Sharpness", - "Posterize", - "Solarize" - ], - "Log": [ - { - "epoch": 1, - "train_loss": 2.2571041584014893, - "val_loss": 2.212921142578125, - "acc": 20.169999999999998, - "time": 38.788926192000005, - "param": [ - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774 - ] - }, - { - "epoch": 2, - "train_loss": 2.212834358215332, - "val_loss": 2.043567180633545, - "acc": 25.009999999999998, - "time": 38.117478509, - "param": [ - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774 - ] - }, - { - "epoch": 3, - "train_loss": 2.091825008392334, - "val_loss": 1.9359350204467773, - "acc": 31.369999999999997, - "time": 39.111459771, - "param": [ - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774, - 0.0714285746216774 - ] - } - ] -} \ No newline at end of file diff --git a/higher/res/log/LeNet-3 epochs.json b/higher/res/log/LeNet-3 epochs.json deleted file mode 100644 index c5addf9..0000000 --- a/higher/res/log/LeNet-3 epochs.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "Accuracy": 39.2, - "Time": [ - 3.9452463850000012, - 0.2891758564900622 - ], - "Device": "TITAN RTX", - "Log": [ - { - "epoch": 0, - "train_loss": 2.109266757965088, - "val_loss": 2.1106348037719727, - "acc": 22.3, - "time": 4.312489993, - "param": null - }, - { - "epoch": 1, - "train_loss": 1.7782783508300781, - "val_loss": 1.8776130676269531, - "acc": 33.76, - "time": 3.605794182000002, - "param": null - }, - { - "epoch": 2, - "train_loss": 1.8152618408203125, - "val_loss": 1.6963396072387695, - "acc": 39.2, - "time": 3.9174549800000023, - "param": null - } - ] -} \ No newline at end of file diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 66e9665..56c31a6 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -5,9 +5,9 @@ from train_utils import * tf_names = [ ## Geometric TF ## - #'Identity', - #'FlipUD', - #'FlipLR', + 'Identity', + 'FlipUD', + 'FlipLR', 'Rotate', 'TranslateX', 'TranslateY', @@ -37,8 +37,8 @@ else: ########################################## if __name__ == "__main__": - n_inner_iter = 1 - epochs = 2 + n_inner_iter = 10 + epochs = 200 dataug_epoch_start=0 #### Classic #### @@ -57,7 +57,7 @@ if __name__ == "__main__": print('-'*9) times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log} - print(str(model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1]) + print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) with open("res/log/%s.json" % "{}-{} epochs".format(str(model),epochs), "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') @@ -68,7 +68,7 @@ if __name__ == "__main__": t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) @@ -79,12 +79,13 @@ if __name__ == "__main__": print('-'*9) times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - print(str(aug_model),": acc", out["Accuracy"], "in (s?):", out["Time"][0], "+/-", out["Time"][1]) + print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) with open("res/log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') - print('Execution Time : %.00f (s?)'%(time.process_time() - t0)) + print('TF influence', TF_influence(log)) + print('Execution Time : %.00f '%(time.process_time() - t0)) print('-'*9) #''' #### TF number tests #### diff --git a/higher/train_utils.py b/higher/train_utils.py index df07326..81b4fd3 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -528,7 +528,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0): optim_copy(dopt=diffopt, opt=inner_opt) meta_opt.step() - model['data_aug'].adjust_prob() #Contrainte sum(proba)=1 + model['data_aug'].adjust_param() #Contrainte sum(proba)=1 print("Copy ", countcopy) return log @@ -588,7 +588,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() if fmodel._data_augmentation: #Weight loss - w_loss = fmodel['data_aug'].loss_weight().to(device) + w_loss = fmodel['data_aug'].loss_weight()#.to(device) loss = loss * w_loss loss = loss.mean() #''' @@ -605,7 +605,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f if(high_grad_track and i%inner_it==0): #Perform Meta step #print("meta") #Peu utile si high_grad_track = False - val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss() #print_graph(val_loss) @@ -616,15 +616,15 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f optim_copy(dopt=diffopt, opt=inner_opt) meta_opt.step() - model['data_aug'].adjust_prob(soft=False) #Contrainte sum(proba)=1 + model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) tf = time.process_time() - viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) if(not high_grad_track): countcopy+=1 @@ -643,7 +643,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f if(print_freq and epoch%print_freq==0): print('-'*9) print('Epoch : %d/%d'%(epoch,epochs)) - print('Time : %.00f s'%(tf - t0)) + print('Time : %.00f'%(tf - t0)) print('Train loss :',loss.item(), '/ val loss', val_loss.item()) print('Accuracy :', accuracy) print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) @@ -651,6 +651,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f #print('proba grad',model['data_aug']['prob'].grad) print('TF Mag :', model['data_aug']['mag'].data) #print('Mag grad',model['data_aug']['mag'].grad) + print('Reg loss:', model['data_aug'].reg_loss().item()) ############# #### Log #### #print(type(model['data_aug']) is dataug.Data_augV5) diff --git a/higher/utils.py b/higher/utils.py index f1a9630..7df741a 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -254,6 +254,11 @@ def print_torch_mem(add_info=''): torch.cuda.max_memory_cached()/ mega_bytes) print(string) +def TF_influence(log): + proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])] + mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])] + + return np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean class loss_monitor(): #Voir https://github.com/pytorch/ignite def __init__(self, patience, end_train=1):