mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF
This commit is contained in:
parent
e291bc2e44
commit
75901b69b4
6 changed files with 198 additions and 83 deletions
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||||
super(Data_augV5, self).__init__()
|
super(Data_augV5, self).__init__()
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
|
|
||||||
|
@ -545,13 +545,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._shared_mag = shared_mag
|
self._shared_mag = shared_mag
|
||||||
self._fixed_mag = fixed_mag
|
self._fixed_mag = fixed_mag
|
||||||
|
|
||||||
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||||
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag
|
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||||
else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
for tf in TF.TF_no_grad :
|
||||||
|
if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||||
|
|
||||||
#Distribution
|
#Distribution
|
||||||
|
@ -1094,8 +1096,8 @@ class Augmented_model(nn.Module):
|
||||||
|
|
||||||
self.augment(mode=True)
|
self.augment(mode=True)
|
||||||
|
|
||||||
def initialize(self):
|
#def initialize(self):
|
||||||
self._mods['model'].initialize()
|
# self._mods['model'].initialize()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self._mods['model'](self._mods['data_aug'](x))
|
return self._mods['model'](self._mods['data_aug'](x))
|
||||||
|
@ -1137,3 +1139,80 @@ class Augmented_model(nn.Module):
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||||
|
|
||||||
|
'''
|
||||||
|
import higher
|
||||||
|
class Augmented_model2(nn.Module):
|
||||||
|
def __init__(self, data_augmenter, model):
|
||||||
|
super(Augmented_model2, self).__init__()
|
||||||
|
|
||||||
|
self._mods = nn.ModuleDict({
|
||||||
|
'data_aug': data_augmenter,
|
||||||
|
'model': model,
|
||||||
|
'fmodel': None
|
||||||
|
})
|
||||||
|
|
||||||
|
self.augment(mode=True)
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self._mods['model'].initialize()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self._mods['fmodel']:
|
||||||
|
return self._mods['fmodel'](self._mods['data_aug'](x))
|
||||||
|
else:
|
||||||
|
return self._mods['model'](self._mods['data_aug'](x))
|
||||||
|
|
||||||
|
def functional(self, opt, track_higher_grads=True):
|
||||||
|
self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True)
|
||||||
|
|
||||||
|
return higher.optim.get_diff_optim(opt,
|
||||||
|
self._mods['model'].parameters(),
|
||||||
|
fmodel=self._mods['fmodel'],
|
||||||
|
track_higher_grads=track_higher_grads)
|
||||||
|
|
||||||
|
def detach_(self):
|
||||||
|
tmp = self._mods['fmodel'].fast_params
|
||||||
|
self._mods['fmodel']._fast_params=[]
|
||||||
|
self._mods['fmodel'].update_params(tmp)
|
||||||
|
for p in self._mods['fmodel'].fast_params:
|
||||||
|
p.detach_().requires_grad_()
|
||||||
|
|
||||||
|
def augment(self, mode=True):
|
||||||
|
self._data_augmentation=mode
|
||||||
|
self._mods['data_aug'].augment(mode)
|
||||||
|
|
||||||
|
def train(self, mode=None):
|
||||||
|
if mode is None :
|
||||||
|
mode=self._data_augmentation
|
||||||
|
self._mods['data_aug'].augment(mode)
|
||||||
|
super(Augmented_model2, self).train(mode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return self.train(mode=False)
|
||||||
|
#super(Augmented_model, self).eval()
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
"""Return an iterable of the ModuleDict key/value pairs.
|
||||||
|
"""
|
||||||
|
return self._mods.items()
|
||||||
|
|
||||||
|
def update(self, modules):
|
||||||
|
self._mods.update(modules)
|
||||||
|
|
||||||
|
def is_augmenting(self):
|
||||||
|
return self._data_augmentation
|
||||||
|
|
||||||
|
def TF_names(self):
|
||||||
|
try:
|
||||||
|
return self._mods['data_aug']._TF
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._mods[key]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||||
|
'''
|
|
@ -3,6 +3,40 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
import higher
|
||||||
|
class Higher_model(nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super(Higher_model, self).__init__()
|
||||||
|
|
||||||
|
self._mods = nn.ModuleDict({
|
||||||
|
'original': model,
|
||||||
|
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
|
||||||
|
return higher.optim.get_diff_optim(opt,
|
||||||
|
self._mods['original'].parameters(),
|
||||||
|
fmodel=self._mods['functional'],
|
||||||
|
grad_callback=grad_callback,
|
||||||
|
track_higher_grads=track_higher_grads)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._mods['functional'](x)
|
||||||
|
|
||||||
|
def detach_(self):
|
||||||
|
tmp = self._mods['functional'].fast_params
|
||||||
|
self._mods['functional']._fast_params=[]
|
||||||
|
self._mods['functional'].update_params(tmp)
|
||||||
|
for p in self._mods['functional'].fast_params:
|
||||||
|
p.detach_().requires_grad_()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._mods[key]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self._mods['original'].__str__()
|
||||||
|
|
||||||
## Basic CNN ##
|
## Basic CNN ##
|
||||||
class LeNet_F(nn.Module):
|
class LeNet_F(nn.Module):
|
||||||
def __init__(self, num_inp, num_out):
|
def __init__(self, num_inp, num_out):
|
||||||
|
|
|
@ -19,8 +19,8 @@ tf_names = [
|
||||||
'Color',
|
'Color',
|
||||||
'Brightness',
|
'Brightness',
|
||||||
'Sharpness',
|
'Sharpness',
|
||||||
#'Posterize',
|
'Posterize',
|
||||||
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
#Color TF (Common mag scale)
|
#Color TF (Common mag scale)
|
||||||
#'+Contrast',
|
#'+Contrast',
|
||||||
|
@ -67,7 +67,7 @@ if __name__ == "__main__":
|
||||||
'aug_model'
|
'aug_model'
|
||||||
}
|
}
|
||||||
n_inner_iter = 1
|
n_inner_iter = 1
|
||||||
epochs = 100
|
epochs = 15
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
optim_param={
|
optim_param={
|
||||||
'Meta':{
|
'Meta':{
|
||||||
|
@ -81,11 +81,13 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
model = LeNet(3,10)
|
#model = LeNet(3,10)
|
||||||
#model = MobileNetV2(num_classes=10)
|
#model = MobileNetV2(num_classes=10)
|
||||||
#model = ResNet(num_classes=10)
|
model = ResNet(num_classes=10)
|
||||||
#model = WideResNet(num_classes=10, wrn_size=32)
|
#model = WideResNet(num_classes=10, wrn_size=32)
|
||||||
|
|
||||||
|
model = Higher_model(model) #run_dist_dataugV3
|
||||||
|
|
||||||
#### Classic ####
|
#### Classic ####
|
||||||
if 'classic' in tasks:
|
if 'classic' in tasks:
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
|
@ -172,12 +174,12 @@ if __name__ == "__main__":
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||||
log= run_dist_dataugV2(model=aug_model,
|
log= run_dist_dataugV3(model=aug_model,
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
inner_it=n_inner_iter,
|
inner_it=n_inner_iter,
|
||||||
dataug_epoch_start=dataug_epoch_start,
|
dataug_epoch_start=dataug_epoch_start,
|
||||||
opt_param=optim_param,
|
opt_param=optim_param,
|
||||||
print_freq=10,
|
print_freq=1,
|
||||||
KLdiv=True,
|
KLdiv=True,
|
||||||
loss_patience=None)
|
loss_patience=None)
|
||||||
|
|
||||||
|
@ -187,7 +189,7 @@ if __name__ == "__main__":
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+"demi_mag"
|
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)#+"demi_mag"
|
||||||
with open("res/log/%s.json" % filename, "w+") as f:
|
with open("res/log/%s.json" % filename, "w+") as f:
|
||||||
json.dump(out, f, indent=True)
|
json.dump(out, f, indent=True)
|
||||||
print('Log :\"',f.name, '\" saved !')
|
print('Log :\"',f.name, '\" saved !')
|
||||||
|
|
|
@ -654,7 +654,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
for epoch in range(1, epochs+1):
|
for epoch in range(1, epochs+1):
|
||||||
#print_torch_mem("Start epoch "+str(epoch))
|
#print_torch_mem("Start epoch "+str(epoch))
|
||||||
|
@ -742,8 +742,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
model_copy(src=fmodel, dst=model)
|
model_copy(src=fmodel, dst=model)
|
||||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||||
torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
|
||||||
|
|
||||||
#if epoch>50:
|
#if epoch>50:
|
||||||
meta_opt.step()
|
meta_opt.step()
|
||||||
|
@ -835,7 +834,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
|
|
||||||
#if inner_it!=0:
|
#if inner_it!=0:
|
||||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
high_grad_track = True
|
high_grad_track = True
|
||||||
if inner_it == 0:
|
if inner_it == 0:
|
||||||
|
@ -853,12 +852,17 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
|
|
||||||
#fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True)
|
#fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True)
|
||||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
#diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
|
diffopt = model['model'].get_diffopt(
|
||||||
|
inner_opt,
|
||||||
|
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||||
|
track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
#meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
#meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||||
|
|
||||||
print(len(fmodel._fast_params))
|
#print(len(model['model']['functional']._fast_params))
|
||||||
|
|
||||||
for epoch in range(1, epochs+1):
|
for epoch in range(1, epochs+1):
|
||||||
#print_torch_mem("Start epoch "+str(epoch))
|
#print_torch_mem("Start epoch "+str(epoch))
|
||||||
|
@ -871,30 +875,30 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
|
|
||||||
if(not KLdiv):
|
if(not KLdiv):
|
||||||
#Methode uniforme
|
#Methode uniforme
|
||||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
logits = model(xs) # modified `params` can also be passed as a kwarg
|
||||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||||
|
|
||||||
if fmodel._data_augmentation: #Weight loss
|
if model._data_augmentation: #Weight loss
|
||||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
w_loss = model['data_aug'].loss_weight()#.to(device)
|
||||||
loss = loss * w_loss
|
loss = loss * w_loss
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
#Methode KL div
|
#Methode KL div
|
||||||
if fmodel._data_augmentation :
|
if model._data_augmentation :
|
||||||
fmodel.augment(mode=False)
|
model.augment(mode=False)
|
||||||
sup_logits = fmodel(xs)
|
sup_logits = model(xs)
|
||||||
fmodel.augment(mode=True)
|
model.augment(mode=True)
|
||||||
else:
|
else:
|
||||||
sup_logits = fmodel(xs)
|
sup_logits = model(xs)
|
||||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||||
loss = F.cross_entropy(log_sup, ys)
|
loss = F.cross_entropy(log_sup, ys)
|
||||||
|
|
||||||
if fmodel._data_augmentation:
|
if model._data_augmentation:
|
||||||
aug_logits = fmodel(xs)
|
aug_logits = model(xs)
|
||||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||||
aug_loss=0
|
aug_loss=0
|
||||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||||
|
|
||||||
#if epoch>50: #debut differe ?
|
#if epoch>50: #debut differe ?
|
||||||
#KL div w/ logits - Similarite predictions (distributions)
|
#KL div w/ logits - Similarite predictions (distributions)
|
||||||
|
@ -915,75 +919,36 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
#print(fmodel['model']._params['b4'].grad)
|
#print(fmodel['model']._params['b4'].grad)
|
||||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||||
|
|
||||||
#for _, p in fmodel['data_aug'].named_parameters():
|
|
||||||
# p.requires_grad = False
|
|
||||||
t = time.process_time()
|
t = time.process_time()
|
||||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||||
print(len(fmodel._fast_params),"step", time.process_time()-t)
|
print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
|
||||||
|
|
||||||
#for _, p in fmodel['data_aug'].named_parameters():
|
|
||||||
# p.requires_grad = True
|
|
||||||
|
|
||||||
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||||
#print("meta")
|
#print("meta")
|
||||||
|
|
||||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss()
|
||||||
#print_graph(val_loss)
|
#print_graph(val_loss)
|
||||||
|
|
||||||
val_loss.backward()
|
val_loss.backward()
|
||||||
|
|
||||||
print('proba grad',fmodel['data_aug']['prob'].grad)
|
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||||
#countcopy+=1
|
|
||||||
#model_copy(src=fmodel, dst=model)
|
|
||||||
#optim_copy(dopt=diffopt, opt=inner_opt)
|
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
|
||||||
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
|
||||||
|
|
||||||
for paramName, paramValue, in fmodel['data_aug'].named_parameters():
|
|
||||||
for netCopyName, netCopyValue, in model['data_aug'].named_parameters():
|
|
||||||
if paramName == netCopyName:
|
|
||||||
netCopyValue.grad = paramValue.grad
|
|
||||||
|
|
||||||
#del meta_opt.param_groups[0]
|
|
||||||
#meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]})
|
|
||||||
|
|
||||||
meta_opt.step()
|
meta_opt.step()
|
||||||
fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict())
|
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||||
fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
|
||||||
#model['data_aug'].next_TF_set()
|
|
||||||
|
|
||||||
|
|
||||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
|
||||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
|
||||||
|
|
||||||
|
|
||||||
#fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()]
|
|
||||||
diffopt.detach_()
|
diffopt.detach_()
|
||||||
tmp = fmodel.fast_params
|
model['model'].detach_()
|
||||||
fmodel._fast_params=[]
|
|
||||||
fmodel.update_params(tmp)
|
|
||||||
for p in fmodel.fast_params:
|
|
||||||
p.detach_().requires_grad_()
|
|
||||||
print(len(fmodel._fast_params))
|
|
||||||
|
|
||||||
print('TF Proba :', fmodel['data_aug']['prob'].data)
|
|
||||||
tf = time.process_time()
|
tf = time.process_time()
|
||||||
|
|
||||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||||
|
|
||||||
#model_copy(src=fmodel, dst=model)
|
|
||||||
|
|
||||||
if(not high_grad_track):
|
if(not high_grad_track):
|
||||||
#countcopy+=1
|
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||||
#model_copy(src=fmodel, dst=model)
|
|
||||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
|
||||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
|
||||||
|
|
||||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
|
||||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
|
||||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
|
||||||
|
|
||||||
accuracy, test_loss =test(model)
|
accuracy, test_loss =test(model)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
@ -103,7 +103,8 @@ TF_dict={ #Dataugv5
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'}
|
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'}
|
||||||
TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize', '=Solarize', '=Posterize'}
|
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'}
|
||||||
|
TF_ignore_mag= TF_no_mag | TF_no_grad
|
||||||
|
|
||||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||||
return (float_image*255.).type(torch.uint8)
|
return (float_image*255.).type(torch.uint8)
|
||||||
|
|
|
@ -315,3 +315,37 @@ class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.__init__(self.patience, self.end_train)
|
self.__init__(self.patience, self.end_train)
|
||||||
|
|
||||||
|
### https://github.com/facebookresearch/higher/issues/18 ####
|
||||||
|
from torch._six import inf
|
||||||
|
|
||||||
|
def clip_norm(tensors, max_norm, norm_type=2):
|
||||||
|
r"""Clips norm of passed tensors.
|
||||||
|
The norm is computed over all tensors together, as if they were
|
||||||
|
concatenated into a single vector. Clipped tensors are returned.
|
||||||
|
Arguments:
|
||||||
|
tensors (Iterable[Tensor]): an iterable of Tensors or a
|
||||||
|
single Tensor to be normalized.
|
||||||
|
max_norm (float or int): max norm of the gradients
|
||||||
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||||
|
infinity norm.
|
||||||
|
Returns:
|
||||||
|
Clipped (List[Tensor]) tensors.
|
||||||
|
"""
|
||||||
|
if isinstance(tensors, torch.Tensor):
|
||||||
|
tensors = [tensors]
|
||||||
|
tensors = list(tensors)
|
||||||
|
max_norm = float(max_norm)
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
if norm_type == inf:
|
||||||
|
total_norm = max(t.abs().max() for t in tensors)
|
||||||
|
else:
|
||||||
|
total_norm = 0
|
||||||
|
for t in tensors:
|
||||||
|
param_norm = t.norm(norm_type)
|
||||||
|
total_norm += param_norm.item() ** norm_type
|
||||||
|
total_norm = total_norm ** (1. / norm_type)
|
||||||
|
clip_coef = max_norm / (total_norm + 1e-6)
|
||||||
|
if clip_coef >= 1:
|
||||||
|
return tensors
|
||||||
|
return [t.mul(clip_coef) for t in tensors]
|
Loading…
Add table
Add a link
Reference in a new issue