mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Tests consomation memoire/temps + methode KL divergence (UDA)
This commit is contained in:
parent
b60610d9a7
commit
d68034eec1
5 changed files with 214 additions and 37 deletions
|
@ -63,7 +63,8 @@ def train_classic(model, epochs=1, print_freq=1):
|
|||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
optim.zero_grad()
|
||||
pred = model.forward(features)
|
||||
logits = model.forward(features)
|
||||
pred = F.log_softmax(logits, dim=1)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
@ -125,7 +126,8 @@ def train_classic_higher(model, epochs=1):
|
|||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
#optim.zero_grad()
|
||||
pred = fmodel.forward(features)
|
||||
logits = model.forward(features)
|
||||
pred = F.log_softmax(logits, dim=1)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
#.backward()
|
||||
#optim.step()
|
||||
|
@ -550,7 +552,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
dl_val_it = iter(dl_val)
|
||||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2)
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) #lr=1e-2
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
high_grad_track = True
|
||||
|
@ -589,11 +591,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
#loss = final_loss
|
||||
|
||||
#KLdiv=False
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
|
@ -612,18 +613,15 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
aug_logits = fmodel(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
#KL div w/ logits
|
||||
aug_loss = sup_logits*(log_sup-log_aug)
|
||||
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||
aug_loss=aug_loss.sum(dim=-1)
|
||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
|
||||
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
unsupp_coeff = 1
|
||||
loss += aug_loss * unsupp_coeff
|
||||
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
||||
|
@ -637,11 +635,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
#print("meta")
|
||||
#Peu utile si high_grad_track = False
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
@ -685,7 +682,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
print('Aug loss', aug_loss.item())
|
||||
#print('Aug loss', aug_loss.item())
|
||||
#############
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue