Test KL divergence from UDA

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-06 10:44:18 -05:00
parent fa5bc72616
commit 217f94ef89
5 changed files with 52 additions and 28 deletions

View file

@ -542,7 +542,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
print("Copy ", countcopy)
return log
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, loss_patience=None):
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None):
device = next(model.parameters()).device
log = []
countcopy=0
@ -578,30 +578,51 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
'''
#Methode exacte
final_loss = 0
for tf_idx in range(fmodel['data_aug']._nb_tf):
fmodel['data_aug'].transf_idx=tf_idx
logits = fmodel(xs)
loss = F.cross_entropy(logits, ys)
#loss.backward(retain_graph=True)
#print('idx', tf_idx)
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#final_loss = 0
#for tf_idx in range(fmodel['data_aug']._nb_tf):
# fmodel['data_aug'].transf_idx=tf_idx
# logits = fmodel(xs)
# loss = F.cross_entropy(logits, ys)
# #loss.backward(retain_graph=True)
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#loss = final_loss
loss = final_loss
'''
#KLdiv=False
if(not KLdiv):
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
fmodel.augment(mode=True)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
#KL div w/ logits
aug_loss = sup_logits*(log_sup-log_aug)
aug_loss=aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss
aug_loss = (w_loss * aug_loss).mean()
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
print('TF Proba :', model['data_aug']['prob'].data)
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
#'''
#to visualize computational graph
#print_graph(loss)
@ -664,6 +685,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item())
print('Aug loss', aug_loss.item())
#############
#### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5)