Minor improvement (RandAug)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-30 11:21:25 -05:00
parent 6bba069d8a
commit 561b71b30a
5 changed files with 50 additions and 179 deletions

View file

@ -187,11 +187,11 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
Args:
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
"""
if not self._fixed_prob:
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
else:
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
@ -269,6 +269,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
"""
self._data_augmentation=mode
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key):
"""Access to the learnable parameters
Args:
@ -588,6 +596,14 @@ class Data_augV7(nn.Module): #Proba sequentielles
"""
self._data_augmentation=mode
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key):
"""Access to the learnable parameters
Args:
@ -659,6 +675,8 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
})
self._shared_mag = True
self._fixed_mag = True
self._fixed_prob=True
self._fixed_mix=True
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
@ -753,6 +771,14 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
"""
self._data_augmentation=mode
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key):
"""Access to the learnable parameters
Args:
@ -796,7 +822,7 @@ class Higher_model(nn.Module):
"""
super(Higher_model, self).__init__()
self._name = model.__str__()
self._name = model.__class__.__name__ #model.__str__()
self._mods = nn.ModuleDict({
'original': model,
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)