mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Encapsulation de KLdiv loss dans mixed_loss
This commit is contained in:
parent
a2135e4709
commit
7742f76d12
2 changed files with 63 additions and 34 deletions
|
@ -183,7 +183,7 @@ if __name__ == "__main__":
|
|||
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
model = Higher_model(model) #run_dist_dataugV3
|
||||
aug_model = Augmented_model(Data_augV7(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
|
@ -192,8 +192,8 @@ if __name__ == "__main__":
|
|||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=20,
|
||||
KLdiv=True,
|
||||
print_freq=1,
|
||||
unsup_loss=1,
|
||||
hp_opt=False)
|
||||
|
||||
exec_time=time.process_time() - t0
|
||||
|
|
|
@ -68,6 +68,59 @@ def compute_vaLoss(model, dl_it, dl):
|
|||
model.eval() #Validation sans transfornations !
|
||||
return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
|
||||
|
||||
def mixed_loss(xs, ys, model, unsup_factor=1):
|
||||
"""Evaluate a model on a batch of data.
|
||||
|
||||
Compute a combinaison of losses:
|
||||
+ Supervised Cross-Entropy loss from original data.
|
||||
+ Unsupervised Cross-Entropy loss from augmented data.
|
||||
+ KL divergence loss encouraging similarity between original and augmented prediction.
|
||||
|
||||
If unsup_factor is equal to 0 or if there isn't data augmentation, only the supervised loss is computed.
|
||||
|
||||
Inspired by UDA, see: https://github.com/google-research/uda/blob/master/image/main.py
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of data.
|
||||
ys (Tensor): Batch of labels.
|
||||
model (nn.Module): Augmented model (see dataug.py).
|
||||
unsup_factor (float): Factor by which unsupervised CE and KL div loss are multiplied.
|
||||
|
||||
Returns:
|
||||
(Tensor) Mixed loss if there's data augmentation, just supervised CE loss otherwise.
|
||||
"""
|
||||
|
||||
#TODO: add test to prevent augmented model error and redirect to classic loss
|
||||
if unsup_factor!=0 and model.is_augmenting():
|
||||
|
||||
# Supervised loss (classic)
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
|
||||
log_sup = F.log_softmax(sup_logits, dim=1)
|
||||
sup_loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
# Unsupervised loss
|
||||
aug_logits = model(xs)
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
log_aug = F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss = (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
|
||||
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
|
||||
kl_loss = (F.softmax(sup_logits, dim=1)*(log_sup-log_aug)).sum(dim=-1)
|
||||
kl_loss = (w_loss * kl_loss).mean()
|
||||
|
||||
loss = sup_loss + unsup_factor * (aug_loss + kl_loss)
|
||||
|
||||
else: #Supervised loss (classic)
|
||||
sup_logits = model(xs)
|
||||
log_sup = F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
return loss
|
||||
|
||||
def train_classic(model, opt_param, epochs=1, print_freq=1):
|
||||
"""Classic training of a model.
|
||||
|
||||
|
@ -130,15 +183,14 @@ def train_classic(model, opt_param, epochs=1, print_freq=1):
|
|||
|
||||
return log
|
||||
|
||||
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, KLdiv=1, hp_opt=False, save_sample_freq=None):
|
||||
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, unsup_loss=1, hp_opt=False, save_sample_freq=None):
|
||||
"""Training of an augmented model with higher.
|
||||
|
||||
This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
|
||||
Ex : Augmented_model(Data_augV5(...), Higher_model(model))
|
||||
|
||||
Training loss can either be computed directly from augmented inputs (KLdiv=0).
|
||||
However, it is recommended to use the KLdiv loss computation, inspired from UDA, which combine original and augmented inputs to compute the loss (KLdiv>0).
|
||||
See : https://github.com/google-research/uda
|
||||
Training loss can either be computed directly from augmented inputs (unsup_loss=0).
|
||||
However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
|
||||
|
||||
Args:
|
||||
model (nn.Module): Augmented model to train.
|
||||
|
@ -147,7 +199,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
|
||||
dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
KLdiv (float): Proportion of the KLdiv loss added to the supervised loss. If set to 0, the loss is classicly computed on augmented inputs. (default: 1)
|
||||
unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
|
||||
hp_opt (bool): Wether to learn inner optimizer parameters. (default: False)
|
||||
save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None)
|
||||
|
||||
|
@ -193,7 +245,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
if(KLdiv<=0):
|
||||
if(unsup_loss==0):
|
||||
#Methode uniforme
|
||||
logits = model(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||
|
@ -204,31 +256,8 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
|
|||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
# Supervised loss (classic)
|
||||
if model.is_augmenting() :
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
else:
|
||||
sup_logits = model(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
# Unsupervised loss (KLdiv)
|
||||
if model.is_augmenting() :
|
||||
aug_logits = model(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss=0
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
#KL div w/ logits - Similarite predictions (distributions)
|
||||
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||
aug_loss = aug_loss.sum(dim=-1)
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
|
||||
loss += aug_loss * KLdiv
|
||||
#Methode mixed
|
||||
loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
|
||||
|
||||
#print_graph(loss) #to visualize computational graph
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue