diff --git a/higher/datasets.py b/higher/datasets.py index 09fa1ee..f6fe438 100755 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -35,6 +35,8 @@ import augmentation_transforms import numpy as np download_data=False +num_workers=0 +pin_memory=False class AugmentedDataset(VisionDataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None): @@ -281,14 +283,14 @@ train_subset_indices=range(int(len(data_train)/2)) val_subset_indices=range(int(len(data_train)/2),len(data_train)) #train_subset_indices=range(BATCH_SIZE*10) #val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20) -dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices)) +dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory) ### Augmented Dataset ### #data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) #data_train_aug.augement_data(aug_copy=10) #print(data_train_aug) -#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True) +#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) -dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices)) -dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False) +dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory) +dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) diff --git a/higher/res/Aug_mod(Data_augV6(Uniform-18TF(2)x1-MagFxSh)-LeNet)-10 epochs (dataug:0)- 10 in_it.png b/higher/res/Aug_mod(Data_augV6(Uniform-18TF(2)x1-MagFxSh)-LeNet)-10 epochs (dataug:0)- 10 in_it.png deleted file mode 100755 index 34c4571..0000000 Binary files a/higher/res/Aug_mod(Data_augV6(Uniform-18TF(2)x1-MagFxSh)-LeNet)-10 epochs (dataug:0)- 10 in_it.png and /dev/null differ diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 2c51923..7f60822 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -6,21 +6,21 @@ from train_utils import * tf_names = [ ## Geometric TF ## 'Identity', - 'FlipUD', - 'FlipLR', - 'Rotate', - 'TranslateX', - 'TranslateY', - 'ShearX', - 'ShearY', + #'FlipUD', + #'FlipLR', + #'Rotate', + #'TranslateX', + #'TranslateY', + #'ShearX', + #'ShearY', ## Color TF (Expect image in the range of [0, 1]) ## - 'Contrast', - 'Color', - 'Brightness', - 'Sharpness', - 'Posterize', - 'Solarize', #=>Image entre [0,1] #Pas opti pour des batch + #'Contrast', + #'Color', + #'Brightness', + #'Sharpness', + #'Posterize', + #'Solarize', #=>Image entre [0,1] #Pas opti pour des batch #Color TF (Common mag scale) #'+Contrast', @@ -49,6 +49,8 @@ tf_names = [ #'BadContrast', #'BadBrightness', + 'Random', + #'RandBlend' #Non fonctionnel #'Auto_Contrast', #Pas opti pour des batch (Super lent) #'Equalize', @@ -65,12 +67,12 @@ else: if __name__ == "__main__": tasks={ - 'classic', + #'classic', #'aug_dataset', - #'aug_model' + 'aug_model' } n_inner_iter = 1 - epochs = 100 + epochs = 1 dataug_epoch_start=0 optim_param={ 'Meta':{ @@ -84,9 +86,9 @@ if __name__ == "__main__": } } - #model = LeNet(3,10) + model = LeNet(3,10) #model = MobileNetV2(num_classes=10) - model = ResNet(num_classes=10) + #model = ResNet(num_classes=10) #model = WideResNet(num_classes=10, wrn_size=32) #### Classic #### @@ -95,8 +97,8 @@ if __name__ == "__main__": model = model.to(device) print("{} on {} for {} epochs".format(str(model), device_name, epochs)) - log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1) - #log= train_classic_higher(model=model, epochs=epochs) + #log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=10) + log= train_classic_higher(model=model, epochs=epochs) exec_time=time.process_time() - t0 #### @@ -138,7 +140,7 @@ if __name__ == "__main__": data_train_aug.augement_data(aug_copy=1) print(data_train_aug) unsup_ratio = 5 - dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True) + dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) unsup_xs, sup_xs, ys = next(iter(dl_unsup)) viz_sample_data(imgs=sup_xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug))) @@ -172,7 +174,7 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} #aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), model).to(device) - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, fixed_prob=False, fixed_mag=True, shared_mag=True), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) @@ -181,7 +183,7 @@ if __name__ == "__main__": inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, opt_param=optim_param, - print_freq=10, + print_freq=1, KLdiv=True, loss_patience=None) diff --git a/higher/train_utils.py b/higher/train_utils.py index ec3a9c5..e78dde6 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -44,8 +44,7 @@ def compute_vaLoss(model, dl_it, dl): xs, ys = xs.to(device), ys.to(device) model.eval() #Validation sans transfornations ! - - return F.cross_entropy(model(xs), ys) + return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys) def train_classic(model, opt_param, epochs=1, print_freq=1): device = next(model.parameters()).device @@ -688,25 +687,30 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start else: #Methode KL div - fmodel.augment(mode=False) - sup_logits = fmodel(xs) + if fmodel._data_augmentation : + fmodel.augment(mode=False) + sup_logits = fmodel(xs) + fmodel.augment(mode=True) + else: + sup_logits = fmodel(xs) log_sup=F.log_softmax(sup_logits, dim=1) - fmodel.augment(mode=True) loss = F.cross_entropy(log_sup, ys) if fmodel._data_augmentation: aug_logits = fmodel(xs) log_aug=F.log_softmax(aug_logits, dim=1) - #KL div w/ logits - aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug) - aug_loss=aug_loss.sum(dim=-1) - #aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions) - + w_loss = fmodel['data_aug'].loss_weight() #Weight loss - aug_loss = (w_loss * aug_loss).mean() #apprentissage differe ? + + #if epoch>50: #debut differe ? + #KL div w/ logits - Similarite predictions (distributions) + aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug) + aug_loss = aug_loss.sum(dim=-1) + #aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') + aug_loss = (w_loss * aug_loss).mean() aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() - #print(aug_loss) + unsupp_coeff = 1 loss += aug_loss * unsupp_coeff @@ -717,20 +721,28 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start #print(fmodel['model']._params['b4'].grad) #print('prob grad', fmodel['data_aug']['prob'].grad) + #t = time.process_time() diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) + #print(len(fmodel._fast_params),"step", time.process_time()-t) - if(high_grad_track and i%inner_it==0): #Perform Meta step + if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step #print("meta") - #Peu utile si high_grad_track = False - val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss() + + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss() #print_graph(val_loss) + #t = time.process_time() val_loss.backward() + #print("meta", time.process_time()-t) + #print('proba grad',model['data_aug']['prob'].grad) countcopy+=1 model_copy(src=fmodel, dst=model) optim_copy(dopt=diffopt, opt=inner_opt) + torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN + torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN + #if epoch>50: meta_opt.step() model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 @@ -757,21 +769,6 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start accuracy, test_loss =test(model) model.train() - #### Print #### - if(print_freq and epoch%print_freq==0): - print('-'*9) - print('Epoch : %d/%d'%(epoch,epochs)) - print('Time : %.00f'%(tf - t0)) - print('Train loss :',loss.item(), '/ val loss', val_loss.item()) - print('Accuracy :', accuracy) - print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) - print('TF Proba :', model['data_aug']['prob'].data) - #print('proba grad',model['data_aug']['prob'].grad) - print('TF Mag :', model['data_aug']['mag'].data) - #print('Mag grad',model['data_aug']['mag'].grad) - #print('Reg loss:', model['data_aug'].reg_loss().item()) - #print('Aug loss', aug_loss.item()) - ############# #### Log #### #print(type(model['data_aug']) is dataug.Data_augV5) param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] @@ -787,6 +784,235 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start } log.append(data) ############# + #### Print #### + if(print_freq and epoch%print_freq==0): + print('-'*9) + print('Epoch : %d/%d'%(epoch,epochs)) + print('Time : %.00f'%(tf - t0)) + print('Train loss :',loss.item(), '/ val loss', val_loss.item()) + print('Accuracy :', max([x["acc"] for x in log])) + print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) + print('TF Proba :', model['data_aug']['prob'].data) + #print('proba grad',model['data_aug']['prob'].grad) + print('TF Mag :', model['data_aug']['mag'].data) + #print('Mag grad',model['data_aug']['mag'].grad) + #print('Reg loss:', model['data_aug'].reg_loss().item()) + #print('Aug loss', aug_loss.item()) + ############# + if val_loss_monitor : + model.eval() + val_loss_monitor.register(test_loss)#val_loss.item()) + if val_loss_monitor.end_training(): break #Stop training + model.train() + + if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)): + print('Starting Data Augmention...') + dataug_epoch_start = epoch + model.augment(mode=True) + if inner_it != 0: high_grad_track = True + + try: + viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight()) + except: + print("Couldn't save finals samples") + pass + + #print("Copy ", countcopy) + return log + +def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False): + device = next(model.parameters()).device + log = [] + countcopy=0 + val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch + dl_val_it = iter(dl_val) + + #if inner_it!=0: + meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 + inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + + high_grad_track = True + if inner_it == 0: + high_grad_track=False + if dataug_epoch_start!=0: + model.augment(mode=False) + high_grad_track = False + + val_loss_monitor= None + if loss_patience != None : + if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start + else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data) + + model.train() + + #fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True) + #diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track) + + #meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 + + print(len(fmodel._fast_params)) + + for epoch in range(1, epochs+1): + #print_torch_mem("Start epoch "+str(epoch)) + #print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params)) + t0 = time.process_time() + #with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): + + for i, (xs, ys) in enumerate(dl_train): + xs, ys = xs.to(device), ys.to(device) + + if(not KLdiv): + #Methode uniforme + logits = fmodel(xs) # modified `params` can also be passed as a kwarg + loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards() + + if fmodel._data_augmentation: #Weight loss + w_loss = fmodel['data_aug'].loss_weight()#.to(device) + loss = loss * w_loss + loss = loss.mean() + + else: + #Methode KL div + if fmodel._data_augmentation : + fmodel.augment(mode=False) + sup_logits = fmodel(xs) + fmodel.augment(mode=True) + else: + sup_logits = fmodel(xs) + log_sup=F.log_softmax(sup_logits, dim=1) + loss = F.cross_entropy(log_sup, ys) + + if fmodel._data_augmentation: + aug_logits = fmodel(xs) + log_aug=F.log_softmax(aug_logits, dim=1) + aug_loss=0 + w_loss = fmodel['data_aug'].loss_weight() #Weight loss + + #if epoch>50: #debut differe ? + #KL div w/ logits - Similarite predictions (distributions) + aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug) + aug_loss = aug_loss.sum(dim=-1) + #aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') + aug_loss = (w_loss * aug_loss).mean() + + aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() + + unsupp_coeff = 1 + loss += aug_loss * unsupp_coeff + + #to visualize computational graph + #print_graph(loss) + + #loss.backward(retain_graph=True) + #print(fmodel['model']._params['b4'].grad) + #print('prob grad', fmodel['data_aug']['prob'].grad) + + #for _, p in fmodel['data_aug'].named_parameters(): + # p.requires_grad = False + t = time.process_time() + diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) + print(len(fmodel._fast_params),"step", time.process_time()-t) + + #for _, p in fmodel['data_aug'].named_parameters(): + # p.requires_grad = True + + if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step + #print("meta") + + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss() + #print_graph(val_loss) + + val_loss.backward() + + print('proba grad',fmodel['data_aug']['prob'].grad) + #countcopy+=1 + #model_copy(src=fmodel, dst=model) + #optim_copy(dopt=diffopt, opt=inner_opt) + + torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN + torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN + + for paramName, paramValue, in fmodel['data_aug'].named_parameters(): + for netCopyName, netCopyValue, in model['data_aug'].named_parameters(): + if paramName == netCopyName: + netCopyValue.grad = paramValue.grad + + #del meta_opt.param_groups[0] + #meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]}) + + meta_opt.step() + fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict()) + fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1 + #model['data_aug'].next_TF_set() + + + #fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + #diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) + + + #fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()] + diffopt.detach_() + tmp = fmodel.fast_params + fmodel._fast_params=[] + fmodel.update_params(tmp) + for p in fmodel.fast_params: + p.detach_().requires_grad_() + print(len(fmodel._fast_params)) + + print('TF Proba :', fmodel['data_aug']['prob'].data) + tf = time.process_time() + + #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) + #viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) + + #model_copy(src=fmodel, dst=model) + + if(not high_grad_track): + #countcopy+=1 + #model_copy(src=fmodel, dst=model) + optim_copy(dopt=diffopt, opt=inner_opt) + val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + + #Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False) + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track) + + accuracy, test_loss =test(model) + model.train() + + #### Log #### + #print(type(model['data_aug']) is dataug.Data_augV5) + param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])] + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": param #if isinstance(model['data_aug'], Data_augV5) + #else [p.item() for p in model['data_aug']['prob']], + } + log.append(data) + ############# + #### Print #### + if(print_freq and epoch%print_freq==0): + print('-'*9) + print('Epoch : %d/%d'%(epoch,epochs)) + print('Time : %.00f'%(tf - t0)) + print('Train loss :',loss.item(), '/ val loss', val_loss.item()) + print('Accuracy :', max([x["acc"] for x in log])) + print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start)) + print('TF Proba :', model['data_aug']['prob'].data) + #print('proba grad',model['data_aug']['prob'].grad) + print('TF Mag :', model['data_aug']['mag'].data) + #print('Mag grad',model['data_aug']['mag'].grad) + #print('Reg loss:', model['data_aug'].reg_loss().item()) + #print('Aug loss', aug_loss.item()) + ############# if val_loss_monitor : model.eval() val_loss_monitor.register(test_loss)#val_loss.item()) diff --git a/higher/transformations.py b/higher/transformations.py index 82a8d9e..e4cfded 100755 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -96,10 +96,13 @@ TF_dict={ #Dataugv5 'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))), 'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))), - 'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), - 'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), - 'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), - 'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))), + 'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))), + 'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))), + 'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))), + 'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))), + + 'Random':(lambda x, mag: torch.rand_like(x)), + 'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.5,device=mag.device).expand(x.shape[0]))), #Non fonctionnel #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) diff --git a/higher/utils.py b/higher/utils.py index ea81ea3..f7c5ab0 100755 --- a/higher/utils.py +++ b/higher/utils.py @@ -22,7 +22,7 @@ class timer(): def print_graph(PyTorch_obj, fig_name='graph'): graph=make_dot(PyTorch_obj) #Loss give the whole graph - graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats + graph.format = 'pdf' #https://graphviz.readthedocs.io/en/stable/manual.html#formats graph.render(fig_name) def plot_res(log, fig_name='res', param_names=None): @@ -183,7 +183,7 @@ def plot_TF_res(log, tf_names, fig_name='res'): plt.savefig(fig_name, bbox_inches='tight') plt.close() -def viz_sample_data(imgs, labels, fig_name='data_sample'): +def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None): sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu() @@ -194,7 +194,9 @@ def viz_sample_data(imgs, labels, fig_name='data_sample'): plt.yticks([]) plt.grid(False) plt.imshow(sample[i,].detach().numpy(), cmap=plt.cm.binary) - plt.xlabel(labels[i].item()) + label = str(labels[i].item()) + if weight_labels is not None : label+= ("- p %.2f" % weight_labels[i].item()) + plt.xlabel(label) plt.savefig(fig_name) print("Sample saved :", fig_name)