Borne mag + Regularisation mag

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-19 15:37:29 -05:00
parent f4bdd9bca5
commit 64282bda3a
10 changed files with 43 additions and 228 deletions

View file

@ -114,7 +114,7 @@ class Data_augV2(nn.Module): #Methode exacte
return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w)
def adjust_prob(self): #Detach from gradient ?
def adjust_param(self): #Detach from gradient ?
self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
#print('proba',self._params['prob'])
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
@ -262,7 +262,7 @@ class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte
# warp the original image by the found transform
return kornia.warp_perspective(x, M, dsize=(h, w))
def adjust_prob(self, soft=False): #Detach from gradient ?
def adjust_param(self, soft=False): #Detach from gradient ?
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
@ -478,7 +478,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
'''
return x
def adjust_prob(self, soft=False): #Detach from gradient ?
def adjust_param(self, soft=False): #Detach from gradient ?
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
@ -549,15 +549,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(0.5) if self._shared_mag
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]
})
self._samples = []
#Distribution
self._samples = []
self._mix_dist = False
if mix_dist != 0.0:
self._mix_dist = True
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
#Mag regularisation
if not self._fixed_mag:
ignore={'Identity', 'FlipUD', 'FlipLR', 'Solarize', 'Posterize'}
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in ignore]
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):
if self._data_augmentation:
device = x.device
@ -610,18 +617,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return x
def adjust_prob(self, soft=False): #Detach from gradient ?
def adjust_param(self, soft=False): #Detach from gradient ?
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
else:
#self._params['prob'].clamp(min=0.0,max=1.0)
self._params['prob'].data = F.relu(self._params['prob'].data)
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
def loss_weight(self):
# 1 seule TF
@ -642,6 +648,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
w_loss = torch.sum(w_loss,dim=1)
return w_loss
def reg_loss(self, reg_factor=0.005):
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
return reg_factor * F.mse_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt.to(self._params['mag'].device), reduction='mean')
def train(self, mode=None):
if mode is None :