mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Refactoring de TF_dict
This commit is contained in:
parent
fd4dcdb392
commit
103277fadd
8 changed files with 245 additions and 23 deletions
160
higher/dataug.py
160
higher/dataug.py
|
@ -322,8 +322,9 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
|
||||
#self._TF_matrix={}
|
||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
||||
self._mag_fct = TF_dict
|
||||
self._TF=list(self._mag_fct.keys())
|
||||
#self._mag_fct = TF_dict
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_TF = N_TF
|
||||
|
@ -356,7 +357,6 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
print(self.distrib.shape)
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
@ -414,6 +414,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
magnitude=self._fixed_mag
|
||||
tf=self._TF[tf_idx]
|
||||
|
||||
'''
|
||||
## Geometric TF ##
|
||||
if tf=='Identity':
|
||||
pass
|
||||
|
@ -449,6 +450,8 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
raise Exception("Invalid TF requested : ", tf)
|
||||
|
||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
||||
'''
|
||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
||||
|
||||
#idx= mask.nonzero()
|
||||
#print('-'*8)
|
||||
|
@ -527,6 +530,157 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
else:
|
||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
||||
|
||||
class Data_augV5(nn.Module): #Transformations avec mask
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0):
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
#self._TF_matrix={}
|
||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
||||
self._mag_fct = TF_dict
|
||||
self._TF=list(self._mag_fct.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_TF = N_TF
|
||||
|
||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
})
|
||||
|
||||
self._samples = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
self._samples = []
|
||||
|
||||
for _ in range(self._N_TF):
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
self._samples.append(sample)
|
||||
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
device = x.device
|
||||
smps_x=[]
|
||||
masks=[]
|
||||
for tf_idx in range(self._nb_tf):
|
||||
mask = sampled_TF==tf_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ?
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._fixed_mag
|
||||
tf=self._TF[tf_idx]
|
||||
|
||||
## Geometric TF ##
|
||||
if tf=='Identity':
|
||||
pass
|
||||
elif tf=='FlipLR':
|
||||
smp_x = TF.flipLR(smp_x)
|
||||
elif tf=='FlipUD':
|
||||
smp_x = TF.flipUD(smp_x)
|
||||
elif tf=='Rotate':
|
||||
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='TranslateX' or tf=='TranslateY':
|
||||
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='ShearX' or tf=='ShearY' :
|
||||
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
elif tf=='Contrast':
|
||||
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Color':
|
||||
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Brightness':
|
||||
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Sharpness':
|
||||
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Posterize':
|
||||
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
|
||||
elif tf=='Solarize':
|
||||
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Equalize':
|
||||
smp_x = TF.equalize(smp_x)
|
||||
elif tf=='Auto_Contrast':
|
||||
smp_x = TF.auto_contrast(smp_x)
|
||||
else:
|
||||
raise Exception("Invalid TF requested : ", tf)
|
||||
|
||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
||||
return x
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
def loss_weight(self):
|
||||
# 1 seule TF
|
||||
#self._sample = self._samples[-1]
|
||||
#w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||
#w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1)
|
||||
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
#w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Plusieurs TF sequentielles (Hypothese ordre negligeable)
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV5, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
if not self._mix_dist:
|
||||
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
|
||||
else:
|
||||
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
||||
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
def __init__(self, data_augmenter, model):
|
||||
super(Augmented_model, self).__init__()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue