mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Borne mag + Regularisation mag
This commit is contained in:
parent
f4bdd9bca5
commit
64282bda3a
10 changed files with 43 additions and 228 deletions
|
@ -114,7 +114,7 @@ class Data_augV2(nn.Module): #Methode exacte
|
|||
return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w)
|
||||
|
||||
|
||||
def adjust_prob(self): #Detach from gradient ?
|
||||
def adjust_param(self): #Detach from gradient ?
|
||||
self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
#print('proba',self._params['prob'])
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
@ -262,7 +262,7 @@ class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte
|
|||
# warp the original image by the found transform
|
||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
|
@ -478,7 +478,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
'''
|
||||
return x
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
|
@ -549,15 +549,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
"mag" : nn.Parameter(torch.tensor(0.5) if self._shared_mag
|
||||
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
|
||||
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
})
|
||||
self._samples = []
|
||||
|
||||
#Distribution
|
||||
self._samples = []
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
ignore={'Identity', 'FlipUD', 'FlipLR', 'Solarize', 'Posterize'}
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in ignore]
|
||||
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
|
@ -610,18 +617,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
|
||||
return x
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
#self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
|
||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
|
||||
def loss_weight(self):
|
||||
# 1 seule TF
|
||||
|
@ -642,6 +648,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
return reg_factor * F.mse_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt.to(self._params['mag'].device), reduction='mean')
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue