mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Test KL divergence from UDA
This commit is contained in:
parent
fa5bc72616
commit
217f94ef89
5 changed files with 52 additions and 28 deletions
|
@ -542,7 +542,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
|||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, loss_patience=None):
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None):
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
countcopy=0
|
||||
|
@ -578,30 +578,51 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
'''
|
||||
|
||||
#Methode exacte
|
||||
final_loss = 0
|
||||
for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
fmodel['data_aug'].transf_idx=tf_idx
|
||||
logits = fmodel(xs)
|
||||
loss = F.cross_entropy(logits, ys)
|
||||
#loss.backward(retain_graph=True)
|
||||
#print('idx', tf_idx)
|
||||
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
|
||||
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
#final_loss = 0
|
||||
#for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
# fmodel['data_aug'].transf_idx=tf_idx
|
||||
# logits = fmodel(xs)
|
||||
# loss = F.cross_entropy(logits, ys)
|
||||
# #loss.backward(retain_graph=True)
|
||||
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
#loss = final_loss
|
||||
|
||||
loss = final_loss
|
||||
'''
|
||||
#KLdiv=False
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
fmodel.augment(mode=False)
|
||||
sup_logits = fmodel(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
fmodel.augment(mode=True)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if fmodel._data_augmentation:
|
||||
aug_logits = fmodel(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
#KL div w/ logits
|
||||
aug_loss = sup_logits*(log_sup-log_aug)
|
||||
aug_loss=aug_loss.sum(dim=-1)
|
||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
|
||||
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
unsupp_coeff = 1
|
||||
loss += aug_loss * unsupp_coeff
|
||||
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
#'''
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
@ -664,6 +685,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
print('Aug loss', aug_loss.item())
|
||||
#############
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue