mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Modif interface Data_augv4
This commit is contained in:
parent
3ae3e02e59
commit
0066da2e4d
2 changed files with 77 additions and 98 deletions
169
higher/dataug.py
169
higher/dataug.py
|
@ -320,31 +320,6 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
|
|
||||||
#self._TF_matrix={}
|
#self._TF_matrix={}
|
||||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
||||||
'''
|
|
||||||
self._mag_fct={ #f(mag_normalise)=mag_reelle
|
|
||||||
## Geometric TF ##
|
|
||||||
'Identity' : (lambda mag: None),
|
|
||||||
'FlipUD' : (lambda mag: None),
|
|
||||||
'FlipLR' : (lambda mag: None),
|
|
||||||
'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))),
|
|
||||||
'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]),
|
|
||||||
'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]),
|
|
||||||
'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]),
|
|
||||||
'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]),
|
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
|
||||||
'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
|
||||||
'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
|
||||||
'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))),
|
|
||||||
'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
|
||||||
'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))),
|
|
||||||
'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch
|
|
||||||
|
|
||||||
#Non fonctionnel
|
|
||||||
'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
|
||||||
#'Equalize': (lambda mag: None),
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
self._mag_fct = TF_dict
|
self._mag_fct = TF_dict
|
||||||
self._TF=list(self._mag_fct.keys())
|
self._TF=list(self._mag_fct.keys())
|
||||||
self._nb_tf= len(self._TF)
|
self._nb_tf= len(self._TF)
|
||||||
|
@ -380,77 +355,8 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
self._sample = cat_distrib.sample()
|
self._sample = cat_distrib.sample()
|
||||||
|
|
||||||
## Transformations ##
|
## Transformations ##
|
||||||
#'''
|
|
||||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
smps_x=[]
|
x = self.apply_TF(x, self._sample)
|
||||||
masks=[]
|
|
||||||
for tf_idx in range(self._nb_tf):
|
|
||||||
mask = self._sample==tf_idx #Create selection mask
|
|
||||||
smp_x = x[mask] #torch.masked_select() ?
|
|
||||||
|
|
||||||
if smp_x.shape[0]!=0: #if there's data to TF
|
|
||||||
magnitude=self._fixed_mag
|
|
||||||
tf=self._TF[tf_idx]
|
|
||||||
|
|
||||||
## Geometric TF ##
|
|
||||||
if tf=='Identity':
|
|
||||||
pass
|
|
||||||
elif tf=='FlipLR':
|
|
||||||
smp_x = TF.flipLR(smp_x)
|
|
||||||
elif tf=='FlipUD':
|
|
||||||
smp_x = TF.flipUD(smp_x)
|
|
||||||
elif tf=='Rotate':
|
|
||||||
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='TranslateX' or tf=='TranslateY':
|
|
||||||
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='ShearX' or tf=='ShearY' :
|
|
||||||
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
|
||||||
elif tf=='Contrast':
|
|
||||||
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Color':
|
|
||||||
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Brightness':
|
|
||||||
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Sharpness':
|
|
||||||
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Posterize':
|
|
||||||
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
|
|
||||||
elif tf=='Solarize':
|
|
||||||
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Equalize':
|
|
||||||
smp_x = TF.equalize(smp_x)
|
|
||||||
elif tf=='Auto_Contrast':
|
|
||||||
smp_x = TF.auto_contrast(smp_x)
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid TF requested : ", tf)
|
|
||||||
|
|
||||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
|
||||||
|
|
||||||
#idx= mask.nonzero()
|
|
||||||
#print('-'*8)
|
|
||||||
#print(idx[0], tf_idx)
|
|
||||||
#print(smp_x[0,])
|
|
||||||
#x=x.view(-1,3*32*32)
|
|
||||||
#x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation...
|
|
||||||
#x=x.view(-1,3,32,32)
|
|
||||||
#print(x[0,])
|
|
||||||
|
|
||||||
'''
|
|
||||||
if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer
|
|
||||||
self.compute_TF_matrix(sample_info={'h': x.shape[2],
|
|
||||||
'w': x.shape[3],
|
|
||||||
'device': x.device})
|
|
||||||
|
|
||||||
TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF
|
|
||||||
|
|
||||||
for tf_idx in range(self._nb_tf):
|
|
||||||
mask = self._sample==tf_idx #Create selection mask
|
|
||||||
TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]]
|
|
||||||
|
|
||||||
x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w))
|
|
||||||
'''
|
|
||||||
return x
|
return x
|
||||||
'''
|
'''
|
||||||
def compute_TF_matrix(self, magnitude=None, sample_info= None):
|
def compute_TF_matrix(self, magnitude=None, sample_info= None):
|
||||||
|
@ -489,6 +395,79 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
else:
|
else:
|
||||||
raise Exception("Invalid TF requested")
|
raise Exception("Invalid TF requested")
|
||||||
'''
|
'''
|
||||||
|
def apply_TF(self, x, sampled_TF):
|
||||||
|
device = x.device
|
||||||
|
smps_x=[]
|
||||||
|
masks=[]
|
||||||
|
for tf_idx in range(self._nb_tf):
|
||||||
|
mask = sampled_TF==tf_idx #Create selection mask
|
||||||
|
smp_x = x[mask] #torch.masked_select() ?
|
||||||
|
|
||||||
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
|
magnitude=self._fixed_mag
|
||||||
|
tf=self._TF[tf_idx]
|
||||||
|
|
||||||
|
## Geometric TF ##
|
||||||
|
if tf=='Identity':
|
||||||
|
pass
|
||||||
|
elif tf=='FlipLR':
|
||||||
|
smp_x = TF.flipLR(smp_x)
|
||||||
|
elif tf=='FlipUD':
|
||||||
|
smp_x = TF.flipUD(smp_x)
|
||||||
|
elif tf=='Rotate':
|
||||||
|
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
elif tf=='TranslateX' or tf=='TranslateY':
|
||||||
|
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
elif tf=='ShearX' or tf=='ShearY' :
|
||||||
|
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
elif tf=='Contrast':
|
||||||
|
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
elif tf=='Color':
|
||||||
|
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
elif tf=='Brightness':
|
||||||
|
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
elif tf=='Sharpness':
|
||||||
|
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
elif tf=='Posterize':
|
||||||
|
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
|
||||||
|
elif tf=='Solarize':
|
||||||
|
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||||
|
elif tf=='Equalize':
|
||||||
|
smp_x = TF.equalize(smp_x)
|
||||||
|
elif tf=='Auto_Contrast':
|
||||||
|
smp_x = TF.auto_contrast(smp_x)
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid TF requested : ", tf)
|
||||||
|
|
||||||
|
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
||||||
|
|
||||||
|
#idx= mask.nonzero()
|
||||||
|
#print('-'*8)
|
||||||
|
#print(idx[0], tf_idx)
|
||||||
|
#print(smp_x[0,])
|
||||||
|
#x=x.view(-1,3*32*32)
|
||||||
|
#x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation...
|
||||||
|
#x=x.view(-1,3,32,32)
|
||||||
|
#print(x[0,])
|
||||||
|
|
||||||
|
'''
|
||||||
|
if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer
|
||||||
|
self.compute_TF_matrix(sample_info={'h': x.shape[2],
|
||||||
|
'w': x.shape[3],
|
||||||
|
'device': x.device})
|
||||||
|
|
||||||
|
TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF
|
||||||
|
|
||||||
|
for tf_idx in range(self._nb_tf):
|
||||||
|
mask = self._sample==tf_idx #Create selection mask
|
||||||
|
TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]]
|
||||||
|
|
||||||
|
x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w))
|
||||||
|
'''
|
||||||
|
return x
|
||||||
|
|
||||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||||
|
|
||||||
if soft :
|
if soft :
|
||||||
|
|
|
@ -646,8 +646,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
|
|
||||||
tf = time.process_time()
|
tf = time.process_time()
|
||||||
|
|
||||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||||
|
|
||||||
if(not high_grad_track):
|
if(not high_grad_track):
|
||||||
countcopy+=1
|
countcopy+=1
|
||||||
|
@ -732,7 +732,7 @@ if __name__ == "__main__":
|
||||||
aug_model = Augmented_model(Data_augV4(TF_dict=TF.TF_dict, mix_dist=0.0), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV4(TF_dict=TF.TF_dict, mix_dist=0.0), LeNet(3,10)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10)
|
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10)
|
||||||
|
|
||||||
####
|
####
|
||||||
plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))
|
plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue