Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-15 16:55:03 -05:00
parent e291bc2e44
commit 75901b69b4
6 changed files with 198 additions and 83 deletions

View file

@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ):
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
@ -545,13 +545,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
#self._fixed_mag=5 #[0, PARAMETER_MAX]
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag
else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX]
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
})
for tf in TF.TF_no_grad :
if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Distribution
@ -1094,8 +1096,8 @@ class Augmented_model(nn.Module):
self.augment(mode=True)
def initialize(self):
self._mods['model'].initialize()
#def initialize(self):
# self._mods['model'].initialize()
def forward(self, x):
return self._mods['model'](self._mods['data_aug'](x))
@ -1136,4 +1138,81 @@ class Augmented_model(nn.Module):
return self._mods[key]
def __str__(self):
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
'''
import higher
class Augmented_model2(nn.Module):
def __init__(self, data_augmenter, model):
super(Augmented_model2, self).__init__()
self._mods = nn.ModuleDict({
'data_aug': data_augmenter,
'model': model,
'fmodel': None
})
self.augment(mode=True)
def initialize(self):
self._mods['model'].initialize()
def forward(self, x):
if self._mods['fmodel']:
return self._mods['fmodel'](self._mods['data_aug'](x))
else:
return self._mods['model'](self._mods['data_aug'](x))
def functional(self, opt, track_higher_grads=True):
self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True)
return higher.optim.get_diff_optim(opt,
self._mods['model'].parameters(),
fmodel=self._mods['fmodel'],
track_higher_grads=track_higher_grads)
def detach_(self):
tmp = self._mods['fmodel'].fast_params
self._mods['fmodel']._fast_params=[]
self._mods['fmodel'].update_params(tmp)
for p in self._mods['fmodel'].fast_params:
p.detach_().requires_grad_()
def augment(self, mode=True):
self._data_augmentation=mode
self._mods['data_aug'].augment(mode)
def train(self, mode=None):
if mode is None :
mode=self._data_augmentation
self._mods['data_aug'].augment(mode)
super(Augmented_model2, self).train(mode)
return self
def eval(self):
return self.train(mode=False)
#super(Augmented_model, self).eval()
def items(self):
"""Return an iterable of the ModuleDict key/value pairs.
"""
return self._mods.items()
def update(self, modules):
self._mods.update(modules)
def is_augmenting(self):
return self._data_augmentation
def TF_names(self):
try:
return self._mods['data_aug']._TF
except:
return None
def __getitem__(self, key):
return self._mods[key]
def __str__(self):
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
'''