From 561b71b30a8f7ae3b29e98629b8ff710449ba6ac Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Thu, 30 Jan 2020 11:21:25 -0500 Subject: [PATCH] Minor improvement (RandAug) --- higher/smart_aug/dataug.py | 32 ++++++- higher/smart_aug/test_brutus.py | 163 -------------------------------- higher/smart_aug/test_dataug.py | 25 ++--- higher/smart_aug/train_utils.py | 8 +- higher/smart_aug/utils.py | 1 + 5 files changed, 50 insertions(+), 179 deletions(-) delete mode 100755 higher/smart_aug/test_brutus.py diff --git a/higher/smart_aug/dataug.py b/higher/smart_aug/dataug.py index 1859717..246a336 100755 --- a/higher/smart_aug/dataug.py +++ b/higher/smart_aug/dataug.py @@ -187,11 +187,11 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters. Args: - soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False) + soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False) """ if not self._fixed_prob: if soft : - self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible + self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) else: self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0) self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 @@ -269,6 +269,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) """ self._data_augmentation=mode + def is_augmenting(self): + """ Return wether data augmentation is applied. + + Returns: + bool : True if data augmentation is applied. + """ + return self._data_augmentation + def __getitem__(self, key): """Access to the learnable parameters Args: @@ -588,6 +596,14 @@ class Data_augV7(nn.Module): #Proba sequentielles """ self._data_augmentation=mode + def is_augmenting(self): + """ Return wether data augmentation is applied. + + Returns: + bool : True if data augmentation is applied. + """ + return self._data_augmentation + def __getitem__(self, key): """Access to the learnable parameters Args: @@ -659,6 +675,8 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide }) self._shared_mag = True self._fixed_mag = True + self._fixed_prob=True + self._fixed_mix=True self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) @@ -753,6 +771,14 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide """ self._data_augmentation=mode + def is_augmenting(self): + """ Return wether data augmentation is applied. + + Returns: + bool : True if data augmentation is applied. + """ + return self._data_augmentation + def __getitem__(self, key): """Access to the learnable parameters Args: @@ -796,7 +822,7 @@ class Higher_model(nn.Module): """ super(Higher_model, self).__init__() - self._name = model.__str__() + self._name = model.__class__.__name__ #model.__str__() self._mods = nn.ModuleDict({ 'original': model, 'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) diff --git a/higher/smart_aug/test_brutus.py b/higher/smart_aug/test_brutus.py deleted file mode 100755 index 3916a94..0000000 --- a/higher/smart_aug/test_brutus.py +++ /dev/null @@ -1,163 +0,0 @@ -from model import * -from dataug import * -#from utils import * -from train_utils import * - -tf_names = [ - ## Geometric TF ## - 'Identity', - 'FlipUD', - 'FlipLR', - 'Rotate', - 'TranslateX', - 'TranslateY', - 'ShearX', - 'ShearY', - - ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast', - 'Color', - 'Brightness', - 'Sharpness', - 'Posterize', - 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch -] - -device = torch.device('cuda') - -if device == torch.device('cpu'): - device_name = 'CPU' -else: - device_name = torch.cuda.get_device_name(device) - -########################################## -if __name__ == "__main__": - - - n_inner_iter = 1 - epochs = 150 - dataug_epoch_start=0 - optim_param={ - 'Meta':{ - 'optim':'Adam', - 'lr':1e-2, #1e-2 - }, - 'Inner':{ - 'optim': 'SGD', - 'lr':1e-1, #1e-2 - 'momentum':0.9, #0.9 - } - } - - #model = LeNet(3,10) - #model = ResNet(num_classes=10) - #model = MobileNetV2(num_classes=10) - #model = WideResNet(num_classes=10, wrn_size=32) - - tf_dict = {k: TF.TF_dict[k] for k in tf_names} - - #### - ''' - t0 = time.process_time() - - aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) - - print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) - - exec_time=time.process_time() - t0 - #### - times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) - with open("res/log/%s.json" % filename, "w+") as f: - json.dump(out, f, indent=True) - print('Log :\"',f.name, '\" saved !') - ''' - - #### - ''' - t0 = time.process_time() - - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) - - print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) - - exec_time=time.process_time() - t0 - #### - times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) - with open("res/log/%s.json" % filename, "w+") as f: - json.dump(out, f, indent=True) - print('Log :\"',f.name, '\" saved !') - ''' - res_folder="../res/brutus-tests2/" - epochs= 150 - inner_its = [1] - dist_mix = [0.0, 0.5, 0.8, 1.0] - dataug_epoch_starts= [0] - tf_dict = {k: TF.TF_dict[k] for k in tf_names} - TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)] - N_seq_TF= [4, 3, 2] - mag_setup = [(True,True), (False, False)] #(Fixed, Shared) - #prob_setup = [True, False] - nb_run= 3 - - try: - os.mkdir(res_folder) - os.mkdir(res_folder+"log/") - except FileExistsError: - pass - - for n_inner_iter in inner_its: - for dataug_epoch_start in dataug_epoch_starts: - for n_tf in N_seq_TF: - for dist in dist_mix: - #for i in TF_nb: - for m_setup in mag_setup: - #for p_setup in prob_setup: - p_setup=False - for run in range(nb_run): - if (n_inner_iter == 0 and (m_setup!=(True,True) and p_setup!=True)) or (p_setup and dist!=0.0): continue #Autres setup inutiles sans meta-opti - #keys = list(TF.TF_dict.keys())[0:i] - #ntf_dict = {k: TF.TF_dict[k] for k in keys} - - t0 = time.process_time() - - model = ResNet(num_classes=10) - model = Higher_model(model) #run_dist_dataugV3 - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device) - #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) - - print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_dist_dataugV3(model=aug_model, - epochs=epochs, - inner_it=n_inner_iter, - dataug_epoch_start=dataug_epoch_start, - opt_param=optim_param, - print_freq=50, - KLdiv=True) - - exec_time=time.process_time() - t0 - #### - print('-'*9) - times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} - print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) - filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run) - with open("../res/log/%s.json" % filename, "w+") as f: - try: - json.dump(out, f, indent=True) - print('Log :\"',f.name, '\" saved !') - except: - print("Failed to save logs :",f.name) - try: - plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names()) - except: - print("Failed to plot res") - - print('Execution Time : %.00f '%(exec_time)) - print('-'*9) - #''' diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index 9d269d0..b7b339a 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -53,10 +53,6 @@ tf_names = [ #'Random', #'RandBlend' - - #Non fonctionnel - #'Auto_Contrast', #Pas opti pour des batch (Super lent) - #'Equalize', ] @@ -67,6 +63,12 @@ if device == torch.device('cpu'): else: device_name = torch.cuda.get_device_name(device) +torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility + +#Increase reproductibility +torch.manual_seed(0) +np.random.seed(0) + ########################################## if __name__ == "__main__": @@ -78,7 +80,7 @@ if __name__ == "__main__": } #Parameters n_inner_iter = 1 - epochs = 1 + epochs = 150 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -95,9 +97,8 @@ if __name__ == "__main__": #Models model = LeNet(3,10) #model = ResNet(num_classes=10) - #Lents - #model = MobileNetV2(num_classes=10) - #model = WideResNet(num_classes=10, wrn_size=32) + #import torchvision.models as models + #model=models.resnet18() #### Classic #### if 'classic' in tasks: @@ -105,7 +106,7 @@ if __name__ == "__main__": model = model.to(device) print("{} on {} for {} epochs".format(str(model), device_name, epochs)) - log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1) + log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=20) #log= train_classic_higher(model=model, epochs=epochs) exec_time=time.process_time() - t0 @@ -130,11 +131,10 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} model = Higher_model(model) #run_dist_dataugV3 - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_simple_smartaug(model=aug_model, epochs=epochs, inner_it=n_inner_iter, opt_param=optim_param) log= run_dist_dataugV3(model=aug_model, epochs=epochs, inner_it=n_inner_iter, @@ -142,7 +142,8 @@ if __name__ == "__main__": opt_param=optim_param, print_freq=1, unsup_loss=1, - hp_opt=False) + hp_opt=False, + save_sample_freq=None) exec_time=time.process_time() - t0 #### diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index 483ef71..31d42e3 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -287,13 +287,19 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start diffopt.detach_() model['model'].detach_() meta_opt.zero_grad() + + elif not high_grad_track: + diffopt.detach_() + model['model'].detach_() tf = time.process_time() if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving try: viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch)) + model.train() viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch)) + model.eval() except: print("Couldn't save samples epoch"+epoch) pass @@ -315,9 +321,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start "acc": accuracy, "time": tf - t0, - "mix_dist": model['data_aug']['mix_dist'].item(), "param": param, } + if not model['data_aug']._fixed_mix: data["mix_dist"]=model['data_aug']['mix_dist'].item() if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups] log.append(data) ############# diff --git a/higher/smart_aug/utils.py b/higher/smart_aug/utils.py index 37a21c5..c622338 100755 --- a/higher/smart_aug/utils.py +++ b/higher/smart_aug/utils.py @@ -131,6 +131,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None): fig_name (string): Relative path where to save the graph. (default: data_sample) weight_labels (Tensor): Weights associated to each labels. (default: None) """ + sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu() plt.figure(figsize=(10,10))