diff --git a/higher/dataug.py b/higher/dataug.py index 3f731ac..67009f2 100755 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) - def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ): + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True): super(Data_augV5, self).__init__() assert len(TF_dict)>0 @@ -545,13 +545,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._shared_mag = shared_mag self._fixed_mag = fixed_mag - #self._fixed_mag=5 #[0, PARAMETER_MAX] + init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2 self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme - "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag - else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX] + "mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag + else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX] }) + for tf in TF.TF_no_grad : + if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter #for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag #Distribution @@ -1094,8 +1096,8 @@ class Augmented_model(nn.Module): self.augment(mode=True) - def initialize(self): - self._mods['model'].initialize() + #def initialize(self): + # self._mods['model'].initialize() def forward(self, x): return self._mods['model'](self._mods['data_aug'](x)) @@ -1136,4 +1138,81 @@ class Augmented_model(nn.Module): return self._mods[key] def __str__(self): - return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" \ No newline at end of file + return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" + +''' +import higher +class Augmented_model2(nn.Module): + def __init__(self, data_augmenter, model): + super(Augmented_model2, self).__init__() + + self._mods = nn.ModuleDict({ + 'data_aug': data_augmenter, + 'model': model, + 'fmodel': None + }) + + self.augment(mode=True) + + def initialize(self): + self._mods['model'].initialize() + + def forward(self, x): + if self._mods['fmodel']: + return self._mods['fmodel'](self._mods['data_aug'](x)) + else: + return self._mods['model'](self._mods['data_aug'](x)) + + def functional(self, opt, track_higher_grads=True): + self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True) + + return higher.optim.get_diff_optim(opt, + self._mods['model'].parameters(), + fmodel=self._mods['fmodel'], + track_higher_grads=track_higher_grads) + + def detach_(self): + tmp = self._mods['fmodel'].fast_params + self._mods['fmodel']._fast_params=[] + self._mods['fmodel'].update_params(tmp) + for p in self._mods['fmodel'].fast_params: + p.detach_().requires_grad_() + + def augment(self, mode=True): + self._data_augmentation=mode + self._mods['data_aug'].augment(mode) + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self._mods['data_aug'].augment(mode) + super(Augmented_model2, self).train(mode) + return self + + def eval(self): + return self.train(mode=False) + #super(Augmented_model, self).eval() + + def items(self): + """Return an iterable of the ModuleDict key/value pairs. + """ + return self._mods.items() + + def update(self, modules): + self._mods.update(modules) + + def is_augmenting(self): + return self._data_augmentation + + def TF_names(self): + try: + return self._mods['data_aug']._TF + except: + return None + + def __getitem__(self, key): + return self._mods[key] + + def __str__(self): + return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")" +''' \ No newline at end of file diff --git a/higher/model.py b/higher/model.py index 794aefd..84afeff 100755 --- a/higher/model.py +++ b/higher/model.py @@ -3,6 +3,40 @@ import torch import torch.nn as nn import torch.nn.functional as F + +import higher +class Higher_model(nn.Module): + def __init__(self, model): + super(Higher_model, self).__init__() + + self._mods = nn.ModuleDict({ + 'original': model, + 'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + }) + + def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True): + return higher.optim.get_diff_optim(opt, + self._mods['original'].parameters(), + fmodel=self._mods['functional'], + grad_callback=grad_callback, + track_higher_grads=track_higher_grads) + + def forward(self, x): + return self._mods['functional'](x) + + def detach_(self): + tmp = self._mods['functional'].fast_params + self._mods['functional']._fast_params=[] + self._mods['functional'].update_params(tmp) + for p in self._mods['functional'].fast_params: + p.detach_().requires_grad_() + + def __getitem__(self, key): + return self._mods[key] + + def __str__(self): + return self._mods['original'].__str__() + ## Basic CNN ## class LeNet_F(nn.Module): def __init__(self, num_inp, num_out): diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 8278431..4f08fa9 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -19,8 +19,8 @@ tf_names = [ 'Color', 'Brightness', 'Sharpness', - #'Posterize', - #'Solarize', #=>Image entre [0,1] #Pas opti pour des batch + 'Posterize', + 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch #Color TF (Common mag scale) #'+Contrast', @@ -67,7 +67,7 @@ if __name__ == "__main__": 'aug_model' } n_inner_iter = 1 - epochs = 100 + epochs = 15 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -81,11 +81,13 @@ if __name__ == "__main__": } } - model = LeNet(3,10) + #model = LeNet(3,10) #model = MobileNetV2(num_classes=10) - #model = ResNet(num_classes=10) + model = ResNet(num_classes=10) #model = WideResNet(num_classes=10, wrn_size=32) + model = Higher_model(model) #run_dist_dataugV3 + #### Classic #### if 'classic' in tasks: t0 = time.process_time() @@ -172,12 +174,12 @@ if __name__ == "__main__": #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_dist_dataugV2(model=aug_model, + log= run_dist_dataugV3(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=10, + print_freq=1, KLdiv=True, loss_patience=None) @@ -187,7 +189,7 @@ if __name__ == "__main__": times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) - filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+"demi_mag" + filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)#+"demi_mag" with open("res/log/%s.json" % filename, "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') diff --git a/higher/train_utils.py b/higher/train_utils.py index d22fb31..8a64c25 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -654,7 +654,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start model.train() fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) + diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) for epoch in range(1, epochs+1): #print_torch_mem("Start epoch "+str(epoch)) @@ -742,9 +742,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start model_copy(src=fmodel, dst=model) optim_copy(dopt=diffopt, opt=inner_opt) - torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN - torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN - + torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN + #if epoch>50: meta_opt.step() model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 @@ -835,7 +834,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start #if inner_it!=0: meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 - inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 high_grad_track = True if inner_it == 0: @@ -853,12 +852,17 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start #fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True) #diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) + #fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + #diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) + + diffopt = model['model'].get_diffopt( + inner_opt, + grad_callback=(lambda grads: clip_norm(grads, max_norm=10)), + track_higher_grads=high_grad_track) #meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 - print(len(fmodel._fast_params)) + #print(len(model['model']['functional']._fast_params)) for epoch in range(1, epochs+1): #print_torch_mem("Start epoch "+str(epoch)) @@ -871,30 +875,30 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start if(not KLdiv): #Methode uniforme - logits = fmodel(xs) # modified `params` can also be passed as a kwarg + logits = model(xs) # modified `params` can also be passed as a kwarg loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards() - if fmodel._data_augmentation: #Weight loss - w_loss = fmodel['data_aug'].loss_weight()#.to(device) + if model._data_augmentation: #Weight loss + w_loss = model['data_aug'].loss_weight()#.to(device) loss = loss * w_loss loss = loss.mean() else: #Methode KL div - if fmodel._data_augmentation : - fmodel.augment(mode=False) - sup_logits = fmodel(xs) - fmodel.augment(mode=True) + if model._data_augmentation : + model.augment(mode=False) + sup_logits = model(xs) + model.augment(mode=True) else: - sup_logits = fmodel(xs) + sup_logits = model(xs) log_sup=F.log_softmax(sup_logits, dim=1) loss = F.cross_entropy(log_sup, ys) - if fmodel._data_augmentation: - aug_logits = fmodel(xs) + if model._data_augmentation: + aug_logits = model(xs) log_aug=F.log_softmax(aug_logits, dim=1) aug_loss=0 - w_loss = fmodel['data_aug'].loss_weight() #Weight loss + w_loss = model['data_aug'].loss_weight() #Weight loss #if epoch>50: #debut differe ? #KL div w/ logits - Similarite predictions (distributions) @@ -915,75 +919,36 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start #print(fmodel['model']._params['b4'].grad) #print('prob grad', fmodel['data_aug']['prob'].grad) - #for _, p in fmodel['data_aug'].named_parameters(): - # p.requires_grad = False t = time.process_time() diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) - print(len(fmodel._fast_params),"step", time.process_time()-t) + print(len(model['model']['functional']._fast_params),"step", time.process_time()-t) - #for _, p in fmodel['data_aug'].named_parameters(): - # p.requires_grad = True if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step #print("meta") - val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss() + val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss() #print_graph(val_loss) val_loss.backward() - print('proba grad',fmodel['data_aug']['prob'].grad) - #countcopy+=1 - #model_copy(src=fmodel, dst=model) - #optim_copy(dopt=diffopt, opt=inner_opt) - - torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN - torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN - - for paramName, paramValue, in fmodel['data_aug'].named_parameters(): - for netCopyName, netCopyValue, in model['data_aug'].named_parameters(): - if paramName == netCopyName: - netCopyValue.grad = paramValue.grad - - #del meta_opt.param_groups[0] - #meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]}) + torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN meta_opt.step() - fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict()) - fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 - #model['data_aug'].next_TF_set() - + model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 - #fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - #diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) - - - #fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()] diffopt.detach_() - tmp = fmodel.fast_params - fmodel._fast_params=[] - fmodel.update_params(tmp) - for p in fmodel.fast_params: - p.detach_().requires_grad_() - print(len(fmodel._fast_params)) - - print('TF Proba :', fmodel['data_aug']['prob'].data) + model['model'].detach_() + tf = time.process_time() #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) - - #model_copy(src=fmodel, dst=model) + if(not high_grad_track): - #countcopy+=1 - #model_copy(src=fmodel, dst=model) - optim_copy(dopt=diffopt, opt=inner_opt) - val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) - #Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False) - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) accuracy, test_loss =test(model) model.train() diff --git a/higher/transformations.py b/higher/transformations.py index 0410a78..c4f4175 100755 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -103,7 +103,8 @@ TF_dict={ #Dataugv5 } TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} -TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize', '=Solarize', '=Posterize'} +TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} +TF_ignore_mag= TF_no_mag | TF_no_grad def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039) return (float_image*255.).type(torch.uint8) diff --git a/higher/utils.py b/higher/utils.py index 6e5675d..02fc1eb 100755 --- a/higher/utils.py +++ b/higher/utils.py @@ -314,4 +314,38 @@ class loss_monitor(): #Voir https://github.com/pytorch/ignite return False def reset(self): - self.__init__(self.patience, self.end_train) \ No newline at end of file + self.__init__(self.patience, self.end_train) + +### https://github.com/facebookresearch/higher/issues/18 #### +from torch._six import inf + +def clip_norm(tensors, max_norm, norm_type=2): + r"""Clips norm of passed tensors. + The norm is computed over all tensors together, as if they were + concatenated into a single vector. Clipped tensors are returned. + Arguments: + tensors (Iterable[Tensor]): an iterable of Tensors or a + single Tensor to be normalized. + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + Returns: + Clipped (List[Tensor]) tensors. + """ + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + tensors = list(tensors) + max_norm = float(max_norm) + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(t.abs().max() for t in tensors) + else: + total_norm = 0 + for t in tensors: + param_norm = t.norm(norm_type) + total_norm += param_norm.item() ** norm_type + total_norm = total_norm ** (1. / norm_type) + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef >= 1: + return tensors + return [t.mul(clip_coef) for t in tensors] \ No newline at end of file