Debut implementation Dataugv5

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-14 21:42:00 -05:00
parent 103277fadd
commit 05f81787d6
3 changed files with 31 additions and 58 deletions

View file

@ -327,7 +327,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
self._TF= list(self._TF_dict.keys()) self._TF= list(self._TF_dict.keys())
self._nb_tf= len(self._TF) self._nb_tf= len(self._TF)
self._N_TF = N_TF self._N_seqTF = N_TF
self._fixed_mag=5 #[0, PARAMETER_MAX] self._fixed_mag=5 #[0, PARAMETER_MAX]
self._params = nn.ParameterDict({ self._params = nn.ParameterDict({
@ -349,7 +349,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
self._samples = [] self._samples = []
for _ in range(self._N_TF): for _ in range(self._N_seqTF):
## Echantillonage ## ## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
@ -501,7 +501,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
for sample in self._samples: for sample in self._samples:
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF) tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
w_loss += tmp_w w_loss += tmp_w
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
@ -526,28 +526,27 @@ class Data_augV4(nn.Module): #Transformations avec mask
def __str__(self): def __str__(self):
if not self._mix_dist: if not self._mix_dist:
return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF) return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
else: else:
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Data_augV5(nn.Module): #Transformations avec mask class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0): def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, glob_mag=True):
super(Data_augV5, self).__init__() super(Data_augV5, self).__init__()
assert len(TF_dict)>0 assert len(TF_dict)>0
self._data_augmentation = True self._data_augmentation = True
#self._TF_matrix={} self._TF_dict = TF_dict
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix self._TF= list(self._TF_dict.keys())
self._mag_fct = TF_dict
self._TF=list(self._mag_fct.keys())
self._nb_tf= len(self._TF) self._nb_tf= len(self._TF)
self._N_TF = N_TF self._N_seqTF = N_TF
self._fixed_mag=5 #[0, PARAMETER_MAX] #self._fixed_mag=5 #[0, PARAMETER_MAX]
self._params = nn.ParameterDict({ self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(0.5).expand(self._nb_tf) if glob_mag else torch.tensor(0.5).repeat(self._nb_tf)) #[0, PARAMETER_MAX]/10
}) })
self._samples = [] self._samples = []
@ -565,7 +564,7 @@ class Data_augV5(nn.Module): #Transformations avec mask
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles) x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
self._samples = [] self._samples = []
for _ in range(self._N_TF): for _ in range(self._N_seqTF):
## Echantillonage ## ## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1) uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
@ -591,44 +590,12 @@ class Data_augV5(nn.Module): #Transformations avec mask
smp_x = x[mask] #torch.masked_select() ? smp_x = x[mask] #torch.masked_select() ?
if smp_x.shape[0]!=0: #if there's data to TF if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._fixed_mag magnitude=self._params["mag"][tf_idx]*10
tf=self._TF[tf_idx] tf=self._TF[tf_idx]
#print(magnitude)
## Geometric TF ## x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
if tf=='Identity':
pass
elif tf=='FlipLR':
smp_x = TF.flipLR(smp_x)
elif tf=='FlipUD':
smp_x = TF.flipUD(smp_x)
elif tf=='Rotate':
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='TranslateX' or tf=='TranslateY':
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='ShearX' or tf=='ShearY' :
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
## Color TF (Expect image in the range of [0, 1]) ##
elif tf=='Contrast':
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Color':
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Brightness':
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Sharpness':
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Posterize':
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
elif tf=='Solarize':
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Equalize':
smp_x = TF.equalize(smp_x)
elif tf=='Auto_Contrast':
smp_x = TF.auto_contrast(smp_x)
else:
raise Exception("Invalid TF requested : ", tf)
x[mask]=smp_x # Refusionner eviter x[mask] : in place
return x return x
def adjust_prob(self, soft=False): #Detach from gradient ? def adjust_prob(self, soft=False): #Detach from gradient ?
@ -636,9 +603,14 @@ class Data_augV5(nn.Module): #Transformations avec mask
if soft : if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
else: else:
#self._params['prob'].clamp(min=0.0,max=1.0)
self._params['prob'].data = F.relu(self._params['prob'].data) self._params['prob'].data = F.relu(self._params['prob'].data)
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
def loss_weight(self): def loss_weight(self):
# 1 seule TF # 1 seule TF
#self._sample = self._samples[-1] #self._sample = self._samples[-1]
@ -647,11 +619,11 @@ class Data_augV5(nn.Module): #Transformations avec mask
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) #w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
#w_loss = torch.sum(w_loss,dim=1) #w_loss = torch.sum(w_loss,dim=1)
#Plusieurs TF sequentielles (Hypothese ordre negligeable) #Plusieurs TF sequentielles
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device) w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
for sample in self._samples: for sample in self._samples:
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device) tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF) tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
w_loss += tmp_w w_loss += tmp_w
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss) w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
@ -676,9 +648,9 @@ class Data_augV5(nn.Module): #Transformations avec mask
def __str__(self): def __str__(self):
if not self._mix_dist: if not self._mix_dist:
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF) return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
else: else:
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Augmented_model(nn.Module): class Augmented_model(nn.Module):

View file

@ -68,7 +68,7 @@ if __name__ == "__main__":
t0 = time.process_time() t0 = time.process_time()
tf_dict = {k: TF.TF_dict[k] for k in tf_names} tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#tf_dict = TF.TF_dict #tf_dict = TF.TF_dict
aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device) aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device)
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
print(str(aug_model), 'on', device_name) print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
@ -88,7 +88,7 @@ if __name__ == "__main__":
print('-'*9) print('-'*9)
#''' #'''
#### TF number tests #### #### TF number tests ####
#''' '''
res_folder="res/TF_nb_tests/" res_folder="res/TF_nb_tests/"
epochs= 100 epochs= 100
inner_its = [0, 1, 10] inner_its = [0, 1, 10]
@ -128,4 +128,4 @@ if __name__ == "__main__":
print('Log :\"',f.name, '\" saved !') print('Log :\"',f.name, '\" saved !')
print('-'*9) print('-'*9)
#''' '''

View file

@ -649,6 +649,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
print('TF Proba :', model['data_aug']['prob'].data) print('TF Proba :', model['data_aug']['prob'].data)
#print('proba grad',aug_model['data_aug']['prob'].grad) #print('proba grad',aug_model['data_aug']['prob'].grad)
print('TF Mag :', model['data_aug']['mag'].data)
############# #############
#### Log #### #### Log ####
data={ data={