mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF
This commit is contained in:
parent
e291bc2e44
commit
75901b69b4
6 changed files with 198 additions and 83 deletions
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||
|
||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ):
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
||||
|
@ -545,13 +545,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._shared_mag = shared_mag
|
||||
self._fixed_mag = fixed_mag
|
||||
|
||||
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag
|
||||
else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
})
|
||||
|
||||
for tf in TF.TF_no_grad :
|
||||
if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Distribution
|
||||
|
@ -1094,8 +1096,8 @@ class Augmented_model(nn.Module):
|
|||
|
||||
self.augment(mode=True)
|
||||
|
||||
def initialize(self):
|
||||
self._mods['model'].initialize()
|
||||
#def initialize(self):
|
||||
# self._mods['model'].initialize()
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
@ -1136,4 +1138,81 @@ class Augmented_model(nn.Module):
|
|||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||
|
||||
'''
|
||||
import higher
|
||||
class Augmented_model2(nn.Module):
|
||||
def __init__(self, data_augmenter, model):
|
||||
super(Augmented_model2, self).__init__()
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'data_aug': data_augmenter,
|
||||
'model': model,
|
||||
'fmodel': None
|
||||
})
|
||||
|
||||
self.augment(mode=True)
|
||||
|
||||
def initialize(self):
|
||||
self._mods['model'].initialize()
|
||||
|
||||
def forward(self, x):
|
||||
if self._mods['fmodel']:
|
||||
return self._mods['fmodel'](self._mods['data_aug'](x))
|
||||
else:
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
||||
def functional(self, opt, track_higher_grads=True):
|
||||
self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True)
|
||||
|
||||
return higher.optim.get_diff_optim(opt,
|
||||
self._mods['model'].parameters(),
|
||||
fmodel=self._mods['fmodel'],
|
||||
track_higher_grads=track_higher_grads)
|
||||
|
||||
def detach_(self):
|
||||
tmp = self._mods['fmodel'].fast_params
|
||||
self._mods['fmodel']._fast_params=[]
|
||||
self._mods['fmodel'].update_params(tmp)
|
||||
for p in self._mods['fmodel'].fast_params:
|
||||
p.detach_().requires_grad_()
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self._mods['data_aug'].augment(mode)
|
||||
super(Augmented_model2, self).train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
return self.train(mode=False)
|
||||
#super(Augmented_model, self).eval()
|
||||
|
||||
def items(self):
|
||||
"""Return an iterable of the ModuleDict key/value pairs.
|
||||
"""
|
||||
return self._mods.items()
|
||||
|
||||
def update(self, modules):
|
||||
self._mods.update(modules)
|
||||
|
||||
def is_augmenting(self):
|
||||
return self._data_augmentation
|
||||
|
||||
def TF_names(self):
|
||||
try:
|
||||
return self._mods['data_aug']._TF
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||
'''
|
Loading…
Add table
Add a link
Reference in a new issue