From 3ec99bf729f9726d9a65f0345b585fbe3c4448c3 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Mon, 2 Dec 2019 06:37:19 -0500 Subject: [PATCH] Modif Dataugv6 --- higher/compare_res.py | 12 +-- higher/datasets.py | 104 +++++++++++++++++++- higher/dataug.py | 202 ++++++++++++++++++++++++++++++++++++++ higher/test_dataug.py | 40 ++++---- higher/train_utils.py | 3 +- higher/transformations.py | 9 +- 6 files changed, 334 insertions(+), 36 deletions(-) diff --git a/higher/compare_res.py b/higher/compare_res.py index 02dfd0c..6b910b1 100644 --- a/higher/compare_res.py +++ b/higher/compare_res.py @@ -2,12 +2,12 @@ from utils import * if __name__ == "__main__": - #''' + ''' files=[ #"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", #"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", #"res/brutus-tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx1-Mag)-LeNet)-150epochs(dataug:0)-1in_it-0.json", - "res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", + #"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json", ] for idx, file in enumerate(files): @@ -16,7 +16,7 @@ if __name__ == "__main__": data = json.load(json_file) plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names']) #plot_TF_influence(data['Log'], param_names=data['Param_names']) - #''' + ''' ## Loss , Acc, Proba = f(epoch) ## #plot_compare(filenames=files, fig_name="res/compare") @@ -76,11 +76,11 @@ if __name__ == "__main__": ''' #Res print - ''' + #''' nb_run=3 accs = [] times = [] - files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Mix1.0-14TFx2-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-%s.json"%str(run) for run in range(nb_run)] + files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Mix1-14TFx4-Mag)-LeNet)-150epochs(dataug:0)-1in_it-%s.json"%str(run) for run in range(nb_run)] for idx, file in enumerate(files): #legend+=str(idx)+'-'+file+'\n' @@ -90,4 +90,4 @@ if __name__ == "__main__": times.append(data['Time'][0]) print(files[0], np.mean(accs), np.std(accs), np.mean(times)) - ''' \ No newline at end of file + #''' \ No newline at end of file diff --git a/higher/datasets.py b/higher/datasets.py index 4eb834b..78ab2e2 100644 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -28,9 +28,105 @@ data_test = torchvision.datasets.MNIST( "./data", train=False, download=True, transform=torchvision.transforms.ToTensor() ) ''' -data_train = torchvision.datasets.CIFAR10( - "./data", train=True, download=True, transform=transform -) + +from torchvision.datasets.vision import VisionDataset +from PIL import Image +import augmentation_transforms +import numpy as np + +class AugmentedDataset(VisionDataset): + def __init__(self, root, train=True, transform=None, target_transform=None, download=False): + + super(AugmentedDataset, self).__init__(root, transform=transform, target_transform=target_transform) + + supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform) + + self.sup_data = supervised_dataset.data + self.sup_targets = supervised_dataset.targets + + for idx, img in enumerate(self.sup_data): + self.sup_data[idx]= Image.fromarray(img) #to PIL Image + + self.unsup_data=[] + self.unsup_targets=[] + + self.data= self.sup_data + self.targets= self.sup_targets + + + self._TF = [ + 'Invert', 'Cutout', 'Sharpness', 'AutoContrast', 'Posterize', + 'ShearX', 'TranslateX', 'TranslateY', 'ShearY', 'Rotate', + 'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness'] + self._op_list =[] + self.prob=0.5 + for tf in self._TF: + for mag in range(1, 10): + self._op_list+=[(tf, self.prob, mag)] + self._nb_op = len(self._op_list) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + #img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def augement_data(self, aug_copy=1): + + policies = [] + for op_1 in self._op_list: + for op_2 in self._op_list: + policies += [[op_1, op_2]] + + for idx, image in enumerate(self.sup_data): + for _ in range(aug_copy): + chosen_policy = policies[np.random.choice(len(policies))] + aug_image = augmentation_transforms.apply_policy(chosen_policy, image) + #aug_image = augmentation_transforms.cutout_numpy(aug_image) + + self.unsup_data+=[aug_image] + self.unsup_targets+=[self.sup_targets[idx]] + + print(type(self.data), type(self.sup_data), type(self.unsup_data)) + print(len(self.data), len(self.sup_data), len(self.unsup_data)) + #self.data= self.sup_data+self.unsup_data + self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0) + print(len(self.data)) + self.targets= self.sup_targets+self.unsup_targets + + + def len_supervised(self): + return len(self.sup_data) + + def len_unsupervised(self): + return len(self.unsup_data) + + def __len__(self): + return len(self.data) + + +data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform) +#print(len(data_train)) +#data_train = AugmentedDataset("./data", train=True, download=True, transform=transform) +#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised()) +#data_train.augement_data() +#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised()) #data_val = torchvision.datasets.CIFAR10( # "./data", train=True, download=True, transform=transform #) @@ -45,4 +141,4 @@ val_subset_indices=range(int(len(data_train)/2),len(data_train)) dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices)) dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices)) -dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False) \ No newline at end of file +dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False) diff --git a/higher/dataug.py b/higher/dataug.py index 538f4ab..2e9b99c 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -692,6 +692,208 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) else: return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param) +import numpy as np +class Data_augV6(nn.Module): #Optimisation sequentielle + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True): + super(Data_augV6, self).__init__() + assert len(TF_dict)>0 + + self._data_augmentation = True + + self._TF_dict = TF_dict + self._TF= list(self._TF_dict.keys()) + self._nb_tf= len(self._TF) + + self._N_seqTF = N_TF + self._shared_mag = shared_mag + self._fixed_mag = fixed_mag + + self._TF_set_size=3 + #if self._TF_set_size>self._nb_tf: + # print("Warning : TF sets size higher than number of TF. Reducing set size to %d"%self._nb_tf) + # self._TF_set_size=self._nb_tf + assert self._nb_tf>=self._TF_set_size + self._TF_sets=[] + for i in range(1,self._nb_tf): + for j in range(i,self._nb_tf): + if i!=j: + self._TF_sets+=[torch.tensor([0, i, j])] + #print(self._TF_sets) + #self._TF_sets=[torch.tensor([0, i, j]) for i in range(1,self._nb_tf)] #All VS Identity + self._TF_schedule = [list(range(len(self._TF_sets))) for _ in range(self._N_seqTF)] + for n_tf in range(self._N_seqTF) : + TF.random.shuffle(self._TF_schedule[n_tf]) + #print(self._TF_schedule) + self._current_TF_idx=0 #random.randint + self._start_prob = 1/self._TF_set_size + + + self._params = nn.ParameterDict({ + "prob": nn.Parameter(torch.tensor(self._start_prob).expand(self._nb_tf)), #Proba independantes + "mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)) if self._shared_mag + else torch.tensor(float(TF.PARAMETER_MAX)).expand(self._nb_tf)), #[0, PARAMETER_MAX] + }) + + #for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag + + #Distribution + self._fixed_prob=fixed_prob + self._samples = [] + self._mix_dist = False + if mix_dist != 0.0: + self._mix_dist = True + self._mix_factor = max(min(mix_dist, 1.0), 0.0) + + #Mag regularisation + if not self._fixed_mag: + if self._shared_mag : + self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max + else: + self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag] + self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max + + def forward(self, x): + self._samples = [] + if self._data_augmentation:# and TF.random.random() < 0.5: + device = x.device + batch_size, h, w = x.shape[0], x.shape[2], x.shape[3] + + x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) + + for n_tf in range(self._N_seqTF): + + tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(device) + #print(n_tf, tf_set) + ## Echantillonage ## + uniforme_dist = torch.ones(1,len(tf_set),device=device).softmax(dim=1) + + if not self._mix_dist: + self._distrib = uniforme_dist + else: + prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] + curr_prob = torch.index_select(prob, 0, tf_set) + curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1 + self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor + + cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib) + sample = cat_distrib.sample() + self._samples.append(sample) + + ## Transformations ## + x = self.apply_TF(x, sample) + return x + + def apply_TF(self, x, sampled_TF): + device = x.device + batch_size, channels, h, w = x.shape + smps_x=[] + + for sel_idx, tf_idx in enumerate(self._TF_sets[self._current_TF_idx]): + mask = sampled_TF==sel_idx #Create selection mask + smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim) + + if smp_x.shape[0]!=0: #if there's data to TF + magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx] + if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param + + tf=self._TF[tf_idx] + #print(magnitude) + + #In place + #x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) + + #Out of place + smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude) + idx= mask.nonzero() + idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ... + x=x.scatter(dim=0, index=idx, src=smp_x) + + return x + + def adjust_param(self, soft=False): #Detach from gradient ? + if not self._fixed_prob: + if soft : + self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible + else: + self._params['prob'].data = F.relu(self._params['prob'].data) + #self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0) + #self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 + + self._params['prob'].data[0]=self._start_prob #Fixe p identite + + if not self._fixed_mag: + #self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme + self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX) + + def loss_weight(self): #A verifier + if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation + + prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"] + + #Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !) + w_loss = torch.zeros((self._samples[0].shape[0],self._TF_set_size), device=self._samples[0].device) + for n_tf in range(self._N_seqTF): + tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) + tmp_w.scatter_(dim=1, index=self._samples[n_tf].view(-1,1), value=1/self._N_seqTF) + + tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(prob.device) + curr_prob = torch.index_select(prob, 0, tf_set) + curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1 + + #ATTENTION DISTRIB DIFFERENTE AVEC MIX + assert not self._mix_dist + w_loss += tmp_w * curr_prob /self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) + + w_loss = torch.sum(w_loss,dim=1) + return w_loss + + def reg_loss(self, reg_factor=0.005): + if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF + return torch.tensor(0) + else: + #return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean') + params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask] + return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean') + + def next_TF_set(self, idx=None): + if idx: + self._current_TF_idx=idx + else: + self._current_TF_idx+=1 + if self._current_TF_idx== len(self._TF_schedule[0]): + self._current_TF_idx=0 + #for n_tf in range(self._N_seqTF) : + # TF.random.shuffle(self._TF_schedule[n_tf]) + #print(self._TF_schedule) + #print("Current TF :",self._TF_sets[self._current_TF_idx]) + + def train(self, mode=None): + if mode is None : + mode=self._data_augmentation + self.augment(mode=mode) #Inutile si mode=None + super(Data_augV6, self).train(mode) + + def eval(self): + self.train(mode=False) + + def augment(self, mode=True): + self._data_augmentation=mode + + def __getitem__(self, key): + return self._params[key] + + def __str__(self): + dist_param='' + if self._fixed_prob: dist_param+='Fx' + mag_param='Mag' + if self._fixed_mag: mag_param+= 'Fx' + if self._shared_mag: mag_param+= 'Sh' + if not self._mix_dist: + return "Data_augV6(Uniform%s-%dTF(%d)x%d-%s)" % (dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param) + else: + return "Data_augV6(Mix%.1f%s-%dTF(%d)x%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param) + + class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX): super(RandAug, self).__init__() diff --git a/higher/test_dataug.py b/higher/test_dataug.py index d443aa0..12a5653 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -5,21 +5,21 @@ from train_utils import * tf_names = [ ## Geometric TF ## - #'Identity', - #'FlipUD', - #'FlipLR', - #'Rotate', - #'TranslateX', - #'TranslateY', - #'ShearX', - #'ShearY', + 'Identity', + 'FlipUD', + 'FlipLR', + 'Rotate', + 'TranslateX', + 'TranslateY', + 'ShearX', + 'ShearY', ## Color TF (Expect image in the range of [0, 1]) ## - #'Contrast', - #'Color', - #'Brightness', - #'Sharpness', - #'Posterize', + 'Contrast', + 'Color', + 'Brightness', + 'Sharpness', + 'Posterize', 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch #Color TF (Common mag scale) @@ -44,10 +44,10 @@ tf_names = [ #'BadTranslateY', #'BadTranslateY_neg', - #'BadColor', - #'BadSharpness', - #'BadContrast', - #'BadBrightness', + 'BadColor', + 'BadSharpness', + 'BadContrast', + 'BadBrightness', #Non fonctionnel #'Auto_Contrast', #Pas opti pour des batch (Super lent) @@ -65,7 +65,7 @@ else: if __name__ == "__main__": n_inner_iter = 10 - epochs = 1 + epochs = 100 dataug_epoch_start=0 #### Classic #### @@ -95,12 +95,12 @@ if __name__ == "__main__": t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=True, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=None) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None) #### print('-'*9) diff --git a/higher/train_utils.py b/higher/train_utils.py index 0c3c750..f5d6a39 100644 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -618,6 +618,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f meta_opt.step() model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 + model['data_aug'].next_TF_set() fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) @@ -651,7 +652,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f print('TF Proba :', model['data_aug']['prob'].data) #print('proba grad',model['data_aug']['prob'].grad) print('TF Mag :', model['data_aug']['mag'].data) - print('Mag grad',model['data_aug']['mag'].grad) + #print('Mag grad',model['data_aug']['mag'].grad) #print('Reg loss:', model['data_aug'].reg_loss().item()) ############# #### Log #### diff --git a/higher/transformations.py b/higher/transformations.py index cefe253..82a8d9e 100644 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -46,7 +46,7 @@ TF_dict={ #Dataugv5 #AutoAugment 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient - 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch + 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Non fonctionnel #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) @@ -70,7 +70,7 @@ TF_dict={ #Dataugv5 'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))), 'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient - 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch + 'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Color TF (Common mag scale) '+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))), @@ -82,7 +82,7 @@ TF_dict={ #Dataugv5 '-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), '-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))), '=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient - '=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch + '=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] 'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))), @@ -295,8 +295,7 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH return float_image(x) -def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH - # Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B)) +def solarize(x, thresholds): batch_size, channels, h, w = x.shape #imgs=[] #for idx, t in enumerate(thresholds): #Operation par image