From 217f94ef8971c7a449a7e102de2a153b4d8d31b7 Mon Sep 17 00:00:00 2001 From: "Harle, Antoine (Contracteur)" Date: Fri, 6 Dec 2019 10:44:18 -0500 Subject: [PATCH] Test KL divergence from UDA --- higher/datasets.py | 1 + higher/model.py | 3 +- higher/test_brutus.py | 2 +- higher/test_dataug.py | 10 +++---- higher/train_utils.py | 64 +++++++++++++++++++++++++++++-------------- 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/higher/datasets.py b/higher/datasets.py index 6be5e9e..88ec045 100755 --- a/higher/datasets.py +++ b/higher/datasets.py @@ -50,6 +50,7 @@ class AugmentedDataset(VisionDataset): for idx, img in enumerate(self.sup_data): self.sup_data[idx]= Image.fromarray(img) #to PIL Image + self.unsup_ratio=5 #Batch size unsup = train batch size * unsup_ratio self.unsup_data=[] self.unsup_targets=[] diff --git a/higher/model.py b/higher/model.py index ba8064f..3b466fe 100755 --- a/higher/model.py +++ b/higher/model.py @@ -43,7 +43,8 @@ class LeNet(nn.Module): #print("Shape ", out.shape) out = F.linear(out, self._params["w4"], self._params["b4"]) #print("Shape ", out.shape) - return F.log_softmax(out, dim=1) + #return F.log_softmax(out, dim=1) + return out def __getitem__(self, key): return self._params[key] diff --git a/higher/test_brutus.py b/higher/test_brutus.py index 6935046..8740184 100755 --- a/higher/test_brutus.py +++ b/higher/test_brutus.py @@ -35,7 +35,7 @@ if __name__ == "__main__": n_inner_iter = 1 - epochs = 200 + epochs = 100 dataug_epoch_start=0 tf_dict = {k: TF.TF_dict[k] for k in tf_names} diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 40a94da..6679c39 100755 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -70,7 +70,7 @@ if __name__ == "__main__": 'aug_model' } n_inner_iter = 1 - epochs = 200 + epochs = 100 dataug_epoch_start=0 @@ -149,12 +149,12 @@ if __name__ == "__main__": tf_dict = {k: TF.TF_dict[k] for k in tf_names} #aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device) - #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device) - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device) + #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) - log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None) + log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, KLdiv=True, loss_patience=None) exec_time=time.process_time() - t0 #### @@ -162,7 +162,7 @@ if __name__ == "__main__": times = [x["time"] for x in log] out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) - filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) + filename = "{}-{} epochs (dataug:{})- {} in_it (KLdiv)".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) with open("res/log/%s.json" % filename, "w+") as f: json.dump(out, f, indent=True) print('Log :\"',f.name, '\" saved !') diff --git a/higher/train_utils.py b/higher/train_utils.py index b11fb7b..75335bb 100755 --- a/higher/train_utils.py +++ b/higher/train_utils.py @@ -542,7 +542,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0): print("Copy ", countcopy) return log -def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, loss_patience=None): +def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None): device = next(model.parameters()).device log = [] countcopy=0 @@ -578,30 +578,51 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f for i, (xs, ys) in enumerate(dl_train): xs, ys = xs.to(device), ys.to(device) - ''' + #Methode exacte - final_loss = 0 - for tf_idx in range(fmodel['data_aug']._nb_tf): - fmodel['data_aug'].transf_idx=tf_idx - logits = fmodel(xs) - loss = F.cross_entropy(logits, ys) - #loss.backward(retain_graph=True) - #print('idx', tf_idx) - #print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad) - final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ? + #final_loss = 0 + #for tf_idx in range(fmodel['data_aug']._nb_tf): + # fmodel['data_aug'].transf_idx=tf_idx + # logits = fmodel(xs) + # loss = F.cross_entropy(logits, ys) + # #loss.backward(retain_graph=True) + # final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ? + #loss = final_loss - loss = final_loss - ''' + #KLdiv=False + if(not KLdiv): #Methode uniforme - - logits = fmodel(xs) # modified `params` can also be passed as a kwarg - loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() + logits = fmodel(xs) # modified `params` can also be passed as a kwarg + loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() + + if fmodel._data_augmentation: #Weight loss + w_loss = fmodel['data_aug'].loss_weight()#.to(device) + loss = loss * w_loss + loss = loss.mean() + + else: + #Methode KL div + fmodel.augment(mode=False) + sup_logits = fmodel(xs) + log_sup=F.log_softmax(sup_logits, dim=1) + fmodel.augment(mode=True) + loss = F.cross_entropy(log_sup, ys) + + if fmodel._data_augmentation: + aug_logits = fmodel(xs) + log_aug=F.log_softmax(aug_logits, dim=1) + #KL div w/ logits + aug_loss = sup_logits*(log_sup-log_aug) + aug_loss=aug_loss.sum(dim=-1) + #aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions) + + w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss + aug_loss = (w_loss * aug_loss).mean() + unsupp_coeff = 1 + loss += aug_loss * unsupp_coeff + + print('TF Proba :', model['data_aug']['prob'].data) - if fmodel._data_augmentation: #Weight loss - w_loss = fmodel['data_aug'].loss_weight()#.to(device) - loss = loss * w_loss - loss = loss.mean() - #''' #to visualize computational graph #print_graph(loss) @@ -664,6 +685,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f print('TF Mag :', model['data_aug']['mag'].data) #print('Mag grad',model['data_aug']['mag'].grad) #print('Reg loss:', model['data_aug'].reg_loss().item()) + print('Aug loss', aug_loss.item()) ############# #### Log #### #print(type(model['data_aug']) is dataug.Data_augV5)