2020-01-24 11:50:30 -05:00
""" Utilities function for training.
"""
2024-08-20 11:53:35 +02:00
import sys
2019-11-13 11:45:05 -05:00
import torch
#import torch.optim
2024-08-20 11:53:35 +02:00
#import torchvision
2019-11-13 11:45:05 -05:00
import higher
2020-02-04 11:56:45 -05:00
import higher_patch
2019-11-13 11:45:05 -05:00
from datasets import *
from utils import *
2024-08-20 11:53:35 +02:00
from transformations import Normalizer , translate , zero_stack
norm = Normalizer ( MEAN , STD )
2020-01-31 16:43:10 -05:00
confmat = ConfusionMatrix ( num_classes = len ( dl_test . dataset . classes ) )
2024-08-20 11:53:35 +02:00
max_grad = 1 #Max gradient value #Limite catastrophic drop
def test ( model , augment = 0 ) :
2020-01-24 11:50:30 -05:00
""" Evaluate a model on test data.
Args :
model ( nn . Module ) : Model to test .
2024-08-20 11:53:35 +02:00
augment ( int ) : Number of augmented example for each sample . ( Default : 0 )
2020-01-24 11:50:30 -05:00
Returns :
2020-01-31 16:43:10 -05:00
( float , Tensor ) Returns the accuracy and F1 score of the model .
2020-01-24 11:50:30 -05:00
"""
2019-11-13 11:45:05 -05:00
device = next ( model . parameters ( ) ) . device
model . eval ( )
2024-08-20 11:53:35 +02:00
# model['model']['functional'].set_mode('mixed') #ABN
2019-11-13 11:45:05 -05:00
2019-11-13 13:38:00 -05:00
#for i, (features, labels) in enumerate(dl_test):
# features,labels = features.to(device), labels.to(device)
2019-11-13 11:45:05 -05:00
2019-11-13 13:38:00 -05:00
# pred = model.forward(features)
# return pred.argmax(dim=1).eq(labels).sum().item() / dl_test.batch_size * 100
correct = 0
total = 0
2020-01-31 16:43:10 -05:00
#loss = []
2024-08-20 11:53:35 +02:00
global confmat
2020-01-31 16:43:10 -05:00
confmat . reset ( )
2019-11-13 13:38:00 -05:00
with torch . no_grad ( ) :
for features , labels in dl_test :
features , labels = features . to ( device ) , labels . to ( device )
2024-08-20 11:53:35 +02:00
if augment > 0 : #Test Time Augmentation
model . augment ( True )
# V2
features = torch . cat ( [ features for _ in range ( augment ) ] , dim = 0 ) # (B,C,H,W)=>(B*augment,C,H,W)
outputs = model ( features )
outputs = torch . cat ( [ o . unsqueeze ( dim = 0 ) for o in outputs . chunk ( chunks = augment , dim = 0 ) ] , dim = 0 ) # (B*augment,nb_class)=>(augment,B,nb_class)
w_losses = model [ ' data_aug ' ] . loss_weight ( batch_norm = False ) #(B*augment) if Dataug
if w_losses . shape [ 0 ] == 1 : #RandAugment
outputs = torch . sum ( outputs , axis = 0 ) / augment #mean
else : #Dataug
w_losses = torch . cat ( [ w . unsqueeze ( dim = 0 ) for w in w_losses . chunk ( chunks = augment , dim = 0 ) ] , dim = 0 ) #(augment, B)
w_losses = w_losses / w_losses . sum ( axis = 0 , keepdim = True ) #sum(w_losses)=1 pour un même echantillons
outputs = torch . sum ( outputs * w_losses . unsqueeze ( dim = 2 ) . expand_as ( outputs ) , axis = 0 ) / augment #Weighted mean
else :
outputs = model ( features )
2019-11-13 13:38:00 -05:00
_ , predicted = torch . max ( outputs . data , 1 )
total + = labels . size ( 0 )
correct + = ( predicted == labels ) . sum ( ) . item ( )
2020-01-31 16:43:10 -05:00
#loss.append(F.cross_entropy(outputs, labels).item())
confmat . update ( labels , predicted )
2019-11-13 13:38:00 -05:00
accuracy = 100 * correct / total
2020-01-31 16:43:10 -05:00
#print(confmat)
#from sklearn.metrics import f1_score
#f1 = f1_score(labels.data.to('cpu'), predicted.data.to('cpu'), average="macro")
2020-02-03 11:21:54 -05:00
return accuracy , confmat . f1_metric ( average = None )
2019-11-13 13:38:00 -05:00
def compute_vaLoss ( model , dl_it , dl ) :
2020-01-24 11:50:30 -05:00
""" Evaluate a model on a batch of data.
Args :
model ( nn . Module ) : Model to evaluate .
dl_it ( Iterator ) : Data loader iterator .
dl ( DataLoader ) : Data loader .
Returns :
( Tensor ) Loss on a single batch of data .
"""
2019-11-13 11:45:05 -05:00
device = next ( model . parameters ( ) ) . device
try :
2019-11-13 13:38:00 -05:00
xs , ys = next ( dl_it )
2019-11-13 11:45:05 -05:00
except StopIteration : #Fin epoch val
2019-11-13 13:38:00 -05:00
dl_it = iter ( dl )
xs , ys = next ( dl_it )
xs , ys = xs . to ( device ) , ys . to ( device )
2019-11-13 11:45:05 -05:00
2024-08-20 11:53:35 +02:00
model . eval ( ) #Validation sans transformations !
# model['model']['functional'].set_mode('mixed') #ABN
# return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
return F . cross_entropy ( model ( xs ) , ys )
2019-11-13 11:45:05 -05:00
2024-08-20 11:53:35 +02:00
def mixed_loss ( xs , ys , model , unsup_factor = 1 , augment = 1 ) :
2020-01-27 17:29:45 -05:00
""" Evaluate a model on a batch of data.
Compute a combinaison of losses :
+ Supervised Cross - Entropy loss from original data .
+ Unsupervised Cross - Entropy loss from augmented data .
+ KL divergence loss encouraging similarity between original and augmented prediction .
If unsup_factor is equal to 0 or if there isn ' t data augmentation, only the supervised loss is computed.
Inspired by UDA , see : https : / / github . com / google - research / uda / blob / master / image / main . py
Args :
xs ( Tensor ) : Batch of data .
ys ( Tensor ) : Batch of labels .
model ( nn . Module ) : Augmented model ( see dataug . py ) .
unsup_factor ( float ) : Factor by which unsupervised CE and KL div loss are multiplied .
2024-08-20 11:53:35 +02:00
augment ( int ) : Number of augmented example for each sample . ( Default : 1 )
2020-01-27 17:29:45 -05:00
Returns :
( Tensor ) Mixed loss if there ' s data augmentation, just supervised CE loss otherwise.
"""
#TODO: add test to prevent augmented model error and redirect to classic loss
2024-08-20 11:53:35 +02:00
if unsup_factor != 0 and model . is_augmenting ( ) and augment > 0 :
2020-01-27 17:29:45 -05:00
2024-08-20 11:53:35 +02:00
# Supervised loss - Cross-entropy
2020-01-27 17:29:45 -05:00
model . augment ( mode = False )
sup_logits = model ( xs )
model . augment ( mode = True )
log_sup = F . log_softmax ( sup_logits , dim = 1 )
2024-08-20 11:53:35 +02:00
sup_loss = F . nll_loss ( log_sup , ys )
# sup_loss = F.cross_entropy(log_sup, ys)
if augment > 1 :
# Unsupervised loss - Cross-Entropy
xs_a = torch . cat ( [ xs for _ in range ( augment ) ] , dim = 0 ) # (B,C,H,W)=>(B*augment,C,H,W)
ys_a = torch . cat ( [ ys for _ in range ( augment ) ] , dim = 0 )
aug_logits = model ( xs_a ) # (B*augment,nb_class)
w_loss = model [ ' data_aug ' ] . loss_weight ( ) #(B*augment) if Dataug
log_aug = F . log_softmax ( aug_logits , dim = 1 )
aug_loss = F . nll_loss ( log_aug , ys_a , reduction = ' none ' )
# aug_loss = F.cross_entropy(log_aug, ys_a , reduction='none')
aug_loss = ( aug_loss * w_loss ) . mean ( )
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
sup_logits_a = torch . cat ( [ sup_logits for _ in range ( augment ) ] , dim = 0 )
log_sup_a = torch . cat ( [ log_sup for _ in range ( augment ) ] , dim = 0 )
kl_loss = ( F . softmax ( sup_logits_a , dim = 1 ) * ( log_sup_a - log_aug ) ) . sum ( dim = - 1 )
kl_loss = ( w_loss * kl_loss ) . mean ( )
else :
# Unsupervised loss - Cross-Entropy
aug_logits = model ( xs )
w_loss = model [ ' data_aug ' ] . loss_weight ( ) #Weight loss
log_aug = F . log_softmax ( aug_logits , dim = 1 )
aug_loss = F . nll_loss ( log_aug , ys , reduction = ' none ' )
# aug_loss = F.cross_entropy(log_aug, ys , reduction='none')
aug_loss = ( aug_loss * w_loss ) . mean ( )
2020-01-27 17:29:45 -05:00
2024-08-20 11:53:35 +02:00
#KL divergence loss (w/ logits) - Prediction/Distribution similarity
kl_loss = ( F . softmax ( sup_logits , dim = 1 ) * ( log_sup - log_aug ) ) . sum ( dim = - 1 )
kl_loss = ( w_loss * kl_loss ) . mean ( )
2020-01-27 17:29:45 -05:00
loss = sup_loss + unsup_factor * ( aug_loss + kl_loss )
2024-08-20 11:53:35 +02:00
else : #Supervised loss - Cross-Entropy
2020-01-27 17:29:45 -05:00
sup_logits = model ( xs )
2024-08-20 11:53:35 +02:00
loss = F . cross_entropy ( sup_logits , ys )
# log_sup = F.log_softmax(sup_logits, dim=1)
# loss = F.cross_entropy(log_sup, ys)
2020-01-27 17:29:45 -05:00
return loss
2019-12-09 13:49:57 -05:00
def train_classic ( model , opt_param , epochs = 1 , print_freq = 1 ) :
2020-01-24 11:50:30 -05:00
""" Classic training of a model.
Args :
model ( nn . Module ) : Model to train .
opt_param ( dict ) : Dictionnary containing optimizers parameters .
epochs ( int ) : Number of epochs to perform . ( default : 1 )
print_freq ( int ) : Number of epoch between display of the state of training . If set to None , no display will be done . ( default : 1 )
Returns :
( list ) Logs of training . Each items is a dict containing results of an epoch .
"""
2019-11-13 11:45:05 -05:00
device = next ( model . parameters ( ) ) . device
2020-02-12 13:43:44 -05:00
#Optimizer
2019-11-13 11:45:05 -05:00
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
2020-02-10 16:32:59 -05:00
optim = torch . optim . SGD ( model . parameters ( ) ,
lr = opt_param [ ' Inner ' ] [ ' lr ' ] ,
momentum = opt_param [ ' Inner ' ] [ ' momentum ' ] ,
2024-08-20 11:53:35 +02:00
weight_decay = opt_param [ ' Inner ' ] [ ' weight_decay ' ] ,
2020-02-10 16:32:59 -05:00
nesterov = opt_param [ ' Inner ' ] [ ' nesterov ' ] ) #lr=1e-2 / momentum=0.9
2019-11-13 11:45:05 -05:00
2020-02-12 13:43:44 -05:00
#Scheduler
inner_scheduler = None
if opt_param [ ' Inner ' ] [ ' scheduler ' ] == ' cosine ' :
inner_scheduler = torch . optim . lr_scheduler . CosineAnnealingLR ( optim , T_max = epochs , eta_min = 0. )
elif opt_param [ ' Inner ' ] [ ' scheduler ' ] == ' multiStep ' :
#Multistep milestones inspired by AutoAugment
inner_scheduler = torch . optim . lr_scheduler . MultiStepLR ( optim ,
milestones = [ int ( epochs / 3 ) , int ( epochs * 2 / 3 ) , int ( epochs * 2.7 / 3 ) ] ,
gamma = 0.1 )
elif opt_param [ ' Inner ' ] [ ' scheduler ' ] == ' exponential ' :
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
inner_scheduler = torch . optim . lr_scheduler . LambdaLR ( optim , lambda epoch : ( 1 - epoch / epochs ) * * 0.9 )
elif opt_param [ ' Inner ' ] [ ' scheduler ' ] is not None :
raise ValueError ( " Lr scheduler unknown : %s " % opt_param [ ' Inner ' ] [ ' scheduler ' ] )
2024-08-20 11:53:35 +02:00
# from warmup_scheduler import GradualWarmupScheduler
# inner_scheduler=GradualWarmupScheduler(optim, multiplier=2, total_epoch=5, after_scheduler=inner_scheduler)
2020-02-12 13:43:44 -05:00
#Training
2019-11-13 11:45:05 -05:00
model . train ( )
dl_val_it = iter ( dl_val )
log = [ ]
for epoch in range ( epochs ) :
#print_torch_mem("Start epoch")
2020-02-12 13:43:44 -05:00
#print(optim.param_groups[0]['lr'])
2020-02-03 15:08:22 -05:00
t0 = time . perf_counter ( )
2019-11-13 11:45:05 -05:00
for i , ( features , labels ) in enumerate ( dl_train ) :
2020-01-31 10:34:44 -05:00
#viz_sample_data(imgs=features, labels=labels, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
2019-11-13 11:45:05 -05:00
#print_torch_mem("Start iter")
features , labels = features . to ( device ) , labels . to ( device )
optim . zero_grad ( )
2019-12-06 14:13:28 -05:00
logits = model . forward ( features )
pred = F . log_softmax ( logits , dim = 1 )
2019-11-13 11:45:05 -05:00
loss = F . cross_entropy ( pred , labels )
loss . backward ( )
optim . step ( )
2024-08-20 11:53:35 +02:00
# print_graph(loss, '../samples/torchvision_WRN') #to visualize computational graph
# sys.exit()
2020-02-12 13:43:44 -05:00
if inner_scheduler is not None :
inner_scheduler . step ( )
2024-08-20 11:53:35 +02:00
# print(optim.param_groups[0]['lr'])
2020-02-12 13:43:44 -05:00
2019-11-13 11:45:05 -05:00
#### Tests ####
2020-02-03 15:08:22 -05:00
tf = time . perf_counter ( )
2019-11-13 11:45:05 -05:00
2020-01-16 16:38:15 -05:00
val_loss = compute_vaLoss ( model = model , dl_it = dl_val_it , dl = dl_val )
2020-02-03 11:21:54 -05:00
accuracy , f1 = test ( model )
2019-11-13 11:45:05 -05:00
model . train ( )
2019-12-04 14:48:11 -05:00
2019-11-13 11:45:05 -05:00
#### Log ####
data = {
" epoch " : epoch ,
" train_loss " : loss . item ( ) ,
" val_loss " : val_loss . item ( ) ,
" acc " : accuracy ,
2020-02-03 17:46:32 -05:00
" f1 " : f1 . tolist ( ) ,
2019-11-13 11:45:05 -05:00
" time " : tf - t0 ,
" param " : None ,
}
log . append ( data )
2020-02-12 13:43:44 -05:00
#### Print ####
if ( print_freq and epoch % print_freq == 0 ) :
print ( ' - ' * 9 )
print ( ' Epoch : %d / %d ' % ( epoch , epochs ) )
print ( ' Time : %.00f ' % ( tf - t0 ) )
print ( ' Train loss : ' , loss . item ( ) , ' / val loss ' , val_loss . item ( ) )
print ( ' Accuracy max: ' , max ( [ x [ " acc " ] for x in log ] ) )
print ( ' F1 : ' , [ " {0:0.4f} " . format ( i ) for i in f1 ] )
2019-11-13 11:45:05 -05:00
return log
2024-08-20 11:53:35 +02:00
def run_dist_dataugV3 ( model , opt_param , epochs = 1 , inner_it = 1 , dataug_epoch_start = 0 , unsup_loss = 1 , augment_loss = 1 , hp_opt = False , print_freq = 1 , save_sample_freq = None ) :
2020-01-24 11:50:30 -05:00
""" Training of an augmented model with higher.
This function is intended to be used with Augmented_model containing an Higher_model ( see dataug . py ) .
Ex : Augmented_model ( Data_augV5 ( . . . ) , Higher_model ( model ) )
2020-01-27 17:29:45 -05:00
Training loss can either be computed directly from augmented inputs ( unsup_loss = 0 ) .
However , it is recommended to use the mixed loss computation , which combine original and augmented inputs to compute the loss ( unsup_loss > 0 ) .
2020-01-24 11:50:30 -05:00
Args :
model ( nn . Module ) : Augmented model to train .
opt_param ( dict ) : Dictionnary containing optimizers parameters .
epochs ( int ) : Number of epochs to perform . ( default : 1 )
inner_it ( int ) : Number of inner iteration before a meta - step . 0 inner iteration means there ' s no meta-step. (default: 1)
dataug_epoch_start ( int ) : Epoch when to start data augmentation . ( default : 0 )
2020-01-27 17:29:45 -05:00
unsup_loss ( float ) : Proportion of the unsup_loss loss added to the supervised loss . If set to 0 , the loss is only computed on augmented inputs . ( default : 1 )
2024-08-20 11:53:35 +02:00
augment_loss ( int ) : Number of augmented example for each sample in loss computation . ( Default : 1 )
2020-01-24 11:50:30 -05:00
hp_opt ( bool ) : Wether to learn inner optimizer parameters . ( default : False )
2024-08-20 11:53:35 +02:00
print_freq ( int ) : Number of epoch between display of the state of training . If set to None , no display will be done . ( default : 1 )
2020-01-31 16:43:10 -05:00
save_sample_freq ( int ) : Number of epochs between saves of samples of data . If set to None , no sample will be saved . ( default : None )
2020-01-24 11:50:30 -05:00
Returns :
( list ) Logs of training . Each items is a dict containing results of an epoch .
"""
2020-01-10 13:21:34 -05:00
device = next ( model . parameters ( ) ) . device
log = [ ]
2024-08-20 11:53:35 +02:00
# kl_log={"prob":[], "mag":[]}
2020-02-04 11:56:45 -05:00
dl_val_it = iter ( dl_val )
2020-01-10 13:21:34 -05:00
high_grad_track = True
2020-01-20 11:05:40 -05:00
if inner_it == 0 : #No HP optimization
2020-01-10 13:21:34 -05:00
high_grad_track = False
2020-01-20 11:05:40 -05:00
if dataug_epoch_start != 0 : #Augmentation de donnee differee
2020-01-10 13:21:34 -05:00
model . augment ( mode = False )
high_grad_track = False
2020-01-16 16:38:15 -05:00
## Optimizers ##
#Inner Opt
2020-02-12 13:43:44 -05:00
inner_opt = torch . optim . SGD ( model [ ' model ' ] [ ' original ' ] . parameters ( ) ,
lr = opt_param [ ' Inner ' ] [ ' lr ' ] ,
2020-02-10 16:32:59 -05:00
momentum = opt_param [ ' Inner ' ] [ ' momentum ' ] ,
2024-08-20 11:53:35 +02:00
weight_decay = opt_param [ ' Inner ' ] [ ' weight_decay ' ] ,
2020-02-10 16:32:59 -05:00
nesterov = opt_param [ ' Inner ' ] [ ' nesterov ' ] ) #lr=1e-2 / momentum=0.9
2020-01-15 16:55:03 -05:00
diffopt = model [ ' model ' ] . get_diffopt (
inner_opt ,
2024-08-20 11:53:35 +02:00
grad_callback = ( lambda grads : clip_norm ( grads , max_norm = max_grad ) ) ,
2020-01-15 16:55:03 -05:00
track_higher_grads = high_grad_track )
2020-01-10 13:21:34 -05:00
2020-02-12 13:43:44 -05:00
#Scheduler
inner_scheduler = None
if opt_param [ ' Inner ' ] [ ' scheduler ' ] == ' cosine ' :
2020-02-14 13:57:17 -05:00
inner_scheduler = torch . optim . lr_scheduler . CosineAnnealingLR ( inner_opt , T_max = epochs , eta_min = 0. )
2020-02-12 13:43:44 -05:00
elif opt_param [ ' Inner ' ] [ ' scheduler ' ] == ' multiStep ' :
#Multistep milestones inspired by AutoAugment
2020-02-14 13:57:17 -05:00
inner_scheduler = torch . optim . lr_scheduler . MultiStepLR ( inner_opt ,
2020-02-12 13:43:44 -05:00
milestones = [ int ( epochs / 3 ) , int ( epochs * 2 / 3 ) , int ( epochs * 2.7 / 3 ) ] ,
gamma = 0.1 )
elif opt_param [ ' Inner ' ] [ ' scheduler ' ] == ' exponential ' :
#inner_scheduler=torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.1) #Wrong gamma
2020-02-14 13:57:17 -05:00
inner_scheduler = torch . optim . lr_scheduler . LambdaLR ( inner_opt , lambda epoch : ( 1 - epoch / epochs ) * * 0.9 )
2024-08-20 11:53:35 +02:00
elif not ( opt_param [ ' Inner ' ] [ ' scheduler ' ] is None or opt_param [ ' Inner ' ] [ ' scheduler ' ] == ' ' ) :
2020-02-12 13:43:44 -05:00
raise ValueError ( " Lr scheduler unknown : %s " % opt_param [ ' Inner ' ] [ ' scheduler ' ] )
2024-08-20 11:53:35 +02:00
#Warmup
if opt_param [ ' Inner ' ] [ ' warmup ' ] [ ' multiplier ' ] > = 1 :
from warmup_scheduler import GradualWarmupScheduler
inner_scheduler = GradualWarmupScheduler ( inner_opt ,
multiplier = opt_param [ ' Inner ' ] [ ' warmup ' ] [ ' multiplier ' ] ,
total_epoch = opt_param [ ' Inner ' ] [ ' warmup ' ] [ ' epochs ' ] ,
after_scheduler = inner_scheduler )
2020-01-16 16:38:15 -05:00
#Meta Opt
hyper_param = list ( model [ ' data_aug ' ] . parameters ( ) )
2024-08-20 11:53:35 +02:00
if hp_opt : #(deprecated)
2020-01-16 16:38:15 -05:00
for param_group in diffopt . param_groups :
2024-08-20 11:53:35 +02:00
# print(param_group)
for param in hp_opt :
2020-01-16 16:38:15 -05:00
param_group [ param ] = torch . tensor ( param_group [ param ] ) . to ( device ) . requires_grad_ ( )
hyper_param + = [ param_group [ param ] ]
2024-08-20 11:53:35 +02:00
meta_opt = torch . optim . Adam ( hyper_param , lr = opt_param [ ' Meta ' ] [ ' lr ' ] )
2020-01-10 13:21:34 -05:00
2024-08-20 11:53:35 +02:00
#Meta-Scheduler (deprecated)
2020-02-26 12:19:44 -05:00
meta_scheduler = None
if opt_param [ ' Meta ' ] [ ' scheduler ' ] == ' multiStep ' :
meta_scheduler = torch . optim . lr_scheduler . MultiStepLR ( meta_opt ,
2020-03-06 14:25:07 -05:00
milestones = [ int ( epochs / 3 ) , int ( epochs * 2 / 3 ) ] , # int(epochs*2.7/3)],
2024-08-20 11:53:35 +02:00
gamma = 3.16 )
2020-02-26 12:19:44 -05:00
elif opt_param [ ' Meta ' ] [ ' scheduler ' ] is not None :
raise ValueError ( " Lr scheduler unknown : %s " % opt_param [ ' Meta ' ] [ ' scheduler ' ] )
2020-01-16 16:38:15 -05:00
model . train ( )
meta_opt . zero_grad ( )
2020-01-10 13:21:34 -05:00
for epoch in range ( 1 , epochs + 1 ) :
2020-02-03 15:08:22 -05:00
t0 = time . perf_counter ( )
2020-02-21 15:29:34 -05:00
val_loss = None
2020-01-17 11:08:59 -05:00
2020-02-05 12:23:23 -05:00
#Cross-Validation
2020-02-04 11:56:45 -05:00
#dl_train, dl_val = cvs.next_split()
#dl_val_it = iter(dl_val)
2020-02-03 15:08:22 -05:00
2020-01-10 13:21:34 -05:00
for i , ( xs , ys ) in enumerate ( dl_train ) :
xs , ys = xs . to ( device ) , ys . to ( device )
2024-08-20 11:53:35 +02:00
2020-01-27 17:29:45 -05:00
if ( unsup_loss == 0 ) :
2020-01-10 13:21:34 -05:00
#Methode uniforme
2020-01-15 16:55:03 -05:00
logits = model ( xs ) # modified `params` can also be passed as a kwarg
2020-01-10 13:21:34 -05:00
loss = F . cross_entropy ( F . log_softmax ( logits , dim = 1 ) , ys , reduction = ' none ' ) # no need to call loss.backwards()
2020-01-15 16:55:03 -05:00
if model . _data_augmentation : #Weight loss
w_loss = model [ ' data_aug ' ] . loss_weight ( ) #.to(device)
2020-01-10 13:21:34 -05:00
loss = loss * w_loss
loss = loss . mean ( )
else :
2020-01-27 17:29:45 -05:00
#Methode mixed
2024-08-20 11:53:35 +02:00
loss = mixed_loss ( xs , ys , model , unsup_factor = unsup_loss , augment = augment_loss )
2020-01-27 17:29:45 -05:00
2024-08-20 11:53:35 +02:00
# print_graph(loss, '../samples/pytorch_WRN') #to visualize computational graph
# sys.exit()
2020-01-10 13:21:34 -05:00
2024-08-20 11:53:35 +02:00
# t = time.process_time()
2020-02-12 13:43:44 -05:00
diffopt . step ( loss ) #(opt.zero_grad, loss.backward, opt.step)
2024-08-20 11:53:35 +02:00
# print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
2020-01-10 13:21:34 -05:00
2020-02-19 11:59:04 -05:00
if ( high_grad_track and i > 0 and i % inner_it == 0 and epoch > = opt_param [ ' Meta ' ] [ ' epoch_start ' ] ) : #Perform Meta step
2020-02-25 14:05:17 -05:00
val_loss = compute_vaLoss ( model = model , dl_it = dl_val_it , dl = dl_val ) + model [ ' data_aug ' ] . reg_loss ( opt_param [ ' Meta ' ] [ ' reg_factor ' ] )
2024-08-20 11:53:35 +02:00
model . train ( )
2020-01-17 11:08:59 -05:00
#print_graph(val_loss) #to visualize computational graph
2020-01-10 13:21:34 -05:00
val_loss . backward ( )
2024-08-20 11:53:35 +02:00
torch . nn . utils . clip_grad_norm_ ( model [ ' data_aug ' ] . parameters ( ) , max_norm = max_grad , norm_type = 2 ) #Prevent exploding grad with RNN
2020-01-10 13:21:34 -05:00
2024-08-20 11:53:35 +02:00
# print("Grad mix",model['data_aug']["temp"].grad)
# prv_param=model['data_aug']._params
2020-01-10 13:21:34 -05:00
meta_opt . step ( )
2024-08-20 11:53:35 +02:00
# kl_log["prob"].append(F.kl_div(prv_param["prob"],model['data_aug']["prob"], reduction='batchmean').item())
# kl_log["mag"].append(F.kl_div(prv_param["mag"],model['data_aug']["mag"], reduction='batchmean').item())
2020-01-10 13:21:34 -05:00
2020-01-20 11:05:40 -05:00
#Adjust Hyper-parameters
2024-08-20 11:53:35 +02:00
model [ ' data_aug ' ] . adjust_param ( )
2020-01-20 11:05:40 -05:00
if hp_opt :
2020-01-16 16:38:15 -05:00
for param_group in diffopt . param_groups :
2024-08-20 11:53:35 +02:00
for param in hp_opt :
param_group [ param ] . data = param_group [ param ] . data . clamp ( min = 1e-5 )
2020-01-16 16:38:15 -05:00
2020-01-20 11:05:40 -05:00
#Reset gradients
2020-01-10 13:21:34 -05:00
diffopt . detach_ ( )
2020-01-15 16:55:03 -05:00
model [ ' model ' ] . detach_ ( )
2020-01-16 16:38:15 -05:00
meta_opt . zero_grad ( )
2020-01-30 11:21:25 -05:00
2020-02-19 11:59:04 -05:00
elif not high_grad_track or epoch < opt_param [ ' Meta ' ] [ ' epoch_start ' ] :
diffopt . detach_ ( )
2020-01-30 11:21:25 -05:00
model [ ' model ' ] . detach_ ( )
2020-02-19 11:59:04 -05:00
meta_opt . zero_grad ( )
2020-01-15 16:55:03 -05:00
2020-02-03 15:08:22 -05:00
tf = time . perf_counter ( )
2020-01-10 13:21:34 -05:00
2020-02-26 12:19:44 -05:00
#Schedulers
2020-02-12 13:43:44 -05:00
if inner_scheduler is not None :
inner_scheduler . step ( )
#Transfer inner_opt lr to diffopt
for diff_param_group in diffopt . param_groups :
for param_group in inner_opt . param_groups :
diff_param_group [ ' lr ' ] = param_group [ ' lr ' ]
2020-02-26 12:19:44 -05:00
if meta_scheduler is not None :
meta_scheduler . step ( )
2024-08-20 11:53:35 +02:00
# if epoch<epochs/3:
# model['data_aug']['temp'].data=torch.tensor(0.5, device=device)
# elif epoch>epochs/3 and epoch<(epochs*2/3):
# model['data_aug']['temp'].data=torch.tensor(0.75, device=device)
# elif epoch>(epochs*2/3):
# model['data_aug']['temp'].data=torch.tensor(1.0, device=device)
# model['data_aug']['temp'].data=torch.tensor(1./3+2/3*(epoch/epochs), device=device)
# print('Temp',model['data_aug']['temp'])
2020-02-12 13:43:44 -05:00
2020-01-24 11:50:30 -05:00
if ( save_sample_freq and epoch % save_sample_freq == 0 ) : #Data sample saving
2020-01-17 11:08:59 -05:00
try :
2020-01-24 14:32:37 -05:00
viz_sample_data ( imgs = xs , labels = ys , fig_name = ' ../samples/data_sample_epoch {} _noTF ' . format ( epoch ) )
2020-01-30 11:21:25 -05:00
model . train ( )
2024-08-20 11:53:35 +02:00
viz_sample_data ( imgs = model [ ' data_aug ' ] ( xs ) , labels = ys , fig_name = ' ../samples/data_sample_epoch {} ' . format ( epoch ) , weight_labels = model [ ' data_aug ' ] . loss_weight ( ) )
2020-01-30 11:21:25 -05:00
model . eval ( )
2020-01-17 11:08:59 -05:00
except :
2024-08-20 11:53:35 +02:00
print ( " Couldn ' t save samples epoch %d : %s " % ( epoch , str ( sys . exc_info ( ) [ 1 ] ) ) )
2020-01-17 11:08:59 -05:00
pass
2020-01-15 16:55:03 -05:00
2020-01-20 11:05:40 -05:00
if ( not val_loss ) : #Compute val loss for logs
2020-01-15 16:55:03 -05:00
val_loss = compute_vaLoss ( model = model , dl_it = dl_val_it , dl = dl_val )
2020-01-10 13:21:34 -05:00
2020-01-20 11:05:40 -05:00
# Test model
2020-01-31 16:43:10 -05:00
accuracy , f1 = test ( model )
2020-01-10 13:21:34 -05:00
model . train ( )
2019-11-13 11:45:05 -05:00
#### Log ####
2019-11-18 16:48:51 -05:00
param = [ { ' p ' : p . item ( ) , ' m ' : model [ ' data_aug ' ] [ ' mag ' ] . item ( ) } for p in model [ ' data_aug ' ] [ ' prob ' ] ] if model [ ' data_aug ' ] . _shared_mag else [ { ' p ' : p . item ( ) , ' m ' : m . item ( ) } for p , m in zip ( model [ ' data_aug ' ] [ ' prob ' ] , model [ ' data_aug ' ] [ ' mag ' ] ) ]
2019-11-13 11:45:05 -05:00
data = {
" epoch " : epoch ,
" train_loss " : loss . item ( ) ,
" val_loss " : val_loss . item ( ) ,
" acc " : accuracy ,
2020-02-03 17:46:32 -05:00
" f1 " : f1 . tolist ( ) ,
2019-11-13 11:45:05 -05:00
" time " : tf - t0 ,
2020-01-20 11:05:40 -05:00
" param " : param ,
2019-11-13 11:45:05 -05:00
}
2024-08-20 11:53:35 +02:00
if not model [ ' data_aug ' ] . _fixed_temp : data [ " temp " ] = model [ ' data_aug ' ] [ ' temp ' ] . item ( )
if hp_opt : data [ " opt_param " ] = [ { ' lr ' : p_grp [ ' lr ' ] , ' momentum ' : p_grp [ ' momentum ' ] } for p_grp in diffopt . param_groups ]
2019-11-13 11:45:05 -05:00
log . append ( data )
#############
2020-01-10 13:21:34 -05:00
#### Print ####
if ( print_freq and epoch % print_freq == 0 ) :
print ( ' - ' * 9 )
print ( ' Epoch : %d / %d ' % ( epoch , epochs ) )
print ( ' Time : %.00f ' % ( tf - t0 ) )
print ( ' Train loss : ' , loss . item ( ) , ' / val loss ' , val_loss . item ( ) )
2020-02-03 15:08:22 -05:00
print ( ' Accuracy max: ' , max ( [ x [ " acc " ] for x in log ] ) )
2020-02-03 17:46:32 -05:00
print ( ' F1 : ' , [ " {0:0.4f} " . format ( i ) for i in f1 ] )
2020-01-10 13:21:34 -05:00
print ( ' Data Augmention : {} (Epoch {} ) ' . format ( model . _data_augmentation , dataug_epoch_start ) )
2020-02-03 17:46:32 -05:00
if not model [ ' data_aug ' ] . _fixed_prob : print ( ' TF Proba : ' , [ " {0:0.4f} " . format ( p ) for p in model [ ' data_aug ' ] [ ' prob ' ] ] )
2020-01-10 13:21:34 -05:00
#print('proba grad',model['data_aug']['prob'].grad)
2024-08-20 11:53:35 +02:00
if not model [ ' data_aug ' ] . _fixed_mag :
if model [ ' data_aug ' ] . _shared_mag :
print ( ' TF Mag : ' , " {0:0.4f} " . format ( model [ ' data_aug ' ] [ ' mag ' ] ) )
else :
print ( ' TF Mag : ' , [ " {0:0.4f} " . format ( m ) for m in model [ ' data_aug ' ] [ ' mag ' ] ] )
2020-01-10 13:21:34 -05:00
#print('Mag grad',model['data_aug']['mag'].grad)
2024-08-20 11:53:35 +02:00
if not model [ ' data_aug ' ] . _fixed_temp : print ( ' Temp: ' , model [ ' data_aug ' ] [ ' temp ' ] . item ( ) )
2020-01-10 13:21:34 -05:00
#print('Reg loss:', model['data_aug'].reg_loss().item())
2024-08-20 11:53:35 +02:00
# if len(kl_log["prob"])!=0:
# print("KL prob : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["prob"]), np.std(kl_log["prob"]), max(kl_log["prob"]), min(kl_log["prob"])))
# print("KL mag : mean %f, std %f, max %f, min %f"%(np.mean(kl_log["mag"]), np.std(kl_log["mag"]), max(kl_log["mag"]), min(kl_log["mag"])))
# kl_log={"prob":[], "mag":[]}
2020-01-16 16:38:15 -05:00
if hp_opt :
for param_group in diffopt . param_groups :
print ( ' Opt param - lr: ' , param_group [ ' lr ' ] . item ( ) , ' - momentum: ' , param_group [ ' momentum ' ] . item ( ) )
2020-01-10 13:21:34 -05:00
#############
2020-01-20 11:05:40 -05:00
#Augmentation de donnee differee
2020-01-17 11:08:59 -05:00
if not model . is_augmenting ( ) and ( epoch == dataug_epoch_start ) :
2019-11-13 11:45:05 -05:00
print ( ' Starting Data Augmention... ' )
dataug_epoch_start = epoch
model . augment ( mode = True )
2020-01-20 17:09:31 -05:00
if inner_it != 0 : #Rebuild diffopt if needed
high_grad_track = True
diffopt = model [ ' model ' ] . get_diffopt (
inner_opt ,
2024-08-20 11:53:35 +02:00
grad_callback = ( lambda grads : clip_norm ( grads , max_norm = max_grad ) ) ,
2020-01-20 17:09:31 -05:00
track_higher_grads = high_grad_track )
2024-08-20 11:53:35 +02:00
aug_acc , aug_f1 = test ( model , augment = augment_loss )
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
return log , aug_acc
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
#OLD
# def run_simple_smartaug(model, opt_param, epochs=1, inner_it=1, print_freq=1, unsup_loss=1):
# """Simple training of an augmented model with higher.
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py).
# Ex : Augmented_model(Data_augV5(...), Higher_model(model))
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# Training loss can either be computed directly from augmented inputs (unsup_loss=0).
# However, it is recommended to use the mixed loss computation, which combine original and augmented inputs to compute the loss (unsup_loss>0).
2020-02-12 13:43:44 -05:00
2024-08-20 11:53:35 +02:00
# Does not support LR scheduler.
# Args:
# model (nn.Module): Augmented model to train.
# opt_param (dict): Dictionnary containing optimizers parameters.
# epochs (int): Number of epochs to perform. (default: 1)
# inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1)
# print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
# unsup_loss (float): Proportion of the unsup_loss loss added to the supervised loss. If set to 0, the loss is only computed on augmented inputs. (default: 1)
2020-01-29 06:36:12 -05:00
2024-08-20 11:53:35 +02:00
# Returns:
# (dict) A dictionary containing a whole state of the trained network.
# """
# device = next(model.parameters()).device
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# ## Optimizers ##
# hyper_param = list(model['data_aug'].parameters())
# model.start_bilevel_opt(inner_it=inner_it, hp_list=hyper_param, opt_param=opt_param, dl_val=dl_val)
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# model.train()
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# for epoch in range(1, epochs+1):
# t0 = time.process_time()
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# for i, (xs, ys) in enumerate(dl_train):
# xs, ys = xs.to(device), ys.to(device)
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# #Methode mixed
# loss = mixed_loss(xs, ys, model, unsup_factor=unsup_loss)
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# model.step(loss) #(opt.zero_grad, loss.backward, opt.step) + automatic meta-optimisation
2020-01-28 19:42:00 -05:00
2024-08-20 11:53:35 +02:00
# tf = time.process_time()
# #### Print ####
# if(print_freq and epoch%print_freq==0):
# print('-'*9)
# print('Epoch : %d/%d'%(epoch,epochs))
# print('Time : %.00f'%(tf - t0))
# print('Train loss :',loss.item(), '/ val loss', model.val_loss().item())
# if not model['data_aug']._fixed_prob: print('TF Proba :', model['data_aug']['prob'].data)
# if not model['data_aug']._fixed_mag: print('TF Mag :', model['data_aug']['mag'].data)
# if not model['data_aug']._fixed_temp: print('Temp:', model['data_aug']['temp'].item())
# #############
# return model['model'].state_dict()