mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout plus de controle/Vision sur les optimizers
This commit is contained in:
parent
d1ee0c632e
commit
41c7273241
3 changed files with 49 additions and 23 deletions
|
@ -47,10 +47,10 @@ def compute_vaLoss(model, dl_it, dl):
|
|||
|
||||
return F.cross_entropy(model(xs), ys)
|
||||
|
||||
def train_classic(model, epochs=1, print_freq=1):
|
||||
def train_classic(model, opt_param, epochs=1, print_freq=1):
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
|
@ -305,11 +305,12 @@ def train_classic_tests(model, epochs=1):
|
|||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def train_UDA(model, dl_unsup, epochs=1, print_freq=1):
|
||||
def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1):
|
||||
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
|
@ -340,14 +341,13 @@ def train_UDA(model, dl_unsup, epochs=1, print_freq=1):
|
|||
sup_logits = model.forward(origin_xs)
|
||||
unsup_logits = model.forward(aug_xs)
|
||||
|
||||
#print(unsup_logits.shape, sup_logits.shape)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
log_unsup=F.log_softmax(unsup_logits, dim=1)
|
||||
#KL div w/ logits
|
||||
unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup)
|
||||
unsup_loss=unsup_loss.sum(dim=-1).mean()
|
||||
|
||||
#print(unsup_loss.shape)
|
||||
#print(unsup_loss)
|
||||
unsupp_coeff = 1
|
||||
loss = sup_loss + unsup_loss * unsupp_coeff
|
||||
|
||||
|
@ -629,7 +629,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
|||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
||||
def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
countcopy=0
|
||||
|
@ -637,8 +637,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
dl_val_it = iter(dl_val)
|
||||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) #lr=1e-2
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
|
@ -703,7 +703,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
|
||||
|
||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
aug_loss = (w_loss * aug_loss).mean() #apprentissage differe ?
|
||||
|
||||
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
#print(aug_loss)
|
||||
unsupp_coeff = 1
|
||||
loss += aug_loss * unsupp_coeff
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue