mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Borne mag + Regularisation mag
This commit is contained in:
parent
f4bdd9bca5
commit
64282bda3a
10 changed files with 43 additions and 228 deletions
|
@ -528,7 +528,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
|||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_prob() #Contrainte sum(proba)=1
|
||||
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
@ -588,7 +588,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight().to(device)
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
#'''
|
||||
|
@ -605,7 +605,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
if(high_grad_track and i%inner_it==0): #Perform Meta step
|
||||
#print("meta")
|
||||
#Peu utile si high_grad_track = False
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
|
@ -616,15 +616,15 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_prob(soft=False) #Contrainte sum(proba)=1
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
if(not high_grad_track):
|
||||
countcopy+=1
|
||||
|
@ -643,7 +643,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f s'%(tf - t0))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
|
@ -651,6 +651,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
#############
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue