mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout des TF multiples sequentielles
This commit is contained in:
parent
caf5fad470
commit
c6dd26c29a
204 changed files with 54 additions and 35 deletions
|
@ -326,12 +326,14 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
self._TF=list(self._mag_fct.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_TF = N_TF
|
||||
|
||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
})
|
||||
|
||||
self._sample = []
|
||||
self._samples = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
|
@ -342,23 +344,26 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
print(self.distrib.shape)
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
self._sample = cat_distrib.sample()
|
||||
|
||||
## Transformations ##
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
x = self.apply_TF(x, self._sample)
|
||||
self._samples = []
|
||||
|
||||
for _ in range(self._N_TF):
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
print(self.distrib.shape)
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
self._samples.append(sample)
|
||||
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
'''
|
||||
def compute_TF_matrix(self, magnitude=None, sample_info= None):
|
||||
|
@ -482,12 +487,25 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
def loss_weight(self):
|
||||
w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||
w_loss.scatter_(1, self._sample.view(-1,1), 1)
|
||||
# 1 seule TF
|
||||
#self._sample = self._samples[-1]
|
||||
#w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||
#w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1)
|
||||
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
#w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Plusieurs TF sequentielles
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
|
@ -505,9 +523,9 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
|
||||
def __str__(self):
|
||||
if not self._mix_dist:
|
||||
return "Data_augV4(Uniform-%d TF)" % self._nb_tf
|
||||
return "Data_augV4(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
|
||||
else:
|
||||
return "Data_augV4(Mix %.1f-%d TF)" % (self._mix_factor, self._nb_tf)
|
||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
def __init__(self, data_augmenter, model):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue