mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Test KL divergence from UDA
This commit is contained in:
parent
fa5bc72616
commit
217f94ef89
5 changed files with 52 additions and 28 deletions
|
@ -50,6 +50,7 @@ class AugmentedDataset(VisionDataset):
|
|||
for idx, img in enumerate(self.sup_data):
|
||||
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
|
||||
|
||||
self.unsup_ratio=5 #Batch size unsup = train batch size * unsup_ratio
|
||||
self.unsup_data=[]
|
||||
self.unsup_targets=[]
|
||||
|
||||
|
|
|
@ -43,7 +43,8 @@ class LeNet(nn.Module):
|
|||
#print("Shape ", out.shape)
|
||||
out = F.linear(out, self._params["w4"], self._params["b4"])
|
||||
#print("Shape ", out.shape)
|
||||
return F.log_softmax(out, dim=1)
|
||||
#return F.log_softmax(out, dim=1)
|
||||
return out
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
|
|
@ -35,7 +35,7 @@ if __name__ == "__main__":
|
|||
|
||||
|
||||
n_inner_iter = 1
|
||||
epochs = 200
|
||||
epochs = 100
|
||||
dataug_epoch_start=0
|
||||
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
|
|
|
@ -70,7 +70,7 @@ if __name__ == "__main__":
|
|||
'aug_model'
|
||||
}
|
||||
n_inner_iter = 1
|
||||
epochs = 200
|
||||
epochs = 100
|
||||
dataug_epoch_start=0
|
||||
|
||||
|
||||
|
@ -149,12 +149,12 @@ if __name__ == "__main__":
|
|||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
|
||||
#aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, prob_set_size=2, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device)
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), WideResNet(num_classes=10, wrn_size=32)).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), WideResNet(num_classes=10, wrn_size=32)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, KLdiv=True, loss_patience=None)
|
||||
|
||||
exec_time=time.process_time() - t0
|
||||
####
|
||||
|
@ -162,7 +162,7 @@ if __name__ == "__main__":
|
|||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it (KLdiv)".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
||||
with open("res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
|
|
@ -542,7 +542,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
|||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, loss_patience=None):
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None):
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
countcopy=0
|
||||
|
@ -578,30 +578,51 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
'''
|
||||
|
||||
#Methode exacte
|
||||
final_loss = 0
|
||||
for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
fmodel['data_aug'].transf_idx=tf_idx
|
||||
logits = fmodel(xs)
|
||||
loss = F.cross_entropy(logits, ys)
|
||||
#loss.backward(retain_graph=True)
|
||||
#print('idx', tf_idx)
|
||||
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
|
||||
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
#final_loss = 0
|
||||
#for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
# fmodel['data_aug'].transf_idx=tf_idx
|
||||
# logits = fmodel(xs)
|
||||
# loss = F.cross_entropy(logits, ys)
|
||||
# #loss.backward(retain_graph=True)
|
||||
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
#loss = final_loss
|
||||
|
||||
loss = final_loss
|
||||
'''
|
||||
#KLdiv=False
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
fmodel.augment(mode=False)
|
||||
sup_logits = fmodel(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
fmodel.augment(mode=True)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if fmodel._data_augmentation:
|
||||
aug_logits = fmodel(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
#KL div w/ logits
|
||||
aug_loss = sup_logits*(log_sup-log_aug)
|
||||
aug_loss=aug_loss.sum(dim=-1)
|
||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
|
||||
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.unsqueeze(dim=1).expand(-1,10) #Weight loss
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
unsupp_coeff = 1
|
||||
loss += aug_loss * unsupp_coeff
|
||||
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
#'''
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
@ -664,6 +685,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
print('Aug loss', aug_loss.item())
|
||||
#############
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue