Fix etat Train/Eval pour augmentation differee (Retester !)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-20 17:09:31 -05:00
parent 2d6d2f7397
commit d21a6bbf5c
4 changed files with 38 additions and 29 deletions

View file

@ -747,21 +747,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
max_mag_reg = reg_factor * F.mse_loss(mags, target=self._reg_tgt.to(mags.device), reduction='mean')
return max_mag_reg
def train(self, mode=None):
def train(self, mode=True):
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
"""
if mode is None :
mode=self._data_augmentation
#if mode is None :
# mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None
super(Data_augV5, self).train(mode)
return self
def eval(self):
""" Set the module to evaluation mode.
"""
self.train(mode=False)
return self.train(mode=False)
def augment(self, mode=True):
""" Set the augmentation mode.
@ -1266,7 +1267,7 @@ class Augmented_model(nn.Module):
Attributes:
_mods (nn.ModuleDict): A dictionary containing the modules.
_data_augmentation (bool): Wether data augmentation is used.
_data_augmentation (bool): Wether data augmentation should be used.
"""
def __init__(self, data_augmenter, model):
"""Init Augmented Model.
@ -1308,22 +1309,25 @@ class Augmented_model(nn.Module):
self._data_augmentation=mode
self._mods['data_aug'].augment(mode)
def train(self, mode=None):
def train(self, mode=True):
""" Set the module training mode.
Args:
mode (bool): Wether to learn the parameter of the module. None would not change mode. (default: None)
mode (bool): Wether to learn the parameter of the module. (default: None)
"""
if mode is None :
mode=self._data_augmentation
self._mods['data_aug'].augment(mode)
#if mode is None :
# mode=self._data_augmentation
super(Augmented_model, self).train(mode)
self._mods['data_aug'].augment(mode=self._data_augmentation) #Restart if needed data augmentation
return self
def eval(self):
""" Set the module to evaluation mode.
"""
return self.train(mode=False)
#return self.train(mode=False)
super(Augmented_model, self).train(mode=False)
self._mods['data_aug'].augment(mode=False)
return self
def items(self):
"""Return an iterable of the ModuleDict key/value pairs.