Tests consomation memoire/temps + methode KL divergence (UDA)

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-06 14:13:28 -05:00
parent b60610d9a7
commit d68034eec1
5 changed files with 214 additions and 37 deletions

View file

@ -63,7 +63,8 @@ def train_classic(model, epochs=1, print_freq=1):
features,labels = features.to(device), labels.to(device)
optim.zero_grad()
pred = model.forward(features)
logits = model.forward(features)
pred = F.log_softmax(logits, dim=1)
loss = F.cross_entropy(pred,labels)
loss.backward()
optim.step()
@ -125,7 +126,8 @@ def train_classic_higher(model, epochs=1):
features,labels = features.to(device), labels.to(device)
#optim.zero_grad()
pred = fmodel.forward(features)
logits = model.forward(features)
pred = F.log_softmax(logits, dim=1)
loss = F.cross_entropy(pred,labels)
#.backward()
#optim.step()
@ -550,7 +552,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
dl_val_it = iter(dl_val)
#if inner_it!=0:
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2)
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) #lr=1e-2
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
high_grad_track = True
@ -589,11 +591,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#loss = final_loss
#KLdiv=False
if(not KLdiv):
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
@ -612,18 +613,15 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
#KL div w/ logits
aug_loss = sup_logits*(log_sup-log_aug)
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
aug_loss=aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
aug_loss = (w_loss * aug_loss).mean()
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
print('TF Proba :', model['data_aug']['prob'].data)
#to visualize computational graph
#print_graph(loss)
@ -637,11 +635,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
#print("meta")
#Peu utile si high_grad_track = False
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
#print_graph(val_loss)
val_loss.backward()
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
@ -685,7 +682,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item())
print('Aug loss', aug_loss.item())
#print('Aug loss', aug_loss.item())
#############
#### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5)