Modif interface Data_augv4

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-08 11:43:11 -05:00
parent 3ae3e02e59
commit 0066da2e4d
2 changed files with 77 additions and 98 deletions

View file

@ -320,31 +320,6 @@ class Data_augV4(nn.Module): #Transformations avec mask
#self._TF_matrix={}
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
'''
self._mag_fct={ #f(mag_normalise)=mag_reelle
## Geometric TF ##
'Identity' : (lambda mag: None),
'FlipUD' : (lambda mag: None),
'FlipLR' : (lambda mag: None),
'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))),
'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]),
'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]),
'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]),
'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]),
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))),
'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))),
'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch
#Non fonctionnel
'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
'''
self._mag_fct = TF_dict
self._TF=list(self._mag_fct.keys())
self._nb_tf= len(self._TF)
@ -380,12 +355,52 @@ class Data_augV4(nn.Module): #Transformations avec mask
self._sample = cat_distrib.sample()
## Transformations ##
#'''
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
x = self.apply_TF(x, self._sample)
return x
'''
def compute_TF_matrix(self, magnitude=None, sample_info= None):
print('Computing TF_matrix...')
if not magnitude :
magnitude=self._fixed_mag
if sample_info:
self._input_info['h']= sample_info['h']
self._input_info['w']= sample_info['w']
self._input_info['device'] = sample_info['device']
h, w, device= self._input_info['h'], self._input_info['w'], self._input_info['device']
self._TF_matrix={}
for tf in self._TF :
if tf=='Id':
self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]]], device=device)
elif tf=='Rot':
center = torch.ones(1, 2, device=device)
center[0, 0] = w / 2 # x
center[0, 1] = h / 2 # y
scale = torch.ones(1, device=device)
angle = self._mag_fct[tf](magnitude) * torch.ones(1, device=device)
R = kornia.get_rotation_matrix2d(center, angle, scale) #Rotation matrix (1,2,3)
self._TF_matrix[tf]=torch.cat((R,torch.tensor([[[ 0., 0., 1.]]], device=device)), dim=1) #TF matrix (1,3,3)
elif tf=='FlipLR':
self._TF_matrix[tf]=torch.tensor([[[-1., 0., w-1],
[ 0., 1., 0.],
[ 0., 0., 1.]]], device=device)
elif tf=='FlipUD':
self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.],
[ 0., -1., h-1],
[ 0., 0., 1.]]], device=device)
else:
raise Exception("Invalid TF requested")
'''
def apply_TF(self, x, sampled_TF):
device = x.device
smps_x=[]
masks=[]
for tf_idx in range(self._nb_tf):
mask = self._sample==tf_idx #Create selection mask
mask = sampled_TF==tf_idx #Create selection mask
smp_x = x[mask] #torch.masked_select() ?
if smp_x.shape[0]!=0: #if there's data to TF
@ -452,43 +467,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w))
'''
return x
'''
def compute_TF_matrix(self, magnitude=None, sample_info= None):
print('Computing TF_matrix...')
if not magnitude :
magnitude=self._fixed_mag
if sample_info:
self._input_info['h']= sample_info['h']
self._input_info['w']= sample_info['w']
self._input_info['device'] = sample_info['device']
h, w, device= self._input_info['h'], self._input_info['w'], self._input_info['device']
self._TF_matrix={}
for tf in self._TF :
if tf=='Id':
self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]]], device=device)
elif tf=='Rot':
center = torch.ones(1, 2, device=device)
center[0, 0] = w / 2 # x
center[0, 1] = h / 2 # y
scale = torch.ones(1, device=device)
angle = self._mag_fct[tf](magnitude) * torch.ones(1, device=device)
R = kornia.get_rotation_matrix2d(center, angle, scale) #Rotation matrix (1,2,3)
self._TF_matrix[tf]=torch.cat((R,torch.tensor([[[ 0., 0., 1.]]], device=device)), dim=1) #TF matrix (1,3,3)
elif tf=='FlipLR':
self._TF_matrix[tf]=torch.tensor([[[-1., 0., w-1],
[ 0., 1., 0.],
[ 0., 0., 1.]]], device=device)
elif tf=='FlipUD':
self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.],
[ 0., -1., h-1],
[ 0., 0., 1.]]], device=device)
else:
raise Exception("Invalid TF requested")
'''
def adjust_prob(self, soft=False): #Detach from gradient ?
if soft :

View file

@ -646,8 +646,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
tf = time.process_time()
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
if(not high_grad_track):
countcopy+=1
@ -732,7 +732,7 @@ if __name__ == "__main__":
aug_model = Augmented_model(Data_augV4(TF_dict=TF.TF_dict, mix_dist=0.0), LeNet(3,10)).to(device)
print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10)
####
plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))