mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Borne mag + Regularisation mag
This commit is contained in:
parent
f4bdd9bca5
commit
64282bda3a
10 changed files with 43 additions and 228 deletions
|
@ -38,8 +38,8 @@ data_test = torchvision.datasets.CIFAR10(
|
||||||
"./data", train=False, download=True, transform=transform
|
"./data", train=False, download=True, transform=transform
|
||||||
)
|
)
|
||||||
#'''
|
#'''
|
||||||
#train_subset_indices=range(int(len(data_train)/2))
|
train_subset_indices=range(int(len(data_train)/2))
|
||||||
train_subset_indices=range(BATCH_SIZE*10)
|
#train_subset_indices=range(BATCH_SIZE*10)
|
||||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||||
|
|
||||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||||
|
|
|
@ -114,7 +114,7 @@ class Data_augV2(nn.Module): #Methode exacte
|
||||||
return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w)
|
return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w)
|
||||||
|
|
||||||
|
|
||||||
def adjust_prob(self): #Detach from gradient ?
|
def adjust_param(self): #Detach from gradient ?
|
||||||
self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||||
#print('proba',self._params['prob'])
|
#print('proba',self._params['prob'])
|
||||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
@ -262,7 +262,7 @@ class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte
|
||||||
# warp the original image by the found transform
|
# warp the original image by the found transform
|
||||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||||
|
|
||||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||||
|
|
||||||
if soft :
|
if soft :
|
||||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
|
@ -478,7 +478,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
'''
|
'''
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||||
|
|
||||||
if soft :
|
if soft :
|
||||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
|
@ -549,15 +549,22 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||||
"mag" : nn.Parameter(torch.tensor(0.5) if self._shared_mag
|
"mag" : nn.Parameter(torch.tensor(0.5) if self._shared_mag
|
||||||
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
|
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||||
})
|
})
|
||||||
self._samples = []
|
|
||||||
|
|
||||||
|
#Distribution
|
||||||
|
self._samples = []
|
||||||
self._mix_dist = False
|
self._mix_dist = False
|
||||||
if mix_dist != 0.0:
|
if mix_dist != 0.0:
|
||||||
self._mix_dist = True
|
self._mix_dist = True
|
||||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||||
|
|
||||||
|
#Mag regularisation
|
||||||
|
if not self._fixed_mag:
|
||||||
|
ignore={'Identity', 'FlipUD', 'FlipLR', 'Solarize', 'Posterize'}
|
||||||
|
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in ignore]
|
||||||
|
self._reg_tgt = torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self._data_augmentation:
|
if self._data_augmentation:
|
||||||
device = x.device
|
device = x.device
|
||||||
|
@ -610,18 +617,17 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||||
|
|
||||||
if soft :
|
if soft :
|
||||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
else:
|
else:
|
||||||
#self._params['prob'].clamp(min=0.0,max=1.0)
|
|
||||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||||
|
|
||||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
|
||||||
|
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||||
|
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||||
|
|
||||||
def loss_weight(self):
|
def loss_weight(self):
|
||||||
# 1 seule TF
|
# 1 seule TF
|
||||||
|
@ -642,6 +648,9 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
w_loss = torch.sum(w_loss,dim=1)
|
w_loss = torch.sum(w_loss,dim=1)
|
||||||
return w_loss
|
return w_loss
|
||||||
|
|
||||||
|
def reg_loss(self, reg_factor=0.005):
|
||||||
|
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||||
|
return reg_factor * F.mse_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt.to(self._params['mag'].device), reduction='mean')
|
||||||
|
|
||||||
def train(self, mode=None):
|
def train(self, mode=None):
|
||||||
if mode is None :
|
if mode is None :
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 38 KiB |
Binary file not shown.
Before Width: | Height: | Size: 37 KiB |
|
@ -1,72 +0,0 @@
|
||||||
{
|
|
||||||
"Accuracy": 20.8,
|
|
||||||
"Time": [
|
|
||||||
51.4427050715,
|
|
||||||
0.4778038694999971
|
|
||||||
],
|
|
||||||
"Device": "TITAN RTX",
|
|
||||||
"Param_names": [
|
|
||||||
"Identity",
|
|
||||||
"FlipUD",
|
|
||||||
"FlipLR",
|
|
||||||
"Rotate",
|
|
||||||
"TranslateX",
|
|
||||||
"TranslateY",
|
|
||||||
"ShearX",
|
|
||||||
"ShearY",
|
|
||||||
"Contrast",
|
|
||||||
"Color",
|
|
||||||
"Brightness",
|
|
||||||
"Sharpness",
|
|
||||||
"Posterize",
|
|
||||||
"Solarize"
|
|
||||||
],
|
|
||||||
"Log": [
|
|
||||||
{
|
|
||||||
"epoch": 1,
|
|
||||||
"train_loss": 2.3032476902008057,
|
|
||||||
"val_loss": 2.2924728393554688,
|
|
||||||
"acc": 11.1,
|
|
||||||
"time": 51.920508941,
|
|
||||||
"param": [
|
|
||||||
0.07925213128328323,
|
|
||||||
0.08312409371137619,
|
|
||||||
0.08779778331518173,
|
|
||||||
0.0853320062160492,
|
|
||||||
0.08577536046504974,
|
|
||||||
0.057290591299533844,
|
|
||||||
0.0774931013584137,
|
|
||||||
0.08246791362762451,
|
|
||||||
0.047001805156469345,
|
|
||||||
0.07887403666973114,
|
|
||||||
0.05897113308310509,
|
|
||||||
0.05021947622299194,
|
|
||||||
0.07581018656492233,
|
|
||||||
0.050590354949235916
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"epoch": 2,
|
|
||||||
"train_loss": 2.171858787536621,
|
|
||||||
"val_loss": 2.078795909881592,
|
|
||||||
"acc": 20.8,
|
|
||||||
"time": 50.96490120200001,
|
|
||||||
"param": [
|
|
||||||
0.07892196625471115,
|
|
||||||
0.07488056272268295,
|
|
||||||
0.08041033148765564,
|
|
||||||
0.09144628793001175,
|
|
||||||
0.09114645421504974,
|
|
||||||
0.055715303868055344,
|
|
||||||
0.0672164335846901,
|
|
||||||
0.07994510233402252,
|
|
||||||
0.05105787515640259,
|
|
||||||
0.09191103279590607,
|
|
||||||
0.07849953323602676,
|
|
||||||
0.07014491409063339,
|
|
||||||
0.07624118775129318,
|
|
||||||
0.012463102117180824
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
{
|
|
||||||
"Accuracy": 31.369999999999997,
|
|
||||||
"Time": [
|
|
||||||
38.67262149066667,
|
|
||||||
0.4140408795968137
|
|
||||||
],
|
|
||||||
"Device": "TITAN RTX",
|
|
||||||
"Param_names": [
|
|
||||||
"Identity",
|
|
||||||
"FlipUD",
|
|
||||||
"FlipLR",
|
|
||||||
"Rotate",
|
|
||||||
"TranslateX",
|
|
||||||
"TranslateY",
|
|
||||||
"ShearX",
|
|
||||||
"ShearY",
|
|
||||||
"Contrast",
|
|
||||||
"Color",
|
|
||||||
"Brightness",
|
|
||||||
"Sharpness",
|
|
||||||
"Posterize",
|
|
||||||
"Solarize"
|
|
||||||
],
|
|
||||||
"Log": [
|
|
||||||
{
|
|
||||||
"epoch": 1,
|
|
||||||
"train_loss": 2.2571041584014893,
|
|
||||||
"val_loss": 2.212921142578125,
|
|
||||||
"acc": 20.169999999999998,
|
|
||||||
"time": 38.788926192000005,
|
|
||||||
"param": [
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"epoch": 2,
|
|
||||||
"train_loss": 2.212834358215332,
|
|
||||||
"val_loss": 2.043567180633545,
|
|
||||||
"acc": 25.009999999999998,
|
|
||||||
"time": 38.117478509,
|
|
||||||
"param": [
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"epoch": 3,
|
|
||||||
"train_loss": 2.091825008392334,
|
|
||||||
"val_loss": 1.9359350204467773,
|
|
||||||
"acc": 31.369999999999997,
|
|
||||||
"time": 39.111459771,
|
|
||||||
"param": [
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774,
|
|
||||||
0.0714285746216774
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -1,34 +0,0 @@
|
||||||
{
|
|
||||||
"Accuracy": 39.2,
|
|
||||||
"Time": [
|
|
||||||
3.9452463850000012,
|
|
||||||
0.2891758564900622
|
|
||||||
],
|
|
||||||
"Device": "TITAN RTX",
|
|
||||||
"Log": [
|
|
||||||
{
|
|
||||||
"epoch": 0,
|
|
||||||
"train_loss": 2.109266757965088,
|
|
||||||
"val_loss": 2.1106348037719727,
|
|
||||||
"acc": 22.3,
|
|
||||||
"time": 4.312489993,
|
|
||||||
"param": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"epoch": 1,
|
|
||||||
"train_loss": 1.7782783508300781,
|
|
||||||
"val_loss": 1.8776130676269531,
|
|
||||||
"acc": 33.76,
|
|
||||||
"time": 3.605794182000002,
|
|
||||||
"param": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"epoch": 2,
|
|
||||||
"train_loss": 1.8152618408203125,
|
|
||||||
"val_loss": 1.6963396072387695,
|
|
||||||
"acc": 39.2,
|
|
||||||
"time": 3.9174549800000023,
|
|
||||||
"param": null
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -5,9 +5,9 @@ from train_utils import *
|
||||||
|
|
||||||
tf_names = [
|
tf_names = [
|
||||||
## Geometric TF ##
|
## Geometric TF ##
|
||||||
#'Identity',
|
'Identity',
|
||||||
#'FlipUD',
|
'FlipUD',
|
||||||
#'FlipLR',
|
'FlipLR',
|
||||||
'Rotate',
|
'Rotate',
|
||||||
'TranslateX',
|
'TranslateX',
|
||||||
'TranslateY',
|
'TranslateY',
|
||||||
|
@ -37,8 +37,8 @@ else:
|
||||||
##########################################
|
##########################################
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
n_inner_iter = 1
|
n_inner_iter = 10
|
||||||
epochs = 2
|
epochs = 200
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
|
|
||||||
#### Classic ####
|
#### Classic ####
|
||||||
|
@ -57,7 +57,7 @@ if __name__ == "__main__":
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
|
||||||
print(str(model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1])
|
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
with open("res/log/%s.json" % "{}-{} epochs".format(str(model),epochs), "w+") as f:
|
with open("res/log/%s.json" % "{}-{} epochs".format(str(model),epochs), "w+") as f:
|
||||||
json.dump(out, f, indent=True)
|
json.dump(out, f, indent=True)
|
||||||
print('Log :\"',f.name, '\" saved !')
|
print('Log :\"',f.name, '\" saved !')
|
||||||
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
|
@ -79,12 +79,13 @@ if __name__ == "__main__":
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||||
print(str(aug_model),": acc", out["Accuracy"], "in (s?):", out["Time"][0], "+/-", out["Time"][1])
|
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
with open("res/log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f:
|
with open("res/log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f:
|
||||||
json.dump(out, f, indent=True)
|
json.dump(out, f, indent=True)
|
||||||
print('Log :\"',f.name, '\" saved !')
|
print('Log :\"',f.name, '\" saved !')
|
||||||
|
|
||||||
print('Execution Time : %.00f (s?)'%(time.process_time() - t0))
|
print('TF influence', TF_influence(log))
|
||||||
|
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
#'''
|
#'''
|
||||||
#### TF number tests ####
|
#### TF number tests ####
|
||||||
|
|
|
@ -528,7 +528,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
||||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
meta_opt.step()
|
meta_opt.step()
|
||||||
model['data_aug'].adjust_prob() #Contrainte sum(proba)=1
|
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
|
||||||
|
|
||||||
print("Copy ", countcopy)
|
print("Copy ", countcopy)
|
||||||
return log
|
return log
|
||||||
|
@ -588,7 +588,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||||
|
|
||||||
if fmodel._data_augmentation: #Weight loss
|
if fmodel._data_augmentation: #Weight loss
|
||||||
w_loss = fmodel['data_aug'].loss_weight().to(device)
|
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||||
loss = loss * w_loss
|
loss = loss * w_loss
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
#'''
|
#'''
|
||||||
|
@ -605,7 +605,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
if(high_grad_track and i%inner_it==0): #Perform Meta step
|
if(high_grad_track and i%inner_it==0): #Perform Meta step
|
||||||
#print("meta")
|
#print("meta")
|
||||||
#Peu utile si high_grad_track = False
|
#Peu utile si high_grad_track = False
|
||||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||||
|
|
||||||
#print_graph(val_loss)
|
#print_graph(val_loss)
|
||||||
|
|
||||||
|
@ -616,15 +616,15 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||||
|
|
||||||
meta_opt.step()
|
meta_opt.step()
|
||||||
model['data_aug'].adjust_prob(soft=False) #Contrainte sum(proba)=1
|
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||||
|
|
||||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||||
|
|
||||||
tf = time.process_time()
|
tf = time.process_time()
|
||||||
|
|
||||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||||
|
|
||||||
if(not high_grad_track):
|
if(not high_grad_track):
|
||||||
countcopy+=1
|
countcopy+=1
|
||||||
|
@ -643,7 +643,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
if(print_freq and epoch%print_freq==0):
|
if(print_freq and epoch%print_freq==0):
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
print('Epoch : %d/%d'%(epoch,epochs))
|
print('Epoch : %d/%d'%(epoch,epochs))
|
||||||
print('Time : %.00f s'%(tf - t0))
|
print('Time : %.00f'%(tf - t0))
|
||||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||||
print('Accuracy :', accuracy)
|
print('Accuracy :', accuracy)
|
||||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||||
|
@ -651,6 +651,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
#print('proba grad',model['data_aug']['prob'].grad)
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
print('TF Mag :', model['data_aug']['mag'].data)
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||||
|
print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||||
#############
|
#############
|
||||||
#### Log ####
|
#### Log ####
|
||||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||||
|
|
|
@ -254,6 +254,11 @@ def print_torch_mem(add_info=''):
|
||||||
torch.cuda.max_memory_cached()/ mega_bytes)
|
torch.cuda.max_memory_cached()/ mega_bytes)
|
||||||
print(string)
|
print(string)
|
||||||
|
|
||||||
|
def TF_influence(log):
|
||||||
|
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
|
||||||
|
return np.mean(proba, axis=1)*np.mean(mag, axis=1) #Pourrait etre interessant de multiplier avant le mean
|
||||||
|
|
||||||
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||||
def __init__(self, patience, end_train=1):
|
def __init__(self, patience, end_train=1):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue