mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Test brutus suite
This commit is contained in:
parent
a772d13b83
commit
61dad1ad78
4 changed files with 90 additions and 32 deletions
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||
|
||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_mag=True, shared_mag=True):
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
||||
|
@ -555,6 +555,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Distribution
|
||||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
|
@ -570,12 +571,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:
|
||||
self._samples = []
|
||||
if self._data_augmentation and TF.random.random() < 0.5:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
self._samples = []
|
||||
|
||||
for _ in range(self._N_seqTF):
|
||||
## Echantillonage ##
|
||||
|
@ -584,7 +585,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
@ -622,18 +624,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return x
|
||||
|
||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
if not self._fixed_mag:
|
||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
|
||||
def loss_weight(self):
|
||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
# 1 seule TF
|
||||
#self._sample = self._samples[-1]
|
||||
#w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||
|
@ -648,7 +654,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
|
@ -676,13 +682,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
dist_param=''
|
||||
if self._fixed_prob: dist_param+='Fx'
|
||||
mag_param='Mag'
|
||||
if self._fixed_mag: mag_param+= 'Fx'
|
||||
if self._shared_mag: mag_param+= 'Sh'
|
||||
if not self._mix_dist:
|
||||
return "Data_augV5(Uniform-%dTFx%d-%s)" % (self._nb_tf, self._N_seqTF, mag_param)
|
||||
return "Data_augV5(Uniform%s-%dTFx%d-%s)rand0.5" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV5(Mix%.1f-%dTFx%d-%s)" % (self._mix_factor, self._nb_tf, self._N_seqTF, mag_param)
|
||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue