mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Debut implementation Dataugv5
This commit is contained in:
parent
103277fadd
commit
05f81787d6
3 changed files with 31 additions and 58 deletions
|
@ -327,7 +327,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
self._TF= list(self._TF_dict.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_TF = N_TF
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
self._params = nn.ParameterDict({
|
||||
|
@ -349,7 +349,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
self._samples = []
|
||||
|
||||
for _ in range(self._N_TF):
|
||||
for _ in range(self._N_seqTF):
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
|
@ -501,7 +501,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
|
@ -526,28 +526,27 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
|
||||
def __str__(self):
|
||||
if not self._mix_dist:
|
||||
return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
|
||||
return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
|
||||
else:
|
||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||
|
||||
class Data_augV5(nn.Module): #Transformations avec mask
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0):
|
||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, glob_mag=True):
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
#self._TF_matrix={}
|
||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
||||
self._mag_fct = TF_dict
|
||||
self._TF=list(self._mag_fct.keys())
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_TF = N_TF
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
"mag" : nn.Parameter(torch.tensor(0.5).expand(self._nb_tf) if glob_mag else torch.tensor(0.5).repeat(self._nb_tf)) #[0, PARAMETER_MAX]/10
|
||||
})
|
||||
|
||||
self._samples = []
|
||||
|
@ -565,7 +564,7 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
|||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
self._samples = []
|
||||
|
||||
for _ in range(self._N_TF):
|
||||
for _ in range(self._N_seqTF):
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
|
@ -581,7 +580,7 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
|||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
device = x.device
|
||||
smps_x=[]
|
||||
|
@ -591,44 +590,12 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
|||
smp_x = x[mask] #torch.masked_select() ?
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._fixed_mag
|
||||
magnitude=self._params["mag"][tf_idx]*10
|
||||
tf=self._TF[tf_idx]
|
||||
#print(magnitude)
|
||||
|
||||
## Geometric TF ##
|
||||
if tf=='Identity':
|
||||
pass
|
||||
elif tf=='FlipLR':
|
||||
smp_x = TF.flipLR(smp_x)
|
||||
elif tf=='FlipUD':
|
||||
smp_x = TF.flipUD(smp_x)
|
||||
elif tf=='Rotate':
|
||||
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='TranslateX' or tf=='TranslateY':
|
||||
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='ShearX' or tf=='ShearY' :
|
||||
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
elif tf=='Contrast':
|
||||
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Color':
|
||||
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Brightness':
|
||||
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Sharpness':
|
||||
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Posterize':
|
||||
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
|
||||
elif tf=='Solarize':
|
||||
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Equalize':
|
||||
smp_x = TF.equalize(smp_x)
|
||||
elif tf=='Auto_Contrast':
|
||||
smp_x = TF.auto_contrast(smp_x)
|
||||
else:
|
||||
raise Exception("Invalid TF requested : ", tf)
|
||||
|
||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
||||
|
||||
return x
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
|
@ -636,9 +603,14 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
|||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
#self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
|
||||
|
||||
def loss_weight(self):
|
||||
# 1 seule TF
|
||||
#self._sample = self._samples[-1]
|
||||
|
@ -647,11 +619,11 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
|||
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
#w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Plusieurs TF sequentielles (Hypothese ordre negligeable)
|
||||
#Plusieurs TF sequentielles
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
|
@ -676,9 +648,9 @@ class Data_augV5(nn.Module): #Transformations avec mask
|
|||
|
||||
def __str__(self):
|
||||
if not self._mix_dist:
|
||||
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
|
||||
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
|
||||
else:
|
||||
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
||||
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue