From 994d657a2808efa5a1a718026982179d4722d81b Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Mon, 18 Nov 2019 12:53:23 -0500 Subject: [PATCH] Dataugv5- Modification des TF pour propagation du gradient (mag) --- higher/dataug.py | 22 +++++++++--- higher/test_dataug.py | 10 +++--- higher/train_utils.py | 7 ++-- higher/transformations.py | 74 ++++++++++++++++++++++++++++++++++----- higher/utils.py | 2 +- 5 files changed, 94 insertions(+), 21 deletions(-) diff --git a/higher/dataug.py b/higher/dataug.py index 63ddaf2..f2b1d6a 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -583,19 +583,33 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) def apply_TF(self, x, sampled_TF): device = x.device + batch_size, channels, h, w = x.shape smps_x=[] - masks=[] + for tf_idx in range(self._nb_tf): mask = sampled_TF==tf_idx #Create selection mask - smp_x = x[mask] #torch.masked_select() ? + smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim) if smp_x.shape[0]!=0: #if there's data to TF magnitude=self._params["mag"][tf_idx]*10 tf=self._TF[tf_idx] #print(magnitude) - x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place - + #x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place + smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude) + + idx= mask.nonzero() + #print('-'*8) + idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + #print(idx.shape, smp_x.shape) + #print(idx[0], tf_idx) + #print(smp_x[0,]) + #x=x.view(-1,3*32*32) + #smp_x=smp_x.view(-1,3*32*32) + x=x.scatter(dim=0, index=idx, src=smp_x) + #x=x.view(-1,3,32,32) + #print(x[0,]) + return x def adjust_prob(self, soft=False): #Detach from gradient ? diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 3485f35..1fb9849 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -5,9 +5,9 @@ from train_utils import * tf_names = [ ## Geometric TF ## - 'Identity', - 'FlipUD', - 'FlipLR', + #'Identity', + #'FlipUD', + #'FlipLR', 'Rotate', 'TranslateX', 'TranslateY', @@ -37,7 +37,7 @@ else: ########################################## if __name__ == "__main__": - n_inner_iter = 10 + n_inner_iter = 1 epochs = 2 dataug_epoch_start=0 @@ -68,7 +68,7 @@ if __name__ == "__main__": t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, glob_mag=False), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) diff --git a/higher/train_utils.py b/higher/train_utils.py index 97c2c81..e6e1a94 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -623,8 +623,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f tf = time.process_time() - #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) - #viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) if(not high_grad_track): countcopy+=1 @@ -648,8 +648,9 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f print('Accuracy :', accuracy) print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) print('TF Proba :', model['data_aug']['prob'].data) - #print('proba grad',aug_model['data_aug']['prob'].grad) + #print('proba grad',model['data_aug']['prob'].grad) print('TF Mag :', model['data_aug']['mag'].data) + print('Mag grad',model['data_aug']['mag'].grad) ############# #### Log #### data={ diff --git a/higher/transformations.py b/higher/transformations.py index cd21eff..bf17cfe 100644 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -28,6 +28,7 @@ TF_dict={ #f(mag_normalise)=mag_reelle #'Equalize': (lambda mag: None), } ''' +''' TF_dict={ ## Geometric TF ## 'Identity' : (lambda x, mag: x), @@ -42,7 +43,7 @@ TF_dict={ ## Color TF (Expect image in the range of [0, 1]) ## 'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), 'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), - 'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=1., maxval=1.9) for _ in x], device=x.device))), + 'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))), 'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))), 'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch @@ -51,6 +52,27 @@ TF_dict={ #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) #'Equalize': (lambda mag: None), } +''' +TF_dict={ + ## Geometric TF ## + 'Identity' : (lambda x, mag: x), + 'FlipUD' : (lambda x, mag: flipUD(x)), + 'FlipLR' : (lambda x, mag: flipLR(x)), + 'Rotate': (lambda x, mag: rotate(x, angle=rand_float(size=x.shape[0], mag=mag, maxval=30))), + 'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))), + 'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))), + 'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))), + 'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_float(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))), + + ## Color TF (Expect image in the range of [0, 1]) ## + 'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Color':(lambda x, mag: color(x, color_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_float(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), + 'Posterize': (lambda x, mag: posterize(x, bits=rand_float(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient + 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_float(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch + +} def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039) return (float_image*255.).type(torch.uint8) @@ -71,6 +93,19 @@ def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval] if not minval : minval = -real_max return random.uniform(minval, real_max) +def rand_float(size, mag, maxval, minval=None): #[(-maxval,minval), maxval] + real_max = float_parameter(mag, maxval=maxval) + if not minval : minval = -real_max + #return random.uniform(minval, real_max) + return minval +(real_max-minval) * torch.rand(size, device=mag.device) + +def zero_stack(tensor, zero_pos): + if zero_pos==0: + return torch.stack((tensor, torch.zeros((tensor.shape[0],), device=tensor.device)), dim=1) + if zero_pos==1: + return torch.stack((torch.zeros((tensor.shape[0],), device=tensor.device), tensor), dim=1) + else: + raise Exception("Invalid zero_pos : ", zero_pos) #https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137 PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted @@ -83,7 +118,9 @@ def float_parameter(level, maxval): Returns: A float that results from scaling `maxval` according to `level`. """ - return float(level) * maxval / PARAMETER_MAX + + #return float(level) * maxval / PARAMETER_MAX + return (level * maxval / PARAMETER_MAX)#.to(torch.float32) def int_parameter(level, maxval): """Helper function to scale `val` between 0 and maxval . @@ -94,7 +131,11 @@ def int_parameter(level, maxval): Returns: An int that results from scaling `maxval` according to `level`. """ - return int(level * maxval / PARAMETER_MAX) + #return int(level * maxval / PARAMETER_MAX) + print(level) + res= (level * maxval / PARAMETER_MAX).to(torch.int8).requires_grad_()#.type(torch.int8) + print(res) + return res def flipLR(x): device = x.device @@ -119,10 +160,11 @@ def flipUD(x): return kornia.warp_perspective(x, M, dsize=(h, w)) def rotate(x, angle): - return kornia.rotate(x, angle=angle.type(torch.float32)) #Kornia ne supporte pas les int + return kornia.rotate(x, angle=angle)#.type(torch.float32)) #Kornia ne supporte pas les int def translate(x, translation): - return kornia.translate(x, translation=translation.type(torch.float32)) #Kornia ne supporte pas les int + #print(translation) + return kornia.translate(x, translation=translation)#.type(torch.float32)) #Kornia ne supporte pas les int def shear(x, shear): return kornia.shear(x, shear=shear) @@ -156,6 +198,7 @@ def sharpeness(x, sharpness_factor): #https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py def posterize(x, bits): + bits = bits.type(torch.uint8) #Perte du gradient x = int_image(x) #Expect image in the range of [0, 1] mask = ~(2 ** (8 - bits) - 1).type(torch.uint8) @@ -217,10 +260,25 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH # Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B)) + batch_size, channels, h, w = x.shape + imgs=[] for idx, t in enumerate(thresholds): #Operation par image - mask = x[idx] > t.item() - inv_x = 1-x[idx][mask] - x[idx][mask]=inv_x + mask = x[idx] > t #Perte du gradient + #In place + #inv_x = 1-x[idx][mask] + #x[idx][mask]=inv_x + # + + #Out of place + im = x[idx] + inv_x = 1-im[mask] + + imgs.append(im.masked_scatter(mask,inv_x)) + + idxs=torch.tensor(range(x.shape[0]), device=x.device) + idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs)) + # return x #https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818 diff --git a/higher/utils.py b/higher/utils.py index b0973d9..02f89be 100644 --- a/higher/utils.py +++ b/higher/utils.py @@ -170,7 +170,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample'): plt.xticks([]) plt.yticks([]) plt.grid(False) - plt.imshow(sample[i,], cmap=plt.cm.binary) + plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary) plt.xlabel(labels[i].item()) plt.savefig(fig_name)