Encapsulation de KLdiv loss dans mixed_loss

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-27 17:29:45 -05:00
parent a2135e4709
commit 7742f76d12
2 changed files with 63 additions and 34 deletions

View file

@ -183,7 +183,7 @@ if __name__ == "__main__":
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
model = Higher_model(model) #run_dist_dataugV3
aug_model = Augmented_model(Data_augV7(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
@ -192,8 +192,8 @@ if __name__ == "__main__":
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=20,
KLdiv=True,
print_freq=1,
unsup_loss=1,
hp_opt=False)
exec_time=time.process_time() - t0

View file

@ -68,6 +68,59 @@ def compute_vaLoss(model, dl_it, dl):
model.eval() #Validation sans transfornations !
return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
def mixed_loss(xs, ys, model, unsup_factor=1):
"""Evaluate a model on a batch of data.
Compute a combinaison of losses:
+ Supervised Cross-Entropy loss from original data.
+ Unsupervised Cross-Entropy loss from augmented data.
+ KL divergence loss encouraging similarity between original and augmented prediction.
If unsup_factor is equal to 0 or if there isn't data augmentation, only the supervised loss is computed.
Inspired by UDA, see: https://github.com/google-research/uda/blob/master/image/main.py
Args:
xs (Tensor): Batch of data.
ys (Tensor): Batch of labels.
model (nn.Module): Augmented model (see dataug.py).
unsup_factor (float): Factor by which unsupervised CE and KL div loss are multiplied.
Returns:
(Tensor) Mixed loss if there's data augmentation, just supervised CE loss otherwise.
"""
#TODO: add test to prevent augmented model error and redirect to classic loss
if unsup_factor!=0 and model.is_augmenting():
# Supervised loss (classic)
model.augment(mode=False)
sup_logits = model(xs)
model.augment(mode=True)
log_sup = F.log_softmax(sup_logits, dim=1)
sup_loss = F.cross_entropy(log_sup, ys)
# Unsupervised loss
aug_logits = model(xs)
w_loss = model['data_aug'].loss_weight() #Weight loss
log_aug = F.log_softmax(aug_logits, dim=1)
aug_loss = (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
kl_loss = (w_loss * kl_loss).mean()
loss = sup_loss + unsup_factor * (aug_loss + kl_loss)
else: #Supervised loss (classic)
sup_logits = model(xs)
log_sup = F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
return loss
def train_classic(model, opt_param, epochs=1, print_freq=1):
"""Classic training of a model.
@ -130,15 +183,14 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
return log
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, KLdiv=1, hp_opt=False, save_sample_freq=None):
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, unsup_loss=1, hp_opt=False, save_sample_freq=None):
"""Training of an augmented model with higher.
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
Ex : Augmented_model(Data_augV5(...), Higher_model(model))
Training loss can either be computed directly from augmented inputs (KLdiv=0).
However, it is recommended to use the KLdiv loss computation, inspired from UDA, which combine original and augmented inputs to compute the loss (KLdiv>0).
See : https://github.com/google-research/uda
Training loss can either be computed directly from augmented inputs (unsup_loss=0).
However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
Args:
model (nn.Module): Augmented model to train.
@ -147,7 +199,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
KLdiv (float): Proportion of the KLdiv loss added to the supervised loss. If set to 0, the loss is classicly computed on augmented inputs. (default: 1)
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None)
@ -193,7 +245,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
if(KLdiv<=0):
if(unsup_loss==0):
#Methode uniforme
logits = model(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
@ -204,32 +256,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
loss = loss.mean()
else:
#Methode KL div
# Supervised loss (classic)
if model.is_augmenting() :
model.augment(mode=False)
sup_logits = model(xs)
model.augment(mode=True)
else:
sup_logits = model(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
#Methode mixed
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
# Unsupervised loss (KLdiv)
if model.is_augmenting() :
aug_logits = model(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
aug_loss=0
w_loss = model['data_aug'].loss_weight() #Weight loss
#KL div w/ logits - Similarite predictions (distributions)
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
aug_loss = aug_loss.sum(dim=-1)
aug_loss = (w_loss * aug_loss).mean()
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
loss += aug_loss * KLdiv
#print_graph(loss) #to visualize computational graph
#t = time.process_time()