mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF
This commit is contained in:
parent
e291bc2e44
commit
75901b69b4
6 changed files with 198 additions and 83 deletions
|
@ -654,7 +654,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
model.train()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
|
@ -742,9 +742,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
#if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
|
@ -835,7 +834,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
|
@ -853,12 +852,17 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
|
||||
#fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
#meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
|
||||
print(len(fmodel._fast_params))
|
||||
#print(len(model['model']['functional']._fast_params))
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
|
@ -871,30 +875,30 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
logits = model(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
if model._data_augmentation: #Weight loss
|
||||
w_loss = model['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
if fmodel._data_augmentation :
|
||||
fmodel.augment(mode=False)
|
||||
sup_logits = fmodel(xs)
|
||||
fmodel.augment(mode=True)
|
||||
if model._data_augmentation :
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
else:
|
||||
sup_logits = fmodel(xs)
|
||||
sup_logits = model(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if fmodel._data_augmentation:
|
||||
aug_logits = fmodel(xs)
|
||||
if model._data_augmentation:
|
||||
aug_logits = model(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss=0
|
||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
#if epoch>50: #debut differe ?
|
||||
#KL div w/ logits - Similarite predictions (distributions)
|
||||
|
@ -915,75 +919,36 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
#for _, p in fmodel['data_aug'].named_parameters():
|
||||
# p.requires_grad = False
|
||||
t = time.process_time()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||
print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
|
||||
|
||||
#for _, p in fmodel['data_aug'].named_parameters():
|
||||
# p.requires_grad = True
|
||||
|
||||
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||
#print("meta")
|
||||
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss()
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
print('proba grad',fmodel['data_aug']['prob'].grad)
|
||||
#countcopy+=1
|
||||
#model_copy(src=fmodel, dst=model)
|
||||
#optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
for paramName, paramValue, in fmodel['data_aug'].named_parameters():
|
||||
for netCopyName, netCopyValue, in model['data_aug'].named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
|
||||
#del meta_opt.param_groups[0]
|
||||
#meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]})
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
meta_opt.step()
|
||||
fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict())
|
||||
fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
#model['data_aug'].next_TF_set()
|
||||
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
|
||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
|
||||
#fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()]
|
||||
diffopt.detach_()
|
||||
tmp = fmodel.fast_params
|
||||
fmodel._fast_params=[]
|
||||
fmodel.update_params(tmp)
|
||||
for p in fmodel.fast_params:
|
||||
p.detach_().requires_grad_()
|
||||
print(len(fmodel._fast_params))
|
||||
|
||||
print('TF Proba :', fmodel['data_aug']['prob'].data)
|
||||
model['model'].detach_()
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
#model_copy(src=fmodel, dst=model)
|
||||
|
||||
|
||||
if(not high_grad_track):
|
||||
#countcopy+=1
|
||||
#model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
accuracy, test_loss =test(model)
|
||||
model.train()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue