mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Changement permission fichiers + Simplification utilisation Augmented_dataset
This commit is contained in:
parent
adaac437b6
commit
b26fbcd2a2
619 changed files with 41 additions and 13049 deletions
13
higher/train_utils.py
Normal file → Executable file
13
higher/train_utils.py
Normal file → Executable file
|
@ -47,7 +47,7 @@ def compute_vaLoss(model, dl_it, dl):
|
|||
|
||||
return F.cross_entropy(model(xs), ys)
|
||||
|
||||
def train_classic(model, epochs=1):
|
||||
def train_classic(model, epochs=1, print_freq=1):
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
@ -80,6 +80,15 @@ def train_classic(model, epochs=1):
|
|||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
|
@ -619,7 +628,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
#model['data_aug'].next_TF_set()
|
||||
#model['data_aug'].next_TF_set()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue