mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Minor improvement + Comments
This commit is contained in:
parent
d21a6bbf5c
commit
c1ad787d97
5 changed files with 165 additions and 62 deletions
|
@ -876,7 +876,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
else:
|
||||
#Methode KL div
|
||||
# Supervised loss (classic)
|
||||
if model._data_augmentation :
|
||||
if model.is_augmenting() :
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
|
@ -886,7 +886,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
# Unsupervised loss (KLdiv)
|
||||
if model._data_augmentation:
|
||||
if model.is_augmenting() :
|
||||
aug_logits = model(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss=0
|
||||
|
@ -948,7 +948,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
accuracy, test_loss =test(model)
|
||||
model.train()
|
||||
|
||||
print(model['data_aug']._data_augmentation)
|
||||
#### Log ####
|
||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||
data={
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue