Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-15 16:55:03 -05:00
parent e291bc2e44
commit 75901b69b4
6 changed files with 198 additions and 83 deletions

View file

@ -654,7 +654,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
model.train()
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
for epoch in range(1, epochs+1):
#print_torch_mem("Start epoch "+str(epoch))
@ -742,9 +742,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
#if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
@ -835,7 +834,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
#if inner_it!=0:
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
high_grad_track = True
if inner_it == 0:
@ -853,12 +852,17 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
#fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
diffopt = model['model'].get_diffopt(
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
track_higher_grads=high_grad_track)
#meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
print(len(fmodel._fast_params))
#print(len(model['model']['functional']._fast_params))
for epoch in range(1, epochs+1):
#print_torch_mem("Start epoch "+str(epoch))
@ -871,30 +875,30 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
if(not KLdiv):
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
logits = model(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
if model._data_augmentation: #Weight loss
w_loss = model['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
if fmodel._data_augmentation :
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
fmodel.augment(mode=True)
if model._data_augmentation :
model.augment(mode=False)
sup_logits = model(xs)
model.augment(mode=True)
else:
sup_logits = fmodel(xs)
sup_logits = model(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
if model._data_augmentation:
aug_logits = model(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
aug_loss=0
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
w_loss = model['data_aug'].loss_weight() #Weight loss
#if epoch>50: #debut differe ?
#KL div w/ logits - Similarite predictions (distributions)
@ -915,75 +919,36 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
#print(fmodel['model']._params['b4'].grad)
#print('prob grad', fmodel['data_aug']['prob'].grad)
#for _, p in fmodel['data_aug'].named_parameters():
# p.requires_grad = False
t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
print(len(fmodel._fast_params),"step", time.process_time()-t)
print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
#for _, p in fmodel['data_aug'].named_parameters():
# p.requires_grad = True
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
#print("meta")
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss()
#print_graph(val_loss)
val_loss.backward()
print('proba grad',fmodel['data_aug']['prob'].grad)
#countcopy+=1
#model_copy(src=fmodel, dst=model)
#optim_copy(dopt=diffopt, opt=inner_opt)
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
for paramName, paramValue, in fmodel['data_aug'].named_parameters():
for netCopyName, netCopyValue, in model['data_aug'].named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#del meta_opt.param_groups[0]
#meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]})
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
meta_opt.step()
fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict())
fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#model['data_aug'].next_TF_set()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
#fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()]
diffopt.detach_()
tmp = fmodel.fast_params
fmodel._fast_params=[]
fmodel.update_params(tmp)
for p in fmodel.fast_params:
p.detach_().requires_grad_()
print(len(fmodel._fast_params))
print('TF Proba :', fmodel['data_aug']['prob'].data)
model['model'].detach_()
tf = time.process_time()
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
#model_copy(src=fmodel, dst=model)
if(not high_grad_track):
#countcopy+=1
#model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
accuracy, test_loss =test(model)
model.train()