mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Mise a jour de toute les modifs... (Higher: Ajout deux TF, modification val loss, ajout prob dans sample image, ...)
This commit is contained in:
parent
e75fb96716
commit
c8ce6c8024
6 changed files with 299 additions and 64 deletions
|
@ -35,6 +35,8 @@ import augmentation_transforms
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
download_data=False
|
download_data=False
|
||||||
|
num_workers=0
|
||||||
|
pin_memory=False
|
||||||
|
|
||||||
class AugmentedDataset(VisionDataset):
|
class AugmentedDataset(VisionDataset):
|
||||||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
|
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
|
||||||
|
@ -281,14 +283,14 @@ train_subset_indices=range(int(len(data_train)/2))
|
||||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||||
#train_subset_indices=range(BATCH_SIZE*10)
|
#train_subset_indices=range(BATCH_SIZE*10)
|
||||||
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
#val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20)
|
||||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||||
|
|
||||||
### Augmented Dataset ###
|
### Augmented Dataset ###
|
||||||
#data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
|
#data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
|
||||||
#data_train_aug.augement_data(aug_copy=10)
|
#data_train_aug.augement_data(aug_copy=10)
|
||||||
#print(data_train_aug)
|
#print(data_train_aug)
|
||||||
#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True)
|
#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||||
|
|
||||||
|
|
||||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory)
|
||||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False)
|
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 191 KiB |
|
@ -6,21 +6,21 @@ from train_utils import *
|
||||||
tf_names = [
|
tf_names = [
|
||||||
## Geometric TF ##
|
## Geometric TF ##
|
||||||
'Identity',
|
'Identity',
|
||||||
'FlipUD',
|
#'FlipUD',
|
||||||
'FlipLR',
|
#'FlipLR',
|
||||||
'Rotate',
|
#'Rotate',
|
||||||
'TranslateX',
|
#'TranslateX',
|
||||||
'TranslateY',
|
#'TranslateY',
|
||||||
'ShearX',
|
#'ShearX',
|
||||||
'ShearY',
|
#'ShearY',
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
'Contrast',
|
#'Contrast',
|
||||||
'Color',
|
#'Color',
|
||||||
'Brightness',
|
#'Brightness',
|
||||||
'Sharpness',
|
#'Sharpness',
|
||||||
'Posterize',
|
#'Posterize',
|
||||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
#Color TF (Common mag scale)
|
#Color TF (Common mag scale)
|
||||||
#'+Contrast',
|
#'+Contrast',
|
||||||
|
@ -49,6 +49,8 @@ tf_names = [
|
||||||
#'BadContrast',
|
#'BadContrast',
|
||||||
#'BadBrightness',
|
#'BadBrightness',
|
||||||
|
|
||||||
|
'Random',
|
||||||
|
#'RandBlend'
|
||||||
#Non fonctionnel
|
#Non fonctionnel
|
||||||
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
|
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
|
||||||
#'Equalize',
|
#'Equalize',
|
||||||
|
@ -65,12 +67,12 @@ else:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
tasks={
|
tasks={
|
||||||
'classic',
|
#'classic',
|
||||||
#'aug_dataset',
|
#'aug_dataset',
|
||||||
#'aug_model'
|
'aug_model'
|
||||||
}
|
}
|
||||||
n_inner_iter = 1
|
n_inner_iter = 1
|
||||||
epochs = 100
|
epochs = 1
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
optim_param={
|
optim_param={
|
||||||
'Meta':{
|
'Meta':{
|
||||||
|
@ -84,9 +86,9 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#model = LeNet(3,10)
|
model = LeNet(3,10)
|
||||||
#model = MobileNetV2(num_classes=10)
|
#model = MobileNetV2(num_classes=10)
|
||||||
model = ResNet(num_classes=10)
|
#model = ResNet(num_classes=10)
|
||||||
#model = WideResNet(num_classes=10, wrn_size=32)
|
#model = WideResNet(num_classes=10, wrn_size=32)
|
||||||
|
|
||||||
#### Classic ####
|
#### Classic ####
|
||||||
|
@ -95,8 +97,8 @@ if __name__ == "__main__":
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
|
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
|
||||||
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1)
|
#log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10)
|
||||||
#log= train_classic_higher(model=model, epochs=epochs)
|
log= train_classic_higher(model=model, epochs=epochs)
|
||||||
|
|
||||||
exec_time=time.process_time() - t0
|
exec_time=time.process_time() - t0
|
||||||
####
|
####
|
||||||
|
@ -138,7 +140,7 @@ if __name__ == "__main__":
|
||||||
data_train_aug.augement_data(aug_copy=1)
|
data_train_aug.augement_data(aug_copy=1)
|
||||||
print(data_train_aug)
|
print(data_train_aug)
|
||||||
unsup_ratio = 5
|
unsup_ratio = 5
|
||||||
dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True)
|
dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
|
||||||
|
|
||||||
unsup_xs, sup_xs, ys = next(iter(dl_unsup))
|
unsup_xs, sup_xs, ys = next(iter(dl_unsup))
|
||||||
viz_sample_data(imgs=sup_xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug)))
|
viz_sample_data(imgs=sup_xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug)))
|
||||||
|
@ -172,7 +174,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), model).to(device)
|
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), model).to(device)
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True), model).to(device)
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||||
|
@ -181,7 +183,7 @@ if __name__ == "__main__":
|
||||||
inner_it=n_inner_iter,
|
inner_it=n_inner_iter,
|
||||||
dataug_epoch_start=dataug_epoch_start,
|
dataug_epoch_start=dataug_epoch_start,
|
||||||
opt_param=optim_param,
|
opt_param=optim_param,
|
||||||
print_freq=10,
|
print_freq=1,
|
||||||
KLdiv=True,
|
KLdiv=True,
|
||||||
loss_patience=None)
|
loss_patience=None)
|
||||||
|
|
||||||
|
|
|
@ -44,8 +44,7 @@ def compute_vaLoss(model, dl_it, dl):
|
||||||
xs, ys = xs.to(device), ys.to(device)
|
xs, ys = xs.to(device), ys.to(device)
|
||||||
|
|
||||||
model.eval() #Validation sans transfornations !
|
model.eval() #Validation sans transfornations !
|
||||||
|
return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys)
|
||||||
return F.cross_entropy(model(xs), ys)
|
|
||||||
|
|
||||||
def train_classic(model, opt_param, epochs=1, print_freq=1):
|
def train_classic(model, opt_param, epochs=1, print_freq=1):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
@ -688,25 +687,30 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
|
|
||||||
else:
|
else:
|
||||||
#Methode KL div
|
#Methode KL div
|
||||||
|
if fmodel._data_augmentation :
|
||||||
fmodel.augment(mode=False)
|
fmodel.augment(mode=False)
|
||||||
sup_logits = fmodel(xs)
|
sup_logits = fmodel(xs)
|
||||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
|
||||||
fmodel.augment(mode=True)
|
fmodel.augment(mode=True)
|
||||||
|
else:
|
||||||
|
sup_logits = fmodel(xs)
|
||||||
|
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||||
loss = F.cross_entropy(log_sup, ys)
|
loss = F.cross_entropy(log_sup, ys)
|
||||||
|
|
||||||
if fmodel._data_augmentation:
|
if fmodel._data_augmentation:
|
||||||
aug_logits = fmodel(xs)
|
aug_logits = fmodel(xs)
|
||||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||||
#KL div w/ logits
|
|
||||||
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
|
||||||
aug_loss=aug_loss.sum(dim=-1)
|
|
||||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
|
|
||||||
|
|
||||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||||
aug_loss = (w_loss * aug_loss).mean() #apprentissage differe ?
|
|
||||||
|
#if epoch>50: #debut differe ?
|
||||||
|
#KL div w/ logits - Similarite predictions (distributions)
|
||||||
|
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||||
|
aug_loss = aug_loss.sum(dim=-1)
|
||||||
|
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
|
||||||
|
aug_loss = (w_loss * aug_loss).mean()
|
||||||
|
|
||||||
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||||
#print(aug_loss)
|
|
||||||
unsupp_coeff = 1
|
unsupp_coeff = 1
|
||||||
loss += aug_loss * unsupp_coeff
|
loss += aug_loss * unsupp_coeff
|
||||||
|
|
||||||
|
@ -717,20 +721,28 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
#print(fmodel['model']._params['b4'].grad)
|
#print(fmodel['model']._params['b4'].grad)
|
||||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||||
|
|
||||||
|
#t = time.process_time()
|
||||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||||
|
#print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||||
|
|
||||||
if(high_grad_track and i%inner_it==0): #Perform Meta step
|
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||||
#print("meta")
|
#print("meta")
|
||||||
#Peu utile si high_grad_track = False
|
|
||||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||||
#print_graph(val_loss)
|
#print_graph(val_loss)
|
||||||
|
|
||||||
|
#t = time.process_time()
|
||||||
val_loss.backward()
|
val_loss.backward()
|
||||||
|
#print("meta", time.process_time()-t)
|
||||||
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
|
|
||||||
countcopy+=1
|
countcopy+=1
|
||||||
model_copy(src=fmodel, dst=model)
|
model_copy(src=fmodel, dst=model)
|
||||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
|
torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||||
|
torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||||
|
|
||||||
#if epoch>50:
|
#if epoch>50:
|
||||||
meta_opt.step()
|
meta_opt.step()
|
||||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||||
|
@ -757,21 +769,6 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
accuracy, test_loss =test(model)
|
accuracy, test_loss =test(model)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
#### Print ####
|
|
||||||
if(print_freq and epoch%print_freq==0):
|
|
||||||
print('-'*9)
|
|
||||||
print('Epoch : %d/%d'%(epoch,epochs))
|
|
||||||
print('Time : %.00f'%(tf - t0))
|
|
||||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
|
||||||
print('Accuracy :', accuracy)
|
|
||||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
|
||||||
print('TF Proba :', model['data_aug']['prob'].data)
|
|
||||||
#print('proba grad',model['data_aug']['prob'].grad)
|
|
||||||
print('TF Mag :', model['data_aug']['mag'].data)
|
|
||||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
|
||||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
|
||||||
#print('Aug loss', aug_loss.item())
|
|
||||||
#############
|
|
||||||
#### Log ####
|
#### Log ####
|
||||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||||
|
@ -787,6 +784,235 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
}
|
}
|
||||||
log.append(data)
|
log.append(data)
|
||||||
#############
|
#############
|
||||||
|
#### Print ####
|
||||||
|
if(print_freq and epoch%print_freq==0):
|
||||||
|
print('-'*9)
|
||||||
|
print('Epoch : %d/%d'%(epoch,epochs))
|
||||||
|
print('Time : %.00f'%(tf - t0))
|
||||||
|
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||||
|
print('Accuracy :', max([x["acc"] for x in log]))
|
||||||
|
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||||
|
print('TF Proba :', model['data_aug']['prob'].data)
|
||||||
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
|
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||||
|
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||||
|
#print('Aug loss', aug_loss.item())
|
||||||
|
#############
|
||||||
|
if val_loss_monitor :
|
||||||
|
model.eval()
|
||||||
|
val_loss_monitor.register(test_loss)#val_loss.item())
|
||||||
|
if val_loss_monitor.end_training(): break #Stop training
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)):
|
||||||
|
print('Starting Data Augmention...')
|
||||||
|
dataug_epoch_start = epoch
|
||||||
|
model.augment(mode=True)
|
||||||
|
if inner_it != 0: high_grad_track = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
|
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||||
|
except:
|
||||||
|
print("Couldn't save finals samples")
|
||||||
|
pass
|
||||||
|
|
||||||
|
#print("Copy ", countcopy)
|
||||||
|
return log
|
||||||
|
|
||||||
|
def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
log = []
|
||||||
|
countcopy=0
|
||||||
|
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||||
|
dl_val_it = iter(dl_val)
|
||||||
|
|
||||||
|
#if inner_it!=0:
|
||||||
|
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||||
|
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
|
high_grad_track = True
|
||||||
|
if inner_it == 0:
|
||||||
|
high_grad_track=False
|
||||||
|
if dataug_epoch_start!=0:
|
||||||
|
model.augment(mode=False)
|
||||||
|
high_grad_track = False
|
||||||
|
|
||||||
|
val_loss_monitor= None
|
||||||
|
if loss_patience != None :
|
||||||
|
if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start
|
||||||
|
else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
#fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True)
|
||||||
|
#diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||||
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
|
#meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||||
|
|
||||||
|
print(len(fmodel._fast_params))
|
||||||
|
|
||||||
|
for epoch in range(1, epochs+1):
|
||||||
|
#print_torch_mem("Start epoch "+str(epoch))
|
||||||
|
#print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params))
|
||||||
|
t0 = time.process_time()
|
||||||
|
#with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt):
|
||||||
|
|
||||||
|
for i, (xs, ys) in enumerate(dl_train):
|
||||||
|
xs, ys = xs.to(device), ys.to(device)
|
||||||
|
|
||||||
|
if(not KLdiv):
|
||||||
|
#Methode uniforme
|
||||||
|
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||||
|
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||||
|
|
||||||
|
if fmodel._data_augmentation: #Weight loss
|
||||||
|
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||||
|
loss = loss * w_loss
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
else:
|
||||||
|
#Methode KL div
|
||||||
|
if fmodel._data_augmentation :
|
||||||
|
fmodel.augment(mode=False)
|
||||||
|
sup_logits = fmodel(xs)
|
||||||
|
fmodel.augment(mode=True)
|
||||||
|
else:
|
||||||
|
sup_logits = fmodel(xs)
|
||||||
|
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||||
|
loss = F.cross_entropy(log_sup, ys)
|
||||||
|
|
||||||
|
if fmodel._data_augmentation:
|
||||||
|
aug_logits = fmodel(xs)
|
||||||
|
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||||
|
aug_loss=0
|
||||||
|
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||||
|
|
||||||
|
#if epoch>50: #debut differe ?
|
||||||
|
#KL div w/ logits - Similarite predictions (distributions)
|
||||||
|
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||||
|
aug_loss = aug_loss.sum(dim=-1)
|
||||||
|
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
|
||||||
|
aug_loss = (w_loss * aug_loss).mean()
|
||||||
|
|
||||||
|
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||||
|
|
||||||
|
unsupp_coeff = 1
|
||||||
|
loss += aug_loss * unsupp_coeff
|
||||||
|
|
||||||
|
#to visualize computational graph
|
||||||
|
#print_graph(loss)
|
||||||
|
|
||||||
|
#loss.backward(retain_graph=True)
|
||||||
|
#print(fmodel['model']._params['b4'].grad)
|
||||||
|
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||||
|
|
||||||
|
#for _, p in fmodel['data_aug'].named_parameters():
|
||||||
|
# p.requires_grad = False
|
||||||
|
t = time.process_time()
|
||||||
|
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||||
|
print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||||
|
|
||||||
|
#for _, p in fmodel['data_aug'].named_parameters():
|
||||||
|
# p.requires_grad = True
|
||||||
|
|
||||||
|
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||||
|
#print("meta")
|
||||||
|
|
||||||
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||||
|
#print_graph(val_loss)
|
||||||
|
|
||||||
|
val_loss.backward()
|
||||||
|
|
||||||
|
print('proba grad',fmodel['data_aug']['prob'].grad)
|
||||||
|
#countcopy+=1
|
||||||
|
#model_copy(src=fmodel, dst=model)
|
||||||
|
#optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
|
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||||
|
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||||
|
|
||||||
|
for paramName, paramValue, in fmodel['data_aug'].named_parameters():
|
||||||
|
for netCopyName, netCopyValue, in model['data_aug'].named_parameters():
|
||||||
|
if paramName == netCopyName:
|
||||||
|
netCopyValue.grad = paramValue.grad
|
||||||
|
|
||||||
|
#del meta_opt.param_groups[0]
|
||||||
|
#meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]})
|
||||||
|
|
||||||
|
meta_opt.step()
|
||||||
|
fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict())
|
||||||
|
fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||||
|
#model['data_aug'].next_TF_set()
|
||||||
|
|
||||||
|
|
||||||
|
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
|
|
||||||
|
#fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()]
|
||||||
|
diffopt.detach_()
|
||||||
|
tmp = fmodel.fast_params
|
||||||
|
fmodel._fast_params=[]
|
||||||
|
fmodel.update_params(tmp)
|
||||||
|
for p in fmodel.fast_params:
|
||||||
|
p.detach_().requires_grad_()
|
||||||
|
print(len(fmodel._fast_params))
|
||||||
|
|
||||||
|
print('TF Proba :', fmodel['data_aug']['prob'].data)
|
||||||
|
tf = time.process_time()
|
||||||
|
|
||||||
|
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
|
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||||
|
|
||||||
|
#model_copy(src=fmodel, dst=model)
|
||||||
|
|
||||||
|
if(not high_grad_track):
|
||||||
|
#countcopy+=1
|
||||||
|
#model_copy(src=fmodel, dst=model)
|
||||||
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||||
|
|
||||||
|
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||||
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
|
accuracy, test_loss =test(model)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
#### Log ####
|
||||||
|
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||||
|
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||||
|
data={
|
||||||
|
"epoch": epoch,
|
||||||
|
"train_loss": loss.item(),
|
||||||
|
"val_loss": val_loss.item(),
|
||||||
|
"acc": accuracy,
|
||||||
|
"time": tf - t0,
|
||||||
|
|
||||||
|
"param": param #if isinstance(model['data_aug'], Data_augV5)
|
||||||
|
#else [p.item() for p in model['data_aug']['prob']],
|
||||||
|
}
|
||||||
|
log.append(data)
|
||||||
|
#############
|
||||||
|
#### Print ####
|
||||||
|
if(print_freq and epoch%print_freq==0):
|
||||||
|
print('-'*9)
|
||||||
|
print('Epoch : %d/%d'%(epoch,epochs))
|
||||||
|
print('Time : %.00f'%(tf - t0))
|
||||||
|
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||||
|
print('Accuracy :', max([x["acc"] for x in log]))
|
||||||
|
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||||
|
print('TF Proba :', model['data_aug']['prob'].data)
|
||||||
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
|
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||||
|
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||||
|
#print('Aug loss', aug_loss.item())
|
||||||
|
#############
|
||||||
if val_loss_monitor :
|
if val_loss_monitor :
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss_monitor.register(test_loss)#val_loss.item())
|
val_loss_monitor.register(test_loss)#val_loss.item())
|
||||||
|
|
|
@ -96,10 +96,13 @@ TF_dict={ #Dataugv5
|
||||||
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
|
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
|
||||||
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
|
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
|
||||||
|
|
||||||
'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||||
'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||||
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||||
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||||
|
|
||||||
|
'Random':(lambda x, mag: torch.rand_like(x)),
|
||||||
|
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.5,device=mag.device).expand(x.shape[0]))),
|
||||||
|
|
||||||
#Non fonctionnel
|
#Non fonctionnel
|
||||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
|
|
@ -22,7 +22,7 @@ class timer():
|
||||||
|
|
||||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||||
graph=make_dot(PyTorch_obj) #Loss give the whole graph
|
graph=make_dot(PyTorch_obj) #Loss give the whole graph
|
||||||
graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||||
graph.render(fig_name)
|
graph.render(fig_name)
|
||||||
|
|
||||||
def plot_res(log, fig_name='res', param_names=None):
|
def plot_res(log, fig_name='res', param_names=None):
|
||||||
|
@ -183,7 +183,7 @@ def plot_TF_res(log, tf_names, fig_name='res'):
|
||||||
plt.savefig(fig_name, bbox_inches='tight')
|
plt.savefig(fig_name, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
def viz_sample_data(imgs, labels, fig_name='data_sample'):
|
def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
|
||||||
|
|
||||||
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
||||||
|
|
||||||
|
@ -194,7 +194,9 @@ def viz_sample_data(imgs, labels, fig_name='data_sample'):
|
||||||
plt.yticks([])
|
plt.yticks([])
|
||||||
plt.grid(False)
|
plt.grid(False)
|
||||||
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary)
|
||||||
plt.xlabel(labels[i].item())
|
label = str(labels[i].item())
|
||||||
|
if weight_labels is not None : label+= ("- p %.2f" % weight_labels[i].item())
|
||||||
|
plt.xlabel(label)
|
||||||
|
|
||||||
plt.savefig(fig_name)
|
plt.savefig(fig_name)
|
||||||
print("Sample saved :", fig_name)
|
print("Sample saved :", fig_name)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue