Test KL divergence from UDA

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-06 10:44:18 -05:00
parent fa5bc72616
commit 217f94ef89
5 changed files with 52 additions and 28 deletions

View file

@ -50,6 +50,7 @@ class AugmentedDataset(VisionDataset):
for idx, img in enumerate(self.sup_data): for idx, img in enumerate(self.sup_data):
self.sup_data[idx]= Image.fromarray(img) #to PIL Image self.sup_data[idx]= Image.fromarray(img) #to PIL Image
self.unsup_ratio=5 #Batch size unsup = train batch size * unsup_ratio
self.unsup_data=[] self.unsup_data=[]
self.unsup_targets=[] self.unsup_targets=[]

View file

@ -43,7 +43,8 @@ class LeNet(nn.Module):
#print("Shape ", out.shape) #print("Shape ", out.shape)
out = F.linear(out, self._params["w4"], self._params["b4"]) out = F.linear(out, self._params["w4"], self._params["b4"])
#print("Shape ", out.shape) #print("Shape ", out.shape)
return F.log_softmax(out, dim=1) #return F.log_softmax(out, dim=1)
return out
def __getitem__(self, key): def __getitem__(self, key):
return self._params[key] return self._params[key]

View file

@ -35,7 +35,7 @@ if __name__ == "__main__":
n_inner_iter = 1 n_inner_iter = 1
epochs = 200 epochs = 100
dataug_epoch_start=0 dataug_epoch_start=0
tf_dict = {k: TF.TF_dict[k] for k in tf_names} tf_dict = {k: TF.TF_dict[k] for k in tf_names}

View file

@ -70,7 +70,7 @@ if __name__ == "__main__":
'aug_model' 'aug_model'
} }
n_inner_iter = 1 n_inner_iter = 1
epochs = 200 epochs = 100
dataug_epoch_start=0 dataug_epoch_start=0
@ -149,12 +149,12 @@ if __name__ == "__main__":
tf_dict = {k: TF.TF_dict[k] for k in tf_names} tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device) aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device) #aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device)
print(str(aug_model), 'on', device_name) print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, KLdiv=True, loss_patience=None)
exec_time=time.process_time() - t0 exec_time=time.process_time() - t0
#### ####
@ -162,7 +162,7 @@ if __name__ == "__main__":
times = [x["time"] for x in log] times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1]) print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter) filename = "{}-{} epochs (dataug:{})- {} in_it (KLdiv)".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f: with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True) json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !') print('Log :\"',f.name, '\" saved !')

View file

@ -542,7 +542,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
print("Copy ", countcopy) print("Copy ", countcopy)
return log return log
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, loss_patience=None): def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None):
device = next(model.parameters()).device device = next(model.parameters()).device
log = [] log = []
countcopy=0 countcopy=0
@ -578,30 +578,51 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
for i, (xs, ys) in enumerate(dl_train): for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device) xs, ys = xs.to(device), ys.to(device)
'''
#Methode exacte #Methode exacte
final_loss = 0 #final_loss = 0
for tf_idx in range(fmodel['data_aug']._nb_tf): #for tf_idx in range(fmodel['data_aug']._nb_tf):
fmodel['data_aug'].transf_idx=tf_idx # fmodel['data_aug'].transf_idx=tf_idx
logits = fmodel(xs) # logits = fmodel(xs)
loss = F.cross_entropy(logits, ys) # loss = F.cross_entropy(logits, ys)
#loss.backward(retain_graph=True) # #loss.backward(retain_graph=True)
#print('idx', tf_idx) # final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad) #loss = final_loss
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
loss = final_loss #KLdiv=False
''' if(not KLdiv):
#Methode uniforme #Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
logits = fmodel(xs) # modified `params` can also be passed as a kwarg if fmodel._data_augmentation: #Weight loss
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
fmodel.augment(mode=True)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
#KL div w/ logits
aug_loss = sup_logits*(log_sup-log_aug)
aug_loss=aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss
aug_loss = (w_loss * aug_loss).mean()
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
print('TF Proba :', model['data_aug']['prob'].data)
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
#'''
#to visualize computational graph #to visualize computational graph
#print_graph(loss) #print_graph(loss)
@ -664,6 +685,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
print('TF Mag :', model['data_aug']['mag'].data) print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad) #print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item()) #print('Reg loss:', model['data_aug'].reg_loss().item())
print('Aug loss', aug_loss.item())
############# #############
#### Log #### #### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5) #print(type(model['data_aug']) is dataug.Data_augV5)