From 2e09f07f52193bf9b25d4b00884860c35673c6ba Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Fri, 24 Jan 2020 11:50:30 -0500 Subject: [PATCH] Commentaires + rangement --- higher/datasets.py | 131 +--------------- higher/old/train_utils_old.py | 276 ++++++++++++++++++++++++++++++++++ higher/train_utils.py | 204 +++++++------------------ higher/utils.py | 4 +- 4 files changed, 336 insertions(+), 279 deletions(-) diff --git a/higher/datasets.py b/higher/datasets.py index d7e84e2..d9a8083 100755 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -32,145 +32,16 @@ data_test = torchvision.datasets.MNIST( "./data", train=False, download=True, transform=torchvision.transforms.ToTensor() ) - -from torchvision.datasets.vision import VisionDataset -from PIL import Image -import augmentation_transforms -import numpy as np - -class AugmentedDatasetV2(VisionDataset): - def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None): - - super(AugmentedDatasetV2, self).__init__(root, transform=transform, target_transform=target_transform) - - supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform) - - self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]] - self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]] - assert len(self.sup_data)==len(self.sup_targets) - - for idx, img in enumerate(self.sup_data): - self.sup_data[idx]= Image.fromarray(img) #to PIL Image - - self.unsup_data=[] - self.unsup_targets=[] - self.origin_idx=[] - - self.dataset_info= { - 'name': 'CIFAR10', - 'sup': len(self.sup_data), - 'unsup': len(self.unsup_data), - 'length': len(self.sup_data)+len(self.unsup_data), - } - - - self._TF = [ - ## Geometric TF ## - 'Rotate', - 'TranslateX', - 'TranslateY', - 'ShearX', - 'ShearY', - - 'Cutout', - - ## Color TF ## - 'Contrast', - 'Color', - 'Brightness', - 'Sharpness', - 'Posterize', - 'Solarize', - - 'Invert', - 'AutoContrast', - 'Equalize', - ] - self._op_list =[] - self.prob=0.5 - self.mag_range=(1, 10) - for tf in self._TF: - for mag in range(self.mag_range[0], self.mag_range[1]): - self._op_list+=[(tf, self.prob, mag)] - self._nb_op = len(self._op_list) - - def __getitem__(self, index): - """ - Args: - index (int): Index - - Returns: - tuple: (image, target) where target is index of the target class. - """ - aug_img, origin_img, target = self.unsup_data[index], self.sup_data[self.origin_idx[index]], self.unsup_targets[index] - - # doing this so that it is consistent with all other datasets - # to return a PIL Image - #img = Image.fromarray(img) - - if self.transform is not None: - aug_img = self.transform(aug_img) - origin_img = self.transform(origin_img) - - if self.target_transform is not None: - target = self.target_transform(target) - - return aug_img, origin_img, target - - def augement_data(self, aug_copy=1): - - policies = [] - for op_1 in self._op_list: - for op_2 in self._op_list: - policies += [[op_1, op_2]] - - for idx, image in enumerate(self.sup_data): - if idx%(self.dataset_info['sup']/5)==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup']) - #if idx==10000:break - - for _ in range(aug_copy): - chosen_policy = policies[np.random.choice(len(policies))] - aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image - #aug_image = augmentation_transforms.cutout_numpy(aug_image) - - self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8 - self.unsup_targets+=[self.sup_targets[idx]] - self.origin_idx+=[idx] - - #self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8 - self.unsup_data=np.array(self.unsup_data) - - assert len(self.unsup_data)==len(self.unsup_targets) - - self.dataset_info['unsup']=len(self.unsup_data) - self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup'] - - - def __len__(self): - return self.dataset_info['unsup']#self.dataset_info['length'] - - def __str__(self): - return "CIFAR10(Sup:{}-Unsup:{}-{}TF(Mag{}-{}))".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF), self.mag_range[0], self.mag_range[1]) - - ### Classic Dataset ### data_train = torchvision.datasets.CIFAR10("./data", train=True, download=download_data, transform=transform) #data_val = torchvision.datasets.CIFAR10("./data", train=True, download=download_data, transform=transform) data_test = torchvision.datasets.CIFAR10("./data", train=False, download=download_data, transform=transform) - train_subset_indices=range(int(len(data_train)/2)) val_subset_indices=range(int(len(data_train)/2),len(data_train)) #train_subset_indices=range(BATCH_SIZE*10) #val_subset_indices=range(BATCH_SIZE*10, BATCH_SIZE*20) + dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices), num_workers=num_workers, pin_memory=pin_memory) - -### Augmented Dataset ### -#data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) -#data_train_aug.augement_data(aug_copy=10) -#print(data_train_aug) -#dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=pin_memory) - - dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices), num_workers=num_workers, pin_memory=pin_memory) dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) diff --git a/higher/old/train_utils_old.py b/higher/old/train_utils_old.py index 389dd9d..8e319b6 100644 --- a/higher/old/train_utils_old.py +++ b/higher/old/train_utils_old.py @@ -6,6 +6,66 @@ import higher from datasets import * from utils import * +def train_classic_higher(model, epochs=1): + device = next(model.parameters()).device + #opt = torch.optim.Adam(model.parameters(), lr=1e-3) + optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) + + model.train() + dl_val_it = iter(dl_val) + log = [] + + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False) + #with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt): + + for epoch in range(epochs): + #print_torch_mem("Start epoch "+str(epoch)) + #print("Fast param ",len(fmodel._fast_params)) + t0 = time.process_time() + for i, (features, labels) in enumerate(dl_train): + #print_torch_mem("Start iter") + features,labels = features.to(device), labels.to(device) + + #optim.zero_grad() + logits = model.forward(features) + pred = F.log_softmax(logits, dim=1) + loss = F.cross_entropy(pred,labels) + #.backward() + #optim.step() + diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) + + model_copy(src=fmodel, dst=model, patch_copy=False) + optim_copy(dopt=diffopt, opt=optim) + fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) + diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False) + + #### Tests #### + tf = time.process_time() + try: + xs_val, ys_val = next(dl_val_it) + except StopIteration: #Fin epoch val + dl_val_it = iter(dl_val) + xs_val, ys_val = next(dl_val_it) + xs_val, ys_val = xs_val.to(device), ys_val.to(device) + + val_loss = F.cross_entropy(model(xs_val), ys_val) + accuracy, _ =test(model) + model.train() + #### Log #### + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": None, + } + log.append(data) + + return log + def train_classic_tests(model, epochs=1): device = next(model.parameters()).device #opt = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -148,6 +208,222 @@ def train_classic_tests(model, epochs=1): return log +from torchvision.datasets.vision import VisionDataset +from PIL import Image +import augmentation_transforms +import numpy as np +class AugmentedDatasetV2(VisionDataset): + def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None): + + super(AugmentedDatasetV2, self).__init__(root, transform=transform, target_transform=target_transform) + + supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform) + + self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]] + self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]] + assert len(self.sup_data)==len(self.sup_targets) + + for idx, img in enumerate(self.sup_data): + self.sup_data[idx]= Image.fromarray(img) #to PIL Image + + self.unsup_data=[] + self.unsup_targets=[] + self.origin_idx=[] + + self.dataset_info= { + 'name': 'CIFAR10', + 'sup': len(self.sup_data), + 'unsup': len(self.unsup_data), + 'length': len(self.sup_data)+len(self.unsup_data), + } + + + self._TF = [ + ## Geometric TF ## + 'Rotate', + 'TranslateX', + 'TranslateY', + 'ShearX', + 'ShearY', + + 'Cutout', + + ## Color TF ## + 'Contrast', + 'Color', + 'Brightness', + 'Sharpness', + 'Posterize', + 'Solarize', + + 'Invert', + 'AutoContrast', + 'Equalize', + ] + self._op_list =[] + self.prob=0.5 + self.mag_range=(1, 10) + for tf in self._TF: + for mag in range(self.mag_range[0], self.mag_range[1]): + self._op_list+=[(tf, self.prob, mag)] + self._nb_op = len(self._op_list) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + aug_img, origin_img, target = self.unsup_data[index], self.sup_data[self.origin_idx[index]], self.unsup_targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + #img = Image.fromarray(img) + + if self.transform is not None: + aug_img = self.transform(aug_img) + origin_img = self.transform(origin_img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return aug_img, origin_img, target + + def augement_data(self, aug_copy=1): + + policies = [] + for op_1 in self._op_list: + for op_2 in self._op_list: + policies += [[op_1, op_2]] + + for idx, image in enumerate(self.sup_data): + if idx%(self.dataset_info['sup']/5)==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup']) + #if idx==10000:break + + for _ in range(aug_copy): + chosen_policy = policies[np.random.choice(len(policies))] + aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image + #aug_image = augmentation_transforms.cutout_numpy(aug_image) + + self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8 + self.unsup_targets+=[self.sup_targets[idx]] + self.origin_idx+=[idx] + + #self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8 + self.unsup_data=np.array(self.unsup_data) + + assert len(self.unsup_data)==len(self.unsup_targets) + + self.dataset_info['unsup']=len(self.unsup_data) + self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup'] + + + def __len__(self): + return self.dataset_info['unsup']#self.dataset_info['length'] + + def __str__(self): + return "CIFAR10(Sup:{}-Unsup:{}-{}TF(Mag{}-{}))".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF), self.mag_range[0], self.mag_range[1]) + +def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1): + """Training of a model using UDA inspired approach. + + Intended to be used alongside an already augmented dataset (see AugmentedDatasetV2). + + Args: + model (nn.Module): Model to train. + dl_unsup (Dataloader): Data loader of unsupervised/augmented data. + opt_param (dict): Dictionnary containing optimizers parameters. + epochs (int): Number of epochs to perform. (default: 1) + print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1) + + Returns: + (list) Logs of training. Each items is a dict containing results of an epoch. + """ + device = next(model.parameters()).device + #opt = torch.optim.Adam(model.parameters(), lr=1e-3) + opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + + + model.train() + dl_val_it = iter(dl_val) + dl_unsup_it =iter(dl_unsup) + log = [] + for epoch in range(epochs): + #print_torch_mem("Start epoch") + t0 = time.process_time() + for i, (features, labels) in enumerate(dl_train): + #print_torch_mem("Start iter") + features,labels = features.to(device), labels.to(device) + + optim.zero_grad() + #Supervised + logits = model.forward(features) + pred = F.log_softmax(logits, dim=1) + sup_loss = F.cross_entropy(pred,labels) + + #Unsupervised + try: + aug_xs, origin_xs, ys = next(dl_unsup_it) + except StopIteration: #Fin epoch val + dl_unsup_it =iter(dl_unsup) + aug_xs, origin_xs, ys = next(dl_unsup_it) + aug_xs, origin_xs, ys = aug_xs.to(device), origin_xs.to(device), ys.to(device) + + #print(aug_xs.shape, origin_xs.shape, ys.shape) + sup_logits = model.forward(origin_xs) + unsup_logits = model.forward(aug_xs) + + log_sup=F.log_softmax(sup_logits, dim=1) + log_unsup=F.log_softmax(unsup_logits, dim=1) + #KL div w/ logits + unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup) + unsup_loss=unsup_loss.sum(dim=-1).mean() + + #print(unsup_loss) + unsupp_coeff = 1 + loss = sup_loss + unsup_loss * unsupp_coeff + + loss.backward() + optim.step() + + #### Tests #### + tf = time.process_time() + try: + xs_val, ys_val = next(dl_val_it) + except StopIteration: #Fin epoch val + dl_val_it = iter(dl_val) + xs_val, ys_val = next(dl_val_it) + xs_val, ys_val = xs_val.to(device), ys_val.to(device) + + val_loss = F.cross_entropy(model(xs_val), ys_val) + accuracy, _ =test(model) + model.train() + + #### Print #### + if(print_freq and epoch%print_freq==0): + print('-'*9) + print('Epoch : %d/%d'%(epoch,epochs)) + print('Time : %.00f'%(tf - t0)) + print('Train loss :',loss.item(), '/ val loss', val_loss.item()) + print('Sup Loss :', sup_loss.item(), '/ unsup_loss :', unsup_loss.item()) + print('Accuracy :', accuracy) + + #### Log #### + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": None, + } + log.append(data) + + return log + def run_simple_dataug(inner_it, epochs=1): device = next(model.parameters()).device diff --git a/higher/train_utils.py b/higher/train_utils.py index 7846ef5..20270fe 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -1,3 +1,7 @@ +""" Utilities function for training. + +""" + import torch #import torch.optim import torchvision @@ -7,6 +11,14 @@ from datasets import * from utils import * def test(model): + """Evaluate a model on test data. + + Args: + model (nn.Module): Model to test. + + Returns: + (float, Tensor) Returns the accuracy and test loss of the model. + """ device = next(model.parameters()).device model.eval() @@ -35,6 +47,16 @@ def test(model): return accuracy, np.mean(loss) def compute_vaLoss(model, dl_it, dl): + """Evaluate a model on a batch of data. + + Args: + model (nn.Module): Model to evaluate. + dl_it (Iterator): Data loader iterator. + dl (DataLoader): Data loader. + + Returns: + (Tensor) Loss on a single batch of data. + """ device = next(model.parameters()).device try: xs, ys = next(dl_it) @@ -47,6 +69,17 @@ def compute_vaLoss(model, dl_it, dl): return F.cross_entropy(F.log_softmax(model(xs), dim=1), ys) def train_classic(model, opt_param, epochs=1, print_freq=1): + """Classic training of a model. + + Args: + model (nn.Module): Model to train. + opt_param (dict): Dictionnary containing optimizers parameters. + epochs (int): Number of epochs to perform. (default: 1) + print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1) + + Returns: + (list) Logs of training. Each items is a dict containing results of an epoch. + """ device = next(model.parameters()).device #opt = torch.optim.Adam(model.parameters(), lr=1e-3) optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 @@ -97,152 +130,30 @@ def train_classic(model, opt_param, epochs=1, print_freq=1): return log -def train_classic_higher(model, epochs=1): - device = next(model.parameters()).device - #opt = torch.optim.Adam(model.parameters(), lr=1e-3) - optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) +def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start=0, print_freq=1, KLdiv=1, hp_opt=False, save_sample_freq=None): + """Training of an augmented model with higher. - model.train() - dl_val_it = iter(dl_val) - log = [] + This function is intended to be used with Augmented_model containing an Higher_model (see dataug.py). + Ex : Augmented_model(Data_augV5(...), Higher_model(model)) - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False) - #with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt): + Training loss can either be computed directly from augmented inputs (KLdiv=0). + However, it is recommended to use the KLdiv loss computation, inspired from UDA, which combine original and augmented inputs to compute the loss (KLdiv>0). + See : https://github.com/google-research/uda - for epoch in range(epochs): - #print_torch_mem("Start epoch "+str(epoch)) - #print("Fast param ",len(fmodel._fast_params)) - t0 = time.process_time() - for i, (features, labels) in enumerate(dl_train): - #print_torch_mem("Start iter") - features,labels = features.to(device), labels.to(device) + Args: + model (nn.Module): Augmented model to train. + opt_param (dict): Dictionnary containing optimizers parameters. + epochs (int): Number of epochs to perform. (default: 1) + inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step. (default: 1) + dataug_epoch_start (int): Epoch when to start data augmentation. (default: 0) + print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1) + KLdiv (float): Proportion of the KLdiv loss added to the supervised loss. If set to 0, the loss is classicly computed on augmented inputs. (default: 1) + hp_opt (bool): Wether to learn inner optimizer parameters. (default: False) + save_sample_freq (int): Number of epochs between saves of samples of data. If set to None, only one save would be done at the end of the training. (default: None) - #optim.zero_grad() - logits = model.forward(features) - pred = F.log_softmax(logits, dim=1) - loss = F.cross_entropy(pred,labels) - #.backward() - #optim.step() - diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step) - - model_copy(src=fmodel, dst=model, patch_copy=False) - optim_copy(dopt=diffopt, opt=optim) - fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) - diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False) - - #### Tests #### - tf = time.process_time() - try: - xs_val, ys_val = next(dl_val_it) - except StopIteration: #Fin epoch val - dl_val_it = iter(dl_val) - xs_val, ys_val = next(dl_val_it) - xs_val, ys_val = xs_val.to(device), ys_val.to(device) - - val_loss = F.cross_entropy(model(xs_val), ys_val) - accuracy, _ =test(model) - model.train() - #### Log #### - data={ - "epoch": epoch, - "train_loss": loss.item(), - "val_loss": val_loss.item(), - "acc": accuracy, - "time": tf - t0, - - "param": None, - } - log.append(data) - - return log - -def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1): - - device = next(model.parameters()).device - #opt = torch.optim.Adam(model.parameters(), lr=1e-3) - opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 - - - model.train() - dl_val_it = iter(dl_val) - dl_unsup_it =iter(dl_unsup) - log = [] - for epoch in range(epochs): - #print_torch_mem("Start epoch") - t0 = time.process_time() - for i, (features, labels) in enumerate(dl_train): - #print_torch_mem("Start iter") - features,labels = features.to(device), labels.to(device) - - optim.zero_grad() - #Supervised - logits = model.forward(features) - pred = F.log_softmax(logits, dim=1) - sup_loss = F.cross_entropy(pred,labels) - - #Unsupervised - try: - aug_xs, origin_xs, ys = next(dl_unsup_it) - except StopIteration: #Fin epoch val - dl_unsup_it =iter(dl_unsup) - aug_xs, origin_xs, ys = next(dl_unsup_it) - aug_xs, origin_xs, ys = aug_xs.to(device), origin_xs.to(device), ys.to(device) - - #print(aug_xs.shape, origin_xs.shape, ys.shape) - sup_logits = model.forward(origin_xs) - unsup_logits = model.forward(aug_xs) - - log_sup=F.log_softmax(sup_logits, dim=1) - log_unsup=F.log_softmax(unsup_logits, dim=1) - #KL div w/ logits - unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup) - unsup_loss=unsup_loss.sum(dim=-1).mean() - - #print(unsup_loss) - unsupp_coeff = 1 - loss = sup_loss + unsup_loss * unsupp_coeff - - loss.backward() - optim.step() - - #### Tests #### - tf = time.process_time() - try: - xs_val, ys_val = next(dl_val_it) - except StopIteration: #Fin epoch val - dl_val_it = iter(dl_val) - xs_val, ys_val = next(dl_val_it) - xs_val, ys_val = xs_val.to(device), ys_val.to(device) - - val_loss = F.cross_entropy(model(xs_val), ys_val) - accuracy, _ =test(model) - model.train() - - #### Print #### - if(print_freq and epoch%print_freq==0): - print('-'*9) - print('Epoch : %d/%d'%(epoch,epochs)) - print('Time : %.00f'%(tf - t0)) - print('Train loss :',loss.item(), '/ val loss', val_loss.item()) - print('Sup Loss :', sup_loss.item(), '/ unsup_loss :', unsup_loss.item()) - print('Accuracy :', accuracy) - - #### Log #### - data={ - "epoch": epoch, - "train_loss": loss.item(), - "val_loss": val_loss.item(), - "acc": accuracy, - "time": tf - t0, - - "param": None, - } - log.append(data) - - return log - -def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, hp_opt=False, save_sample=False): + Returns: + (list) Logs of training. Each items is a dict containing results of an epoch. + """ device = next(model.parameters()).device log = [] dl_val_it = iter(dl_val) @@ -282,7 +193,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start for i, (xs, ys) in enumerate(dl_train): xs, ys = xs.to(device), ys.to(device) - if(not KLdiv): + if(KLdiv<=0): #Methode uniforme logits = model(xs) # modified `params` can also be passed as a kwarg loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards() @@ -317,8 +228,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start aug_loss = (w_loss * aug_loss).mean() aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() - unsupp_coeff = 1 - loss += aug_loss * unsupp_coeff + loss += aug_loss * KLdiv #print_graph(loss) #to visualize computational graph @@ -351,7 +261,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start tf = time.process_time() - if save_sample: #Data sample saving + if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving try: viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch)) viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch)) @@ -423,4 +333,4 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start print("Couldn't save finals samples") pass - return log \ No newline at end of file + return log diff --git a/higher/utils.py b/higher/utils.py index 6fab9bc..2dbe623 100755 --- a/higher/utils.py +++ b/higher/utils.py @@ -121,7 +121,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None): plt.figure(figsize=(10,10)) for i in range(25): - plt.subplot(5,5,i+1) + plt.subplot(5,5,i+1) #Trop de figure cree ? plt.xticks([]) plt.yticks([]) plt.grid(False) @@ -132,7 +132,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None): plt.savefig(fig_name) print("Sample saved :", fig_name) - plt.close() + plt.close('all') def print_torch_mem(add_info=''):