diff --git a/higher/model.py b/higher/model.py index 4aa3b7c..794aefd 100755 --- a/higher/model.py +++ b/higher/model.py @@ -323,7 +323,7 @@ class Bottleneck(nn.Module): #ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2] class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() @@ -419,11 +419,14 @@ class ResNet(nn.Module): def forward(self, x): return self._forward_impl(x) + def __str__(self): + return "ResNet18" + ## Wide ResNet ## #https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py #https://github.com/arcelien/pba/blob/master/pba/wrn.py #https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py - +''' class BasicBlock(nn.Module): def __init__(self, in_planes, out_planes, stride, dropRate=0.0): super(BasicBlock, self).__init__() @@ -516,3 +519,4 @@ class WideResNet(nn.Module): def __str__(self): return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth) +''' \ No newline at end of file diff --git a/higher/test_dataug.py b/higher/test_dataug.py index bde04c4..2c51923 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -65,16 +65,28 @@ else: if __name__ == "__main__": tasks={ - #'classic', - 'aug_dataset', + 'classic', + #'aug_dataset', #'aug_model' } n_inner_iter = 1 - epochs = 150 + epochs = 100 dataug_epoch_start=0 + optim_param={ + 'Meta':{ + 'optim':'Adam', + 'lr':1e-2, #1e-2 + }, + 'Inner':{ + 'optim': 'SGD', + 'lr':1e-2, #1e-2 + 'momentum':0.9, #0.9 + } + } - model = LeNet(3,10) + #model = LeNet(3,10) #model = MobileNetV2(num_classes=10) + model = ResNet(num_classes=10) #model = WideResNet(num_classes=10, wrn_size=32) #### Classic #### @@ -83,14 +95,14 @@ if __name__ == "__main__": model = model.to(device) print("{} on {} for {} epochs".format(str(model), device_name, epochs)) - log= train_classic(model=model, epochs=epochs, print_freq=1) + log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1) #log= train_classic_higher(model=model, epochs=epochs) exec_time=time.process_time() - t0 #### print('-'*9) times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Log": log} + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param['Inner'], "Device": device_name, "Log": log} print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) filename = "{}-{} epochs".format(str(model),epochs) with open("res/log/%s.json" % filename, "w+") as f: @@ -123,7 +135,7 @@ if __name__ == "__main__": ##log= train_classic_higher(model=model, epochs=epochs) data_train_aug = AugmentedDatasetV2("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2))) - data_train_aug.augement_data(aug_copy=10) + data_train_aug.augement_data(aug_copy=1) print(data_train_aug) unsup_ratio = 5 dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True) @@ -135,13 +147,13 @@ if __name__ == "__main__": model = model.to(device) print("{} on {} for {} epochs".format(str(model), device_name, epochs)) - log= train_UDA(model=model, dl_unsup=dl_unsup, epochs=epochs, print_freq=10) + log= train_UDA(model=model, dl_unsup=dl_unsup, epochs=epochs, opt_param=optim_param, print_freq=10) exec_time=time.process_time() - t0 #### print('-'*9) times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Log": log} + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param['Inner'], "Device": device_name, "Log": log} print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) filename = "{}-{}-{} epochs".format(str(data_train_aug),str(model),epochs) with open("res/log/%s.json" % filename, "w+") as f: @@ -164,13 +176,20 @@ if __name__ == "__main__": #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=False, loss_patience=None) + log= run_dist_dataugV2(model=aug_model, + epochs=epochs, + inner_it=n_inner_iter, + dataug_epoch_start=dataug_epoch_start, + opt_param=optim_param, + print_freq=10, + KLdiv=True, + loss_patience=None) exec_time=time.process_time() - t0 #### print('-'*9) times = [x["time"] for x in log] - out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} + out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Log": log} print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) with open("res/log/%s.json" % filename, "w+") as f: diff --git a/higher/train_utils.py b/higher/train_utils.py index c1fc880..ec3a9c5 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -47,10 +47,10 @@ def compute_vaLoss(model, dl_it, dl): return F.cross_entropy(model(xs), ys) -def train_classic(model, epochs=1, print_freq=1): +def train_classic(model, opt_param, epochs=1, print_freq=1): device = next(model.parameters()).device #opt = torch.optim.Adam(model.parameters(), lr=1e-3) - optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) + optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 model.train() dl_val_it = iter(dl_val) @@ -305,11 +305,12 @@ def train_classic_tests(model, epochs=1): print("Copy ", countcopy) return log -def train_UDA(model, dl_unsup, epochs=1, print_freq=1): +def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1): device = next(model.parameters()).device #opt = torch.optim.Adam(model.parameters(), lr=1e-3) - optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) + opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 + model.train() dl_val_it = iter(dl_val) @@ -340,14 +341,13 @@ def train_UDA(model, dl_unsup, epochs=1, print_freq=1): sup_logits = model.forward(origin_xs) unsup_logits = model.forward(aug_xs) - #print(unsup_logits.shape, sup_logits.shape) log_sup=F.log_softmax(sup_logits, dim=1) log_unsup=F.log_softmax(unsup_logits, dim=1) #KL div w/ logits unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup) unsup_loss=unsup_loss.sum(dim=-1).mean() - #print(unsup_loss.shape) + #print(unsup_loss) unsupp_coeff = 1 loss = sup_loss + unsup_loss * unsupp_coeff @@ -629,7 +629,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0): print("Copy ", countcopy) return log -def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False): +def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False): device = next(model.parameters()).device log = [] countcopy=0 @@ -637,8 +637,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f dl_val_it = iter(dl_val) #if inner_it!=0: - meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) #lr=1e-2 - inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9) + meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2 + inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9 high_grad_track = True if inner_it == 0: @@ -703,7 +703,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f #aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions) w_loss = fmodel['data_aug'].loss_weight() #Weight loss - aug_loss = (w_loss * aug_loss).mean() + aug_loss = (w_loss * aug_loss).mean() #apprentissage differe ? + + aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean() + #print(aug_loss) unsupp_coeff = 1 loss += aug_loss * unsupp_coeff