From cd4b0405b989e7601c8d628fcdfc9c1aab12546a Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Thu, 16 Jan 2020 16:38:15 -0500 Subject: [PATCH] Ajout fonctionnalitees apprentissage parametre optimisateur + mix dist --- higher/dataug.py | 21 ++++++++++++--- higher/test_dataug.py | 27 ++++++++++++-------- higher/train_utils.py | 59 ++++++++++++++++++++++++++----------------- 3 files changed, 70 insertions(+), 37 deletions(-) diff --git a/higher/dataug.py b/higher/dataug.py index 67009f2..466a8e5 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -545,11 +545,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._shared_mag = shared_mag self._fixed_mag = fixed_mag + self._fixed_mix=True + if mix_dist is None: #Learn Mix dist + self._fixed_mix = False + mix_dist=0.5 + init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2 self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme "mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX] + "mix_dist": nn.Parameter(torch.tensor(mix_dist).clamp(min=0.0,max=0.999)) }) for tf in TF.TF_no_grad : @@ -560,9 +566,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._fixed_prob=fixed_prob self._samples = [] self._mix_dist = False - if mix_dist != 0.0: + if mix_dist != 0.0: #Mix dist self._mix_dist = True - self._mix_factor = max(min(mix_dist, 0.999), 0.0) + #self._mix_factor = max(min(mix_dist, 0.999), 0.0) #Mag regularisation if not self._fixed_mag: @@ -588,7 +594,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._distrib = uniforme_dist else: prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] - self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + mix_dist = self._params["mix_dist"].detach() if self._fixed_mix else self._params["mix_dist"] + #self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + self._distrib = (mix_dist*prob+(1-mix_dist)*uniforme_dist)#.softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) sample = cat_distrib.sample() @@ -638,6 +646,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) #self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) + if not self._fixed_mix: + self._params['mix_dist'].data = self._params['mix_dist'].data.clamp(min=0.0, max=0.999) + def loss_weight(self): if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation @@ -692,8 +703,10 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) if self._shared_mag: mag_param+= 'Sh' if not self._mix_dist: return "Data_augV5(Uniform%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param) + elif self._fixed_mix: + return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._params['mix_dist'].item(),dist_param, self._nb_tf, self._N_seqTF, mag_param) else: - return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param) + return "Data_augV5(Mix%s-%dTFx%d-%s)" % (dist_param, self._nb_tf, self._N_seqTF, mag_param) class Data_augV6(nn.Module): #Optimisation sequentielle #Mauvais resultats diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 4f08fa9..97a03d4 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -67,7 +67,7 @@ if __name__ == "__main__": 'aug_model' } n_inner_iter = 1 - epochs = 15 + epochs = 150 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -81,9 +81,10 @@ if __name__ == "__main__": } } - #model = LeNet(3,10) + model = LeNet(3,10) + #model = ResNet(num_classes=10) + #Lents #model = MobileNetV2(num_classes=10) - model = ResNet(num_classes=10) #model = WideResNet(num_classes=10, wrn_size=32) model = Higher_model(model) #run_dist_dataugV3 @@ -94,8 +95,8 @@ if __name__ == "__main__": model = model.to(device) print("{} on {} for {} epochs".format(str(model), device_name, epochs)) - #log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10) - log= train_classic_higher(model=model, epochs=epochs) + log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1) + #log= train_classic_higher(model=model, epochs=epochs) exec_time=time.process_time() - t0 #### @@ -181,6 +182,7 @@ if __name__ == "__main__": opt_param=optim_param, print_freq=1, KLdiv=True, + hp_opt=True, loss_patience=None) exec_time=time.process_time() - t0 @@ -189,12 +191,17 @@ if __name__ == "__main__": times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) - filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)#+"demi_mag" + filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+"-opt_hp" with open("res/log/%s.json" % filename, "w+") as f: - json.dump(out, f, indent=True) - print('Log :\"',f.name, '\" saved !') - - plot_resV2(log, fig_name="res/"+filename, param_names=aug_model.TF_names()) + try: + json.dump(out, f, indent=True) + print('Log :\"',f.name, '\" saved !') + except: + print("Failed to save logs :",f.name) + try: + plot_resV2(log, fig_name="res/"+filename, param_names=aug_model.TF_names()) + except: + print("Failed to plot res") print('Execution Time : %.00f '%(exec_time)) print('-'*9) \ No newline at end of file diff --git a/higher/train_utils.py b/higher/train_utils.py index 8a64c25..fd9b9da 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -70,14 +70,8 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): #### Tests #### tf = time.process_time() - try: - xs_val, ys_val = next(dl_val_it) - except StopIteration: #Fin epoch val - dl_val_it = iter(dl_val) - xs_val, ys_val = next(dl_val_it) - xs_val, ys_val = xs_val.to(device), ys_val.to(device) - val_loss = F.cross_entropy(model(xs_val), ys_val) + val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) accuracy, _ =test(model) model.train() @@ -656,6 +650,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) + meta_opt.zero_grad() + for epoch in range(1, epochs+1): #print_torch_mem("Start epoch "+str(epoch)) #print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params)) @@ -755,6 +751,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) + meta_opt.zero_grad() + tf = time.process_time() #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) @@ -825,17 +823,13 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start #print("Copy ", countcopy) return log -def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False): +def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, hp_opt=False, loss_patience=None, save_sample=False): device = next(model.parameters()).device log = [] countcopy=0 val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch dl_val_it = iter(dl_val) - #if inner_it!=0: - meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 - inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 - high_grad_track = True if inner_it == 0: high_grad_track=False @@ -848,22 +842,28 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data) - model.train() - - #fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True) - #diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) - #fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - #diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) + ## Optimizers ## + #Inner Opt + inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 diffopt = model['model'].get_diffopt( inner_opt, grad_callback=(lambda grads: clip_norm(grads, max_norm=10)), track_higher_grads=high_grad_track) - #meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 - + #Meta Opt + hyper_param = list(model['data_aug'].parameters()) + if hp_opt : + for param_group in diffopt.param_groups: + for param in list(opt_param['Inner'].keys())[1:]: + param_group[param]=torch.tensor(param_group[param]).to(device).requires_grad_() + hyper_param += [param_group[param]] + meta_opt = torch.optim.Adam(hyper_param, lr=opt_param['Meta']['lr']) #lr=1e-2 #print(len(model['model']['functional']._fast_params)) + model.train() + meta_opt.zero_grad() + for epoch in range(1, epochs+1): #print_torch_mem("Start epoch "+str(epoch)) #print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params)) @@ -919,9 +919,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start #print(fmodel['model']._params['b4'].grad) #print('prob grad', fmodel['data_aug']['prob'].grad) - t = time.process_time() + #t = time.process_time() diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) - print(len(model['model']['functional']._fast_params),"step", time.process_time()-t) + #print(len(model['model']['functional']._fast_params),"step", time.process_time()-t) if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step @@ -937,8 +937,15 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start meta_opt.step() model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 + if hp_opt: + for param_group in diffopt.param_groups: + for param in list(opt_param['Inner'].keys())[1:]: + param_group[param].data = param_group[param].data.clamp(min=1e-4) + diffopt.detach_() model['model'].detach_() + + meta_opt.zero_grad() tf = time.process_time() @@ -963,9 +970,11 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start "acc": accuracy, "time": tf - t0, - "param": param #if isinstance(model['data_aug'], Data_augV5) + "mix_dist": model['data_aug']['mix_dist'].item(), + "param": param, #if isinstance(model['data_aug'], Data_augV5) #else [p.item() for p in model['data_aug']['prob']], } + if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups] log.append(data) ############# #### Print #### @@ -980,8 +989,12 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start #print('proba grad',model['data_aug']['prob'].grad) print('TF Mag :', model['data_aug']['mag'].data) #print('Mag grad',model['data_aug']['mag'].grad) + print('Mix:', model['data_aug']['mix_dist'].data) #print('Reg loss:', model['data_aug'].reg_loss().item()) #print('Aug loss', aug_loss.item()) + if hp_opt : + for param_group in diffopt.param_groups: + print('Opt param - lr:', param_group['lr'].item(),'- momentum:', param_group['momentum'].item()) ############# if val_loss_monitor : model.eval()