diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 87dae1e..bde04c4 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -66,10 +66,10 @@ if __name__ == "__main__": tasks={ #'classic', - #'aug_dataset', - 'aug_model' + 'aug_dataset', + #'aug_model' } - n_inner_iter = 0 + n_inner_iter = 1 epochs = 150 dataug_epoch_start=0 @@ -108,19 +108,34 @@ if __name__ == "__main__": t0 = time.process_time() - data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) - data_train_aug.augement_data(aug_copy=30) - print(data_train_aug) - dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True) + #data_train_aug = AugmentedDataset("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) + #data_train_aug.augement_data(aug_copy=30) + #print(data_train_aug) + #dl_train = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE, shuffle=True) - xs, ys = next(iter(dl_train)) - viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug))) + #xs, ys = next(iter(dl_train)) + #viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug))) + + #model = model.to(device) + + #print("{} on {} for {} epochs".format(str(model), device_name, epochs)) + #log= train_classic(model=model, epochs=epochs, print_freq=10) + ##log= train_classic_higher(model=model, epochs=epochs) + + data_train_aug = AugmentedDatasetV2("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) + data_train_aug.augement_data(aug_copy=10) + print(data_train_aug) + unsup_ratio = 5 + dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True) + + unsup_xs, sup_xs, ys = next(iter(dl_unsup)) + viz_sample_data(imgs=sup_xs, labels=ys, fig_name='samples/data_sample_{}'.format(str(data_train_aug))) + viz_sample_data(imgs=unsup_xs, labels=ys, fig_name='samples/data_sample_{}_unsup'.format(str(data_train_aug))) model = model.to(device) print("{} on {} for {} epochs".format(str(model), device_name, epochs)) - log= train_classic(model=model, epochs=epochs, print_freq=10) - #log= train_classic_higher(model=model, epochs=epochs) + log= train_UDA(model=model, dl_unsup=dl_unsup, epochs=epochs, print_freq=10) exec_time=time.process_time() - t0 #### @@ -145,11 +160,11 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} #aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), model).to(device) - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.0, fixed_prob=True, fixed_mag=True, shared_mag=True), model).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=False, loss_patience=None) exec_time=time.process_time() - t0 #### @@ -157,7 +172,7 @@ if __name__ == "__main__": times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) - filename = "{}-{} epochs (dataug:{})- {} in_it (KLdiv)".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) + filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) with open("res/log/%s.json" % filename, "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') diff --git a/higher/train_utils.py b/higher/train_utils.py index fb14fdb..856d884 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -305,6 +305,91 @@ def train_classic_tests(model, epochs=1): print("Copy ", countcopy) return log +def train_UDA(model, dl_unsup, epochs=1, print_freq=1): + + device = next(model.parameters()).device + #opt = torch.optim.Adam(model.parameters(), lr=1e-3) + optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) + + model.train() + dl_val_it = iter(dl_val) + dl_unsup_it =iter(dl_unsup) + log = [] + for epoch in range(epochs): + #print_torch_mem("Start epoch") + t0 = time.process_time() + for i, (features, labels) in enumerate(dl_train): + #print_torch_mem("Start iter") + features,labels = features.to(device), labels.to(device) + + optim.zero_grad() + #Supervised + logits = model.forward(features) + pred = F.log_softmax(logits, dim=1) + sup_loss = F.cross_entropy(pred,labels) + + #Unsupervised + try: + aug_xs, origin_xs, ys = next(dl_unsup_it) + except StopIteration: #Fin epoch val + dl_unsup_it =iter(dl_unsup) + aug_xs, origin_xs, ys = next(dl_unsup_it) + aug_xs, origin_xs, ys = aug_xs.to(device), origin_xs.to(device), ys.to(device) + + #print(aug_xs.shape, origin_xs.shape, ys.shape) + sup_logits = model.forward(origin_xs) + unsup_logits = model.forward(aug_xs) + + #print(unsup_logits.shape, sup_logits.shape) + log_sup=F.log_softmax(sup_logits, dim=1) + log_unsup=F.log_softmax(unsup_logits, dim=1) + #KL div w/ logits + unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup) + unsup_loss=unsup_loss.sum(dim=-1).mean() + + #print(unsup_loss.shape) + unsupp_coeff = 1 + loss = sup_loss + unsup_loss * unsupp_coeff + + loss.backward() + optim.step() + + #### Tests #### + tf = time.process_time() + try: + xs_val, ys_val = next(dl_val_it) + except StopIteration: #Fin epoch val + dl_val_it = iter(dl_val) + xs_val, ys_val = next(dl_val_it) + xs_val, ys_val = xs_val.to(device), ys_val.to(device) + + val_loss = F.cross_entropy(model(xs_val), ys_val) + accuracy, _ =test(model) + model.train() + + #### Print #### + if(print_freq and epoch%print_freq==0): + print('-'*9) + print('Epoch : %d/%d'%(epoch,epochs)) + print('Time : %.00f'%(tf - t0)) + print('Train loss :',loss.item(), '/ val loss', val_loss.item()) + print('Sup Loss :', sup_loss.item(), '/ unsup_loss :', unsup_loss.item()) + print('Accuracy :', accuracy) + + #### Log #### + data={ + "epoch": epoch, + "train_loss": loss.item(), + "val_loss": val_loss.item(), + "acc": accuracy, + "time": tf - t0, + + "param": None, + } + log.append(data) + + return log + def run_simple_dataug(inner_it, epochs=1): device = next(model.parameters()).device dl_train_it = iter(dl_train)