Test KL divergence from UDA

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-06 10:44:18 -05:00
parent fa5bc72616
commit 217f94ef89
5 changed files with 52 additions and 28 deletions

View file

@ -50,6 +50,7 @@ class AugmentedDataset(VisionDataset):
for idx, img in enumerate(self.sup_data):
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
self.unsup_ratio=5 #Batch size unsup = train batch size * unsup_ratio
self.unsup_data=[]
self.unsup_targets=[]

View file

@ -43,7 +43,8 @@ class LeNet(nn.Module):
#print("Shape ", out.shape)
out = F.linear(out, self._params["w4"], self._params["b4"])
#print("Shape ", out.shape)
return F.log_softmax(out, dim=1)
#return F.log_softmax(out, dim=1)
return out
def __getitem__(self, key):
return self._params[key]

View file

@ -35,7 +35,7 @@ if __name__ == "__main__":
n_inner_iter = 1
epochs = 200
epochs = 100
dataug_epoch_start=0
tf_dict = {k: TF.TF_dict[k] for k in tf_names}

View file

@ -70,7 +70,7 @@ if __name__ == "__main__":
'aug_model'
}
n_inner_iter = 1
epochs = 200
epochs = 100
dataug_epoch_start=0
@ -149,12 +149,12 @@ if __name__ == "__main__":
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device)
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device)
print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, KLdiv=True, loss_patience=None)
exec_time=time.process_time() - t0
####
@ -162,7 +162,7 @@ if __name__ == "__main__":
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
filename = "{}-{} epochs (dataug:{})- {} in_it (KLdiv)".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')

View file

@ -542,7 +542,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
print("Copy ", countcopy)
return log
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, loss_patience=None):
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None):
device = next(model.parameters()).device
log = []
countcopy=0
@ -578,22 +578,20 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
'''
#Methode exacte
final_loss = 0
for tf_idx in range(fmodel['data_aug']._nb_tf):
fmodel['data_aug'].transf_idx=tf_idx
logits = fmodel(xs)
loss = F.cross_entropy(logits, ys)
#loss.backward(retain_graph=True)
#print('idx', tf_idx)
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#final_loss = 0
#for tf_idx in range(fmodel['data_aug']._nb_tf):
# fmodel['data_aug'].transf_idx=tf_idx
# logits = fmodel(xs)
# loss = F.cross_entropy(logits, ys)
# #loss.backward(retain_graph=True)
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#loss = final_loss
loss = final_loss
'''
#KLdiv=False
if(not KLdiv):
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
@ -601,7 +599,30 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
#'''
else:
#Methode KL div
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
fmodel.augment(mode=True)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
#KL div w/ logits
aug_loss = sup_logits*(log_sup-log_aug)
aug_loss=aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss
aug_loss = (w_loss * aug_loss).mean()
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
print('TF Proba :', model['data_aug']['prob'].data)
#to visualize computational graph
#print_graph(loss)
@ -664,6 +685,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item())
print('Aug loss', aug_loss.item())
#############
#### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5)