mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Debut implementation Dataugv5
This commit is contained in:
parent
103277fadd
commit
05f81787d6
3 changed files with 31 additions and 58 deletions
|
@ -327,7 +327,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
self._TF= list(self._TF_dict.keys())
|
self._TF= list(self._TF_dict.keys())
|
||||||
self._nb_tf= len(self._TF)
|
self._nb_tf= len(self._TF)
|
||||||
|
|
||||||
self._N_TF = N_TF
|
self._N_seqTF = N_TF
|
||||||
|
|
||||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
|
@ -349,7 +349,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
self._samples = []
|
self._samples = []
|
||||||
|
|
||||||
for _ in range(self._N_TF):
|
for _ in range(self._N_seqTF):
|
||||||
## Echantillonage ##
|
## Echantillonage ##
|
||||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||||
|
|
||||||
|
@ -501,7 +501,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||||
for sample in self._samples:
|
for sample in self._samples:
|
||||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
|
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||||
w_loss += tmp_w
|
w_loss += tmp_w
|
||||||
|
|
||||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||||
|
@ -526,28 +526,27 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if not self._mix_dist:
|
if not self._mix_dist:
|
||||||
return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
|
return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
|
||||||
else:
|
else:
|
||||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Transformations avec mask
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, glob_mag=True):
|
||||||
super(Data_augV5, self).__init__()
|
super(Data_augV5, self).__init__()
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
|
|
||||||
self._data_augmentation = True
|
self._data_augmentation = True
|
||||||
|
|
||||||
#self._TF_matrix={}
|
self._TF_dict = TF_dict
|
||||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
self._TF= list(self._TF_dict.keys())
|
||||||
self._mag_fct = TF_dict
|
|
||||||
self._TF=list(self._mag_fct.keys())
|
|
||||||
self._nb_tf= len(self._TF)
|
self._nb_tf= len(self._TF)
|
||||||
|
|
||||||
self._N_TF = N_TF
|
self._N_seqTF = N_TF
|
||||||
|
|
||||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||||
|
"mag" : nn.Parameter(torch.tensor(0.5).expand(self._nb_tf) if glob_mag else torch.tensor(0.5).repeat(self._nb_tf)) #[0, PARAMETER_MAX]/10
|
||||||
})
|
})
|
||||||
|
|
||||||
self._samples = []
|
self._samples = []
|
||||||
|
@ -565,7 +564,7 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
||||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
self._samples = []
|
self._samples = []
|
||||||
|
|
||||||
for _ in range(self._N_TF):
|
for _ in range(self._N_seqTF):
|
||||||
## Echantillonage ##
|
## Echantillonage ##
|
||||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||||
|
|
||||||
|
@ -581,7 +580,7 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
||||||
## Transformations ##
|
## Transformations ##
|
||||||
x = self.apply_TF(x, sample)
|
x = self.apply_TF(x, sample)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def apply_TF(self, x, sampled_TF):
|
def apply_TF(self, x, sampled_TF):
|
||||||
device = x.device
|
device = x.device
|
||||||
smps_x=[]
|
smps_x=[]
|
||||||
|
@ -591,44 +590,12 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
||||||
smp_x = x[mask] #torch.masked_select() ?
|
smp_x = x[mask] #torch.masked_select() ?
|
||||||
|
|
||||||
if smp_x.shape[0]!=0: #if there's data to TF
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
magnitude=self._fixed_mag
|
magnitude=self._params["mag"][tf_idx]*10
|
||||||
tf=self._TF[tf_idx]
|
tf=self._TF[tf_idx]
|
||||||
|
#print(magnitude)
|
||||||
|
|
||||||
## Geometric TF ##
|
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
||||||
if tf=='Identity':
|
|
||||||
pass
|
|
||||||
elif tf=='FlipLR':
|
|
||||||
smp_x = TF.flipLR(smp_x)
|
|
||||||
elif tf=='FlipUD':
|
|
||||||
smp_x = TF.flipUD(smp_x)
|
|
||||||
elif tf=='Rotate':
|
|
||||||
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='TranslateX' or tf=='TranslateY':
|
|
||||||
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='ShearX' or tf=='ShearY' :
|
|
||||||
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
|
||||||
elif tf=='Contrast':
|
|
||||||
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Color':
|
|
||||||
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Brightness':
|
|
||||||
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Sharpness':
|
|
||||||
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Posterize':
|
|
||||||
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
|
|
||||||
elif tf=='Solarize':
|
|
||||||
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
|
||||||
elif tf=='Equalize':
|
|
||||||
smp_x = TF.equalize(smp_x)
|
|
||||||
elif tf=='Auto_Contrast':
|
|
||||||
smp_x = TF.auto_contrast(smp_x)
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid TF requested : ", tf)
|
|
||||||
|
|
||||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||||
|
@ -636,9 +603,14 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
||||||
if soft :
|
if soft :
|
||||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
else:
|
else:
|
||||||
|
#self._params['prob'].clamp(min=0.0,max=1.0)
|
||||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||||
|
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||||
|
|
||||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def loss_weight(self):
|
def loss_weight(self):
|
||||||
# 1 seule TF
|
# 1 seule TF
|
||||||
#self._sample = self._samples[-1]
|
#self._sample = self._samples[-1]
|
||||||
|
@ -647,11 +619,11 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
||||||
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||||
#w_loss = torch.sum(w_loss,dim=1)
|
#w_loss = torch.sum(w_loss,dim=1)
|
||||||
|
|
||||||
#Plusieurs TF sequentielles (Hypothese ordre negligeable)
|
#Plusieurs TF sequentielles
|
||||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||||
for sample in self._samples:
|
for sample in self._samples:
|
||||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
|
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||||
w_loss += tmp_w
|
w_loss += tmp_w
|
||||||
|
|
||||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||||
|
@ -676,9 +648,9 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if not self._mix_dist:
|
if not self._mix_dist:
|
||||||
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
|
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
|
||||||
else:
|
else:
|
||||||
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
|
|
||||||
class Augmented_model(nn.Module):
|
class Augmented_model(nn.Module):
|
||||||
|
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
|
@ -88,7 +88,7 @@ if __name__ == "__main__":
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
#'''
|
#'''
|
||||||
#### TF number tests ####
|
#### TF number tests ####
|
||||||
#'''
|
'''
|
||||||
res_folder="res/TF_nb_tests/"
|
res_folder="res/TF_nb_tests/"
|
||||||
epochs= 100
|
epochs= 100
|
||||||
inner_its = [0, 1, 10]
|
inner_its = [0, 1, 10]
|
||||||
|
@ -128,4 +128,4 @@ if __name__ == "__main__":
|
||||||
print('Log :\"',f.name, '\" saved !')
|
print('Log :\"',f.name, '\" saved !')
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
|
|
||||||
#'''
|
'''
|
|
@ -649,6 +649,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||||
print('TF Proba :', model['data_aug']['prob'].data)
|
print('TF Proba :', model['data_aug']['prob'].data)
|
||||||
#print('proba grad',aug_model['data_aug']['prob'].grad)
|
#print('proba grad',aug_model['data_aug']['prob'].grad)
|
||||||
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
#############
|
#############
|
||||||
#### Log ####
|
#### Log ####
|
||||||
data={
|
data={
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue