diff --git a/higher/smart_aug/test_dataug.py b/higher/smart_aug/test_dataug.py index 0a09d0d..19da042 100755 --- a/higher/smart_aug/test_dataug.py +++ b/higher/smart_aug/test_dataug.py @@ -183,7 +183,7 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} model = Higher_model(model) #run_dist_dataugV3 - aug_model = Augmented_model(Data_augV7(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) @@ -192,8 +192,8 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=20, - KLdiv=True, + print_freq=1, + unsup_loss=1, hp_opt=False) exec_time=time.process_time() - t0 diff --git a/higher/smart_aug/train_utils.py b/higher/smart_aug/train_utils.py index b64170e..20e4ec8 100755 --- a/higher/smart_aug/train_utils.py +++ b/higher/smart_aug/train_utils.py @@ -68,6 +68,59 @@ def compute_vaLoss(model, dl_it, dl): model.eval() #Validation sans transfornations ! return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys) +def mixed_loss(xs, ys, model, unsup_factor=1): + """Evaluate a model on a batch of data. + + Compute a combinaison of losses: + + Supervised Cross-Entropy loss from original data. + + Unsupervised Cross-Entropy loss from augmented data. + + KL divergence loss encouraging similarity between original and augmented prediction. + + If unsup_factor is equal to 0 or if there isn't data augmentation, only the supervised loss is computed. + + Inspired by UDA, see: https://github.com/google-research/uda/blob/master/image/main.py + + Args: + xs (Tensor): Batch of data. + ys (Tensor): Batch of labels. + model (nn.Module): Augmented model (see dataug.py). + unsup_factor (float): Factor by which unsupervised CE and KL div loss are multiplied. + + Returns: + (Tensor) Mixed loss if there's data augmentation, just supervised CE loss otherwise. + """ + + #TODO: add test to prevent augmented model error and redirect to classic loss + if unsup_factor!=0 and model.is_augmenting(): + + # Supervised loss (classic) + model.augment(mode=False) + sup_logits = model(xs) + model.augment(mode=True) + + log_sup = F.log_softmax(sup_logits, dim=1) + sup_loss = F.cross_entropy(log_sup, ys) + + # Unsupervised loss + aug_logits = model(xs) + w_loss = model['data_aug'].loss_weight() #Weight loss + + log_aug = F.log_softmax(aug_logits, dim=1) + aug_loss = (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() + + #KL divergence loss (w/ logits) - Prediction/Distribution similarity + kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1) + kl_loss = (w_loss * kl_loss).mean() + + loss = sup_loss + unsup_factor * (aug_loss + kl_loss) + + else: #Supervised loss (classic) + sup_logits = model(xs) + log_sup = F.log_softmax(sup_logits, dim=1) + loss = F.cross_entropy(log_sup, ys) + + return loss + def train_classic(model, opt_param, epochs=1, print_freq=1): """Classic training of a model. @@ -130,15 +183,14 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): return log -def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, KLdiv=1, hp_opt=False, save_sample_freq=None): +def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, unsup_loss=1, hp_opt=False, save_sample_freq=None): """Training of an augmented model with higher. This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py). Ex : Augmented_model(Data_augV5(...), Higher_model(model)) - Training loss can either be computed directly from augmented inputs (KLdiv=0). - However, it is recommended to use the KLdiv loss computation, inspired from UDA, which combine original and augmented inputs to compute the loss (KLdiv>0). - See : https://github.com/google-research/uda + Training loss can either be computed directly from augmented inputs (unsup_loss=0). + However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0). Args: model (nn.Module): Augmented model to train. @@ -147,7 +199,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1) dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0) print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1) - KLdiv (float): Proportion of the KLdiv loss added to the supervised loss. If set to 0, the loss is classicly computed on augmented inputs. (default: 1) + unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1) hp_opt (bool): Wether to learn inner optimizer parameters. (default: False) save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None) @@ -193,7 +245,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start for i, (xs, ys) in enumerate(dl_train): xs, ys = xs.to(device), ys.to(device) - if(KLdiv<=0): + if(unsup_loss==0): #Methode uniforme logits = model(xs) # modified `params` can also be passed as a kwarg loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards() @@ -204,32 +256,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start loss = loss.mean() else: - #Methode KL div - # Supervised loss (classic) - if model.is_augmenting() : - model.augment(mode=False) - sup_logits = model(xs) - model.augment(mode=True) - else: - sup_logits = model(xs) - log_sup=F.log_softmax(sup_logits, dim=1) - loss = F.cross_entropy(log_sup, ys) + #Methode mixed + loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss) - # Unsupervised loss (KLdiv) - if model.is_augmenting() : - aug_logits = model(xs) - log_aug=F.log_softmax(aug_logits, dim=1) - aug_loss=0 - w_loss = model['data_aug'].loss_weight() #Weight loss - - #KL div w/ logits - Similarite predictions (distributions) - aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug) - aug_loss = aug_loss.sum(dim=-1) - aug_loss = (w_loss * aug_loss).mean() - aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() - - loss += aug_loss * KLdiv - #print_graph(loss) #to visualize computational graph #t = time.process_time()