diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index 688d8c8..1859717 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -919,25 +919,51 @@ class Augmented_model(nn.Module): self._data_augmentation=mode self._mods['data_aug'].augment(mode) + #### Encapsulation Meta Opt #### def start_bilevel_opt(self, inner_it, hp_list, opt_param, dl_val): + """ Set up Augmented Model for bi-level optimisation. + + Create and keep in Augmented Model the necessary objects for meta-optimisation. + This allow for an almost transparent use by just hiding the bi-level optimisation (see ''run_dist_dataugV3'') by :: + + model.step(loss) + + See ''run_simple_smartaug'' for a complete example. + + Args: + inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. + hp_list (list): List of hyper-parameters to be learned. + opt_param (dict): Dictionnary containing optimizers parameters. + dl_val (DataLoader): Data loader of validation data. + """ + + self._it_count=0 + self._in_it=inner_it + + self._opt_param=opt_param + #Inner Opt + inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + + #Validation data + self._dl_val=dl_val + self._dl_val_it=iter(dl_val) + self._val_loss=0. if inner_it==0 or len(hp_list)==0: #No meta-opt print("No meta optimization") - self._diffopt = model['model'].get_diffopt( + #Inner Opt + self._diffopt = self._mods['model'].get_diffopt( inner_opt, grad_callback=(lambda grads: clip_norm(grads, max_norm=10)), track_higher_grads=False) + self._meta_opt=None + else: #Bi-level opt print("Bi-Level optimization") - self._it_count=0 - self._in_it=inner_it - self._opt_param=opt_param #Inner Opt - inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 - self._diffopt = self._mods['model'].get_diffopt( inner_opt, grad_callback=(lambda grads: clip_norm(grads, max_norm=10)), @@ -945,15 +971,34 @@ class Augmented_model(nn.Module): #Meta Opt self._meta_opt = torch.optim.Adam(hp_list, lr=opt_param['Meta']['lr']) - - self._dl_val=dl_val - self._dl_val_it=iter(dl_val) - self._val_loss=0. - self._meta_opt.zero_grad() def step(self, loss): + """ Perform a model update. + ''start_bilevel_opt'' method needs to be called once before using this method. + + Perform a step of inner optimization and, if needed, a step of meta optimization. + Replace :: + + opt.zero_grad() + loss.backward() + opt.step() + + val_loss=... + val_loss.backward() + meta_opt.step() + adjust_param() + detach() + meta_opt.zero_grad() + + By :: + + model.step(loss) + + Args: + loss (Tensor): the training loss tensor. + """ self._it_count+=1 self._diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) @@ -982,6 +1027,22 @@ class Augmented_model(nn.Module): self._it_count=0 + def val_loss(self): + """ Get the validation loss. + + Compute, if needed, the validation loss and returns it. + + ''start_bilevel_opt'' method needs to be called once before using this method. + + Returns: + (Tensor) Validation loss on a single batch of data. + """ + if(self._meta_opt): #Bilevel opti + return self._val_loss + else: + return compute_vaLoss(model=self._mods['model'], dl_it=self._dl_val_it, dl=self._dl_val) + + ########################## def train(self, mode=True): """ Set the module training mode. diff --git a/higher/smart_aug/smart_aug_example.py b/higher/smart_aug/smart_aug_example.py new file mode 100644 index 0000000..3358de0 --- /dev/null +++ b/higher/smart_aug/smart_aug_example.py @@ -0,0 +1,77 @@ +""" Example use of smart augmentation. + +""" + +from model import * +from dataug import * +from train_utils import * + +# Use available TF (see transformations.py) +tf_names = [ + ## Geometric TF ## + 'Identity', + 'FlipUD', + 'FlipLR', + 'Rotate', + 'TranslateX', + 'TranslateY', + 'ShearX', + 'ShearY', + + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast', + 'Color', + 'Brightness', + 'Sharpness', + 'Posterize', + 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch +] + + +device = torch.device('cuda') #Select device to use + +if device == torch.device('cpu'): + device_name = 'CPU' +else: + device_name = torch.cuda.get_device_name(device) + +########################################## +if __name__ == "__main__": + + #Parameters + n_inner_iter = 1 + epochs = 150 + optim_param={ + 'Meta':{ + 'optim':'Adam', + 'lr':1e-2, #1e-2 + }, + 'Inner':{ + 'optim': 'SGD', + 'lr':1e-2, #1e-2 + 'momentum':0.9, #0.9 + } + } + + #Models + model = LeNet(3,10) + #model = ResNet(num_classes=10) + #model = MobileNetV2(num_classes=10) + #model = WideResNet(num_classes=10, wrn_size=32) + + #Smart_aug initialisation + tf_dict = {k: TF.TF_dict[k] for k in tf_names} + model = Higher_model(model) #run_dist_dataugV3 + aug_model = Augmented_model( + Data_augV5(TF_dict=tf_dict, + N_TF=3, + mix_dist=0.8, + fixed_prob=False, + fixed_mag=False, + shared_mag=False), + model).to(device) + + print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) + + # Training + trained_model = run_simple_smartaug(model=aug_model, epochs=epochs, inner_it=n_inner_iter, opt_param=optim_param) diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index ce4c864..9d269d0 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -73,12 +73,12 @@ if __name__ == "__main__": #Task to perform tasks={ #'classic', - #'aug_dataset', #Moved to old code 'aug_model' + #'aug_dataset', #Moved to old code } #Parameters n_inner_iter = 1 - epochs = 200 + epochs = 1 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -123,7 +123,47 @@ if __name__ == "__main__": print('Execution Time : %.00f '%(exec_time)) print('-'*9) - + + #### Augmented Model #### + if 'aug_model' in tasks: + t0 = time.process_time() + + tf_dict = {k: TF.TF_dict[k] for k in tf_names} + model = Higher_model(model) #run_dist_dataugV3 + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) + #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) + + print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) + log= run_simple_smartaug(model=aug_model, epochs=epochs, inner_it=n_inner_iter, opt_param=optim_param) + log= run_dist_dataugV3(model=aug_model, + epochs=epochs, + inner_it=n_inner_iter, + dataug_epoch_start=dataug_epoch_start, + opt_param=optim_param, + print_freq=1, + unsup_loss=1, + hp_opt=False) + + exec_time=time.process_time() - t0 + #### + print('-'*9) + times = [x["time"] for x in log] + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} + print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) + filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) + with open("../res/log/%s.json" % filename, "w+") as f: + try: + json.dump(out, f, indent=True) + print('Log :\"',f.name, '\" saved !') + except: + print("Failed to save logs :",f.name) + try: + plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names()) + except: + print("Failed to plot res") + + print('Execution Time : %.00f '%(exec_time)) + print('-'*9) #### Augmented Dataset #### ''' @@ -175,45 +215,4 @@ if __name__ == "__main__": print('Execution Time : %.00f '%(exec_time)) print('-'*9) - ''' - - #### Augmented Model #### - if 'aug_model' in tasks: - t0 = time.process_time() - - tf_dict = {k: TF.TF_dict[k] for k in tf_names} - model = Higher_model(model) #run_dist_dataugV3 - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) - #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) - - print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_simple_smartaug(model=aug_model, opt_param=optim_param) - log= run_dist_dataugV3(model=aug_model, - epochs=epochs, - inner_it=n_inner_iter, - dataug_epoch_start=dataug_epoch_start, - opt_param=optim_param, - print_freq=1, - unsup_loss=1, - hp_opt=False) - - exec_time=time.process_time() - t0 - #### - print('-'*9) - times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) - filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) - with open("res/log/%s.json" % filename, "w+") as f: - try: - json.dump(out, f, indent=True) - print('Log :\"',f.name, '\" saved !') - except: - print("Failed to save logs :",f.name) - try: - plot_resV2(log, fig_name="res/"+filename, param_names=aug_model.TF_names()) - except: - print("Failed to plot res") - - print('Execution Time : %.00f '%(exec_time)) - print('-'*9) \ No newline at end of file + ''' \ No newline at end of file diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index c578034..483ef71 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -364,7 +364,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start return log -def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1, save_sample_freq=None): +def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1): """Simple training of an augmented model with higher. This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py). @@ -380,13 +380,11 @@ def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, un inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1) print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1) unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1) - save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None) - + Returns: - (list) Logs of training. Each items is a dict containing results of an epoch. + (dict) A dictionary containing a whole state of the trained network. """ device = next(model.parameters()).device - log = [] ## Optimizers ## hyper_param = list(model['data_aug'].parameters()) @@ -407,55 +405,15 @@ def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, un tf = time.process_time() - if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving - try: - viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch)) - viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch)) - except: - print("Couldn't save samples epoch"+epoch) - pass - - val_loss = model._val_loss - # Test model - accuracy, test_loss =test(model) - model.train() - - #### Log #### - param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] - data={ - "epoch": epoch, - "train_loss": loss.item(), - "val_loss": val_loss.item(), - "acc": accuracy, - "time": tf - t0, - - "mix_dist": model['data_aug']['mix_dist'].item(), - "param": param, - } - log.append(data) - ############# #### Print #### if(print_freq and epoch%print_freq==0): print('-'*9) print('Epoch : %d/%d'%(epoch,epochs)) print('Time : %.00f'%(tf - t0)) - print('Train loss :',loss.item(), '/ val loss', val_loss.item()) - print('Accuracy :', max([x["acc"] for x in log])) - print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, 0)) + print('Train loss :',loss.item(), '/ val loss', model.val_loss().item()) if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data) - #print('proba grad',model['data_aug']['prob'].grad) if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data) - #print('Mag grad',model['data_aug']['mag'].grad) if not model['data_aug']._fixed_mix: print('Mix:', model['data_aug']['mix_dist'].item()) - #print('Reg loss:', model['data_aug'].reg_loss().item()) ############# - #Data sample saving - try: - viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch)) - viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch)) - except: - print("Couldn't save finals samples") - pass - - return log + return model['model'].state_dict() \ No newline at end of file