mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Minor improvement (RandAug)
This commit is contained in:
parent
6bba069d8a
commit
561b71b30a
5 changed files with 50 additions and 179 deletions
|
@ -187,11 +187,11 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
|
||||
|
||||
Args:
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False)
|
||||
soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
|
||||
"""
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
|
||||
else:
|
||||
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
@ -269,6 +269,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
"""
|
||||
self._data_augmentation=mode
|
||||
|
||||
def is_augmenting(self):
|
||||
""" Return wether data augmentation is applied.
|
||||
|
||||
Returns:
|
||||
bool : True if data augmentation is applied.
|
||||
"""
|
||||
return self._data_augmentation
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to the learnable parameters
|
||||
Args:
|
||||
|
@ -588,6 +596,14 @@ class Data_augV7(nn.Module): #Proba sequentielles
|
|||
"""
|
||||
self._data_augmentation=mode
|
||||
|
||||
def is_augmenting(self):
|
||||
""" Return wether data augmentation is applied.
|
||||
|
||||
Returns:
|
||||
bool : True if data augmentation is applied.
|
||||
"""
|
||||
return self._data_augmentation
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to the learnable parameters
|
||||
Args:
|
||||
|
@ -659,6 +675,8 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
})
|
||||
self._shared_mag = True
|
||||
self._fixed_mag = True
|
||||
self._fixed_prob=True
|
||||
self._fixed_mix=True
|
||||
|
||||
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
|
||||
|
||||
|
@ -753,6 +771,14 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
|||
"""
|
||||
self._data_augmentation=mode
|
||||
|
||||
def is_augmenting(self):
|
||||
""" Return wether data augmentation is applied.
|
||||
|
||||
Returns:
|
||||
bool : True if data augmentation is applied.
|
||||
"""
|
||||
return self._data_augmentation
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Access to the learnable parameters
|
||||
Args:
|
||||
|
@ -796,7 +822,7 @@ class Higher_model(nn.Module):
|
|||
"""
|
||||
super(Higher_model, self).__init__()
|
||||
|
||||
self._name = model.__str__()
|
||||
self._name = model.__class__.__name__ #model.__str__()
|
||||
self._mods = nn.ModuleDict({
|
||||
'original': model,
|
||||
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue