diff --git a/higher/dataug.py b/higher/dataug.py index 24654c3..f0d6bf8 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -322,8 +322,9 @@ class Data_augV4(nn.Module): #Transformations avec mask #self._TF_matrix={} #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix - self._mag_fct = TF_dict - self._TF=list(self._mag_fct.keys()) + #self._mag_fct = TF_dict + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) self._nb_tf= len(self._TF) self._N_TF = N_TF @@ -356,7 +357,6 @@ class Data_augV4(nn.Module): #Transformations avec mask self._distrib = uniforme_dist else: self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor - print(self.distrib.shape) cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) sample = cat_distrib.sample() @@ -414,6 +414,7 @@ class Data_augV4(nn.Module): #Transformations avec mask magnitude=self._fixed_mag tf=self._TF[tf_idx] + ''' ## Geometric TF ## if tf=='Identity': pass @@ -449,6 +450,8 @@ class Data_augV4(nn.Module): #Transformations avec mask raise Exception("Invalid TF requested : ", tf) x[mask]=smp_x # Refusionner eviter x[mask] : in place + ''' + x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place #idx= mask.nonzero() #print('-'*8) @@ -527,6 +530,157 @@ class Data_augV4(nn.Module): #Transformations avec mask else: return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) +class Data_augV5(nn.Module): #Transformations avec mask + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0): + super(Data_augV5, self).__init__() + assert len(TF_dict)>0 + + self._data_augmentation = True + + #self._TF_matrix={} + #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix + self._mag_fct = TF_dict + self._TF=list(self._mag_fct.keys()) + self._nb_tf= len(self._TF) + + self._N_TF = N_TF + + self._fixed_mag=5 #[0, PARAMETER_MAX] + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme + }) + + self._samples = [] + + self._mix_dist = False + if mix_dist != 0.0: + self._mix_dist = True + self._mix_factor = max(min(mix_dist, 1.0), 0.0) + + def forward(self, x): + if self._data_augmentation: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + self._samples = [] + + for _ in range(self._N_TF): + ## Echantillonage ## + uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) + + if not self._mix_dist: + self._distrib = uniforme_dist + else: + self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + + cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) + sample = cat_distrib.sample() + self._samples.append(sample) + + ## Transformations ## + x = self.apply_TF(x, sample) + return x + + def apply_TF(self, x, sampled_TF): + device = x.device + smps_x=[] + masks=[] + for tf_idx in range(self._nb_tf): + mask = sampled_TF==tf_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? + + if smp_x.shape[0]!=0: #if there's data to TF + magnitude=self._fixed_mag + tf=self._TF[tf_idx] + + ## Geometric TF ## + if tf=='Identity': + pass + elif tf=='FlipLR': + smp_x = TF.flipLR(smp_x) + elif tf=='FlipUD': + smp_x = TF.flipUD(smp_x) + elif tf=='Rotate': + smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='TranslateX' or tf=='TranslateY': + smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='ShearX' or tf=='ShearY' : + smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + + ## Color TF (Expect image in the range of [0, 1]) ## + elif tf=='Contrast': + smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Color': + smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Brightness': + smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Sharpness': + smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Posterize': + smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device)) + elif tf=='Solarize': + smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device)) + elif tf=='Equalize': + smp_x = TF.equalize(smp_x) + elif tf=='Auto_Contrast': + smp_x = TF.auto_contrast(smp_x) + else: + raise Exception("Invalid TF requested : ", tf) + + x[mask]=smp_x # Refusionner eviter x[mask] : in place + return x + + def adjust_prob(self, soft=False): #Detach from gradient ? + + if soft : + self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible + else: + self._params['prob'].data = F.relu(self._params['prob'].data) + self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + + def loss_weight(self): + # 1 seule TF + #self._sample = self._samples[-1] + #w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device) + #w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1) + #w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) + #w_loss = torch.sum(w_loss,dim=1) + + #Plusieurs TF sequentielles (Hypothese ordre negligeable) + w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) + for sample in self._samples: + tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) + tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF) + w_loss += tmp_w + + w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) + w_loss = torch.sum(w_loss,dim=1) + return w_loss + + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(Data_augV5, self).train(mode) + + def eval(self): + self.train(mode=False) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + if not self._mix_dist: + return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF) + else: + return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) + + class Augmented_model(nn.Module): def __init__(self, data_augmenter, model): super(Augmented_model, self).__init__() diff --git a/higher/res/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-2 epochs (dataug:0)- 10 in_it.png b/higher/res/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-2 epochs (dataug:0)- 10 in_it.png deleted file mode 100644 index 0e548f2..0000000 Binary files a/higher/res/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-2 epochs (dataug:0)- 10 in_it.png and /dev/null differ diff --git a/higher/res/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-3 epochs (dataug:0)- 0 in_it.png b/higher/res/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-3 epochs (dataug:0)- 0 in_it.png deleted file mode 100644 index 824c495..0000000 Binary files a/higher/res/Aug_mod(Data_augV4(Uniform-14 TF x 2)-LeNet)-3 epochs (dataug:0)- 0 in_it.png and /dev/null differ diff --git a/higher/res/LeNet-3 epochs.png b/higher/res/LeNet-3 epochs.png deleted file mode 100644 index e3ddcf5..0000000 Binary files a/higher/res/LeNet-3 epochs.png and /dev/null differ diff --git a/higher/test_dataug.py b/higher/test_dataug.py index bd0c023..f2b88ab 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -64,18 +64,18 @@ if __name__ == "__main__": print('-'*9) ''' #### Augmented Model #### - ''' + #''' t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), LeNet(3,10)).to(device) - aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) + aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device) + #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10) #### - plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names) + plot_resV2(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names) print('-'*9) times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} @@ -86,12 +86,13 @@ if __name__ == "__main__": print('Execution Time : %.00f (s?)'%(time.process_time() - t0)) print('-'*9) - ''' + #''' #### TF number tests #### #''' res_folder="res/TF_nb_tests/" epochs= 100 inner_its = [0, 1, 10] + dist_mix = [0.0, 0.5] dataug_epoch_starts= [0] TF_nb = [len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)] N_seq_TF= [2, 3, 4, 6] diff --git a/higher/train_utils.py b/higher/train_utils.py index bf970a6..6ed1e09 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -586,8 +586,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f logits = fmodel(xs) # modified `params` can also be passed as a kwarg loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() - #PAS PONDERE LOSS POUR DIST MIX - if fmodel._data_augmentation: # and not fmodel['data_aug']._mix_dist: #Weight loss + + if fmodel._data_augmentation: #Weight loss w_loss = fmodel['data_aug'].loss_weight().to(device) loss = loss * w_loss loss = loss.mean() diff --git a/higher/transformations.py b/higher/transformations.py index ec6e29b..cd21eff 100644 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -3,30 +3,54 @@ import kornia import random ### Available TF for Dataug ### +''' TF_dict={ #f(mag_normalise)=mag_reelle ## Geometric TF ## 'Identity' : (lambda mag: None), 'FlipUD' : (lambda mag: None), 'FlipLR' : (lambda mag: None), - 'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))), - 'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]), - 'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]), - 'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]), - 'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]), + 'Rotate': (lambda mag: rand_int(mag,maxval=30)), + 'TranslateX': (lambda mag: [rand_int(mag,maxval=20), 0]), + 'TranslateY': (lambda mag: [0, rand_int(mag,maxval=20)]), + 'ShearX': (lambda mag: [rand_float(mag, maxval=0.3), 0]), + 'ShearY': (lambda mag: [0, rand_float(mag, maxval=0.3)]), ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), - 'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), - 'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))), - 'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), - 'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))), - 'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch + 'Contrast': (lambda mag: rand_float(mag,minval=0.1, maxval=1.9)), + 'Color':(lambda mag: rand_float(mag,minval=0.1, maxval=1.9)), + 'Brightness':(lambda mag: rand_float(mag,minval=1., maxval=1.9)), + 'Sharpness':(lambda mag: rand_float(mag,minval=0.1, maxval=1.9)), + 'Posterize': (lambda mag: rand_int(mag,minval=4, maxval=8)), + 'Solarize': (lambda mag: rand_int(mag,minval=1, maxval=256)/256.), #=>Image entre [0,1] #Pas opti pour des batch #Non fonctionnel #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) #'Equalize': (lambda mag: None), } +''' +TF_dict={ + ## Geometric TF ## + 'Identity' : (lambda x, mag: x), + 'FlipUD' : (lambda x, mag: flipUD(x)), + 'FlipLR' : (lambda x, mag: flipLR(x)), + 'Rotate': (lambda x, mag: rotate(x, angle=torch.tensor([rand_int(mag, maxval=30)for _ in x], device=x.device))), + 'TranslateX': (lambda x, mag: translate(x, translation=torch.tensor([[rand_int(mag, maxval=20), 0] for _ in x], device=x.device))), + 'TranslateY': (lambda x, mag: translate(x, translation=torch.tensor([[0, rand_int(mag, maxval=20)] for _ in x], device=x.device))), + 'ShearX': (lambda x, mag: shear(x, shear=torch.tensor([[rand_float(mag, maxval=0.3), 0] for _ in x], device=x.device))), + 'ShearY': (lambda x, mag: shear(x, shear=torch.tensor([[0, rand_float(mag, maxval=0.3)] for _ in x], device=x.device))), + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), + 'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), + 'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=1., maxval=1.9) for _ in x], device=x.device))), + 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), + 'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))), + 'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch + + #Non fonctionnel + #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) + #'Equalize': (lambda mag: None), +} def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039) return (float_image*255.).type(torch.uint8) @@ -34,8 +58,19 @@ def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/25 def float_image(int_image): return int_image.type(torch.float)/255. -def rand_inverse(value): - return value if random.random() < 0.5 else -value +#def rand_inverse(value): +# return value if random.random() < 0.5 else -value + +def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval] + real_max = int_parameter(mag, maxval=maxval) + if not minval : minval = -real_max + return random.randint(minval, real_max) + +def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval] + real_max = float_parameter(mag, maxval=maxval) + if not minval : minval = -real_max + return random.uniform(minval, real_max) + #https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137 PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted diff --git a/higher/utils.py b/higher/utils.py index eb1be55..b0973d9 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -48,6 +48,38 @@ def plot_res(log, fig_name='res', param_names=None): plt.savefig(fig_name) plt.close() +def plot_resV2(log, fig_name='res', param_names=None): + + epochs = [x["epoch"] for x in log] + + fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 15)) + + ax[0, 0].set_title('Loss') + ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train') + ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val') + ax[0, 0].legend() + + ax[0, 1].set_title('Acc') + ax[0, 1].plot(epochs,[x["acc"] for x in log]) + + if log[0]["param"]!= None: + ax[1, 1].set_title('Prob') + if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])] + proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])] + ax[1, 1].stackplot(epochs, proba, labels=param_names) + ax[1, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5)) + + ax[1, 0].set_title('Mean prob') + mean = np.mean([x["param"] for x in log], axis=0) + std = np.std([x["param"] for x in log], axis=0) + ax[1, 0].bar(param_names, mean, yerr=std) + plt.sca(ax[1, 0]), plt.xticks(rotation=90) + + + fig_name = fig_name.replace('.',',') + plt.savefig(fig_name, bbox_inches='tight') + plt.close() + def plot_compare(filenames, fig_name='res'): all_data=[]