mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Fix etat Train/Eval pour augmentation differee (Retester !)
This commit is contained in:
parent
2d6d2f7397
commit
d21a6bbf5c
4 changed files with 38 additions and 29 deletions
|
@ -747,21 +747,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
|
||||
return max_mag_reg
|
||||
|
||||
def train(self, mode=None):
|
||||
def train(self, mode=True):
|
||||
""" Set the module training mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||
"""
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
#if mode is None :
|
||||
# mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV5, self).train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
""" Set the module to evaluation mode.
|
||||
"""
|
||||
self.train(mode=False)
|
||||
return self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
""" Set the augmentation mode.
|
||||
|
@ -1266,7 +1267,7 @@ class Augmented_model(nn.Module):
|
|||
|
||||
Attributes:
|
||||
_mods (nn.ModuleDict): A dictionary containing the modules.
|
||||
_data_augmentation (bool): Wether data augmentation is used.
|
||||
_data_augmentation (bool): Wether data augmentation should be used.
|
||||
"""
|
||||
def __init__(self, data_augmenter, model):
|
||||
"""Init Augmented Model.
|
||||
|
@ -1308,22 +1309,25 @@ class Augmented_model(nn.Module):
|
|||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
def train(self, mode=None):
|
||||
def train(self, mode=True):
|
||||
""" Set the module training mode.
|
||||
|
||||
Args:
|
||||
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
|
||||
mode (bool): Wether to learn the parameter of the module. (default: None)
|
||||
"""
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self._mods['data_aug'].augment(mode)
|
||||
#if mode is None :
|
||||
# mode=self._data_augmentation
|
||||
super(Augmented_model, self).train(mode)
|
||||
self._mods['data_aug'].augment(mode=self._data_augmentation) #Restart if needed data augmentation
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
""" Set the module to evaluation mode.
|
||||
"""
|
||||
return self.train(mode=False)
|
||||
#return self.train(mode=False)
|
||||
super(Augmented_model, self).train(mode=False)
|
||||
self._mods['data_aug'].augment(mode=False)
|
||||
return self
|
||||
|
||||
def items(self):
|
||||
"""Return an iterable of the ModuleDict key/value pairs.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue