mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Log et utils pour Dataugv5
This commit is contained in:
parent
860d9f1bbb
commit
f4bdd9bca5
6 changed files with 57 additions and 45 deletions
|
@ -2,21 +2,18 @@ from utils import *
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
#### Comparison ####
|
'''
|
||||||
|
|
||||||
## Loss , Acc, Proba = f(epoch) ##
|
|
||||||
files=[
|
files=[
|
||||||
#"res/log/LeNet-100 epochs.json",
|
"res/log/Aug_mod(Data_augV5(Mix0.5-11TFx1-MagSh)-LeNet)-2 epochs (dataug:0)- 1 in_it.json",
|
||||||
#"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV3(Uniform-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV4(Mix 0,5-3 TF)-LeNet)-100 epochs (dataug:0)- 1 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV4(Mix 0.5-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
|
|
||||||
#"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for idx, file in enumerate(files):
|
||||||
|
#legend+=str(idx)+'-'+file+'\n'
|
||||||
|
with open(file) as json_file:
|
||||||
|
data = json.load(json_file)
|
||||||
|
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
||||||
|
'''
|
||||||
|
## Loss , Acc, Proba = f(epoch) ##
|
||||||
#plot_compare(filenames=files, fig_name="res/compare")
|
#plot_compare(filenames=files, fig_name="res/compare")
|
||||||
|
|
||||||
## Acc, Time, Epochs = f(n_tf) ##
|
## Acc, Time, Epochs = f(n_tf) ##
|
||||||
|
|
|
@ -38,8 +38,8 @@ data_test = torchvision.datasets.CIFAR10(
|
||||||
"./data", train=False, download=True, transform=transform
|
"./data", train=False, download=True, transform=transform
|
||||||
)
|
)
|
||||||
#'''
|
#'''
|
||||||
train_subset_indices=range(int(len(data_train)/2))
|
#train_subset_indices=range(int(len(data_train)/2))
|
||||||
#train_subset_indices=range(BATCH_SIZE*10)
|
train_subset_indices=range(BATCH_SIZE*10)
|
||||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||||
|
|
||||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||||
|
|
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, shared_mag=True):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_mag=True, shared_mag=True):
|
||||||
super(Data_augV5, self).__init__()
|
super(Data_augV5, self).__init__()
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
|
|
||||||
|
@ -543,14 +543,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
|
|
||||||
self._N_seqTF = N_TF
|
self._N_seqTF = N_TF
|
||||||
self._shared_mag = shared_mag
|
self._shared_mag = shared_mag
|
||||||
|
self._fixed_mag = fixed_mag
|
||||||
|
|
||||||
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||||
"mag" : nn.Parameter(torch.tensor(0.5) if shared_mag
|
"mag" : nn.Parameter(torch.tensor(0.5) if self._shared_mag
|
||||||
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
|
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
|
||||||
})
|
})
|
||||||
|
|
||||||
self._samples = []
|
self._samples = []
|
||||||
|
|
||||||
self._mix_dist = False
|
self._mix_dist = False
|
||||||
|
@ -594,23 +594,19 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
|
|
||||||
if smp_x.shape[0]!=0: #if there's data to TF
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
||||||
|
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
|
||||||
|
|
||||||
tf=self._TF[tf_idx]
|
tf=self._TF[tf_idx]
|
||||||
#print(magnitude)
|
#print(magnitude)
|
||||||
|
|
||||||
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
#In place
|
||||||
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||||
|
|
||||||
|
#Out of place
|
||||||
|
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||||
idx= mask.nonzero()
|
idx= mask.nonzero()
|
||||||
#print('-'*8)
|
|
||||||
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
#print(idx.shape, smp_x.shape)
|
|
||||||
#print(idx[0], tf_idx)
|
|
||||||
#print(smp_x[0,])
|
|
||||||
#x=x.view(-1,3*32*32)
|
|
||||||
#smp_x=smp_x.view(-1,3*32*32)
|
|
||||||
x=x.scatter(dim=0, index=idx, src=smp_x)
|
x=x.scatter(dim=0, index=idx, src=smp_x)
|
||||||
#x=x.view(-1,3,32,32)
|
|
||||||
#print(x[0,])
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -663,10 +659,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return self._params[key]
|
return self._params[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
mag_param='Mag'
|
||||||
|
if self._fixed_mag: mag_param+= 'Fx'
|
||||||
|
if self._shared_mag: mag_param+= 'Sh'
|
||||||
if not self._mix_dist:
|
if not self._mix_dist:
|
||||||
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
|
return "Data_augV5(Uniform-%dTFx%d-%s)" % (self._nb_tf, self._N_seqTF, mag_param)
|
||||||
else:
|
else:
|
||||||
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
return "Data_augV5(Mix%.1f-%dTFx%d-%s)" % (self._mix_factor, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
|
||||||
|
|
||||||
class Augmented_model(nn.Module):
|
class Augmented_model(nn.Module):
|
||||||
|
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, shared_mag=True), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
|
|
|
@ -650,9 +650,11 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
print('TF Proba :', model['data_aug']['prob'].data)
|
print('TF Proba :', model['data_aug']['prob'].data)
|
||||||
#print('proba grad',model['data_aug']['prob'].grad)
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
print('TF Mag :', model['data_aug']['mag'].data)
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
print('Mag grad',model['data_aug']['mag'].grad)
|
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||||
#############
|
#############
|
||||||
#### Log ####
|
#### Log ####
|
||||||
|
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||||
|
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||||
data={
|
data={
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"train_loss": loss.item(),
|
"train_loss": loss.item(),
|
||||||
|
@ -660,7 +662,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
"acc": accuracy,
|
"acc": accuracy,
|
||||||
"time": tf - t0,
|
"time": tf - t0,
|
||||||
|
|
||||||
"param": [p.item() for p in model['data_aug']['prob']],
|
"param": param #if isinstance(model['data_aug'], Data_augV5)
|
||||||
|
#else [p.item() for p in model['data_aug']['prob']],
|
||||||
}
|
}
|
||||||
log.append(data)
|
log.append(data)
|
||||||
#############
|
#############
|
||||||
|
|
|
@ -52,28 +52,41 @@ def plot_resV2(log, fig_name='res', param_names=None):
|
||||||
|
|
||||||
epochs = [x["epoch"] for x in log]
|
epochs = [x["epoch"] for x in log]
|
||||||
|
|
||||||
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 15))
|
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(30, 15))
|
||||||
|
|
||||||
ax[0, 0].set_title('Loss')
|
ax[0, 0].set_title('Loss')
|
||||||
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||||
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||||
ax[0, 0].legend()
|
ax[0, 0].legend()
|
||||||
|
|
||||||
ax[0, 1].set_title('Acc')
|
ax[1, 0].set_title('Acc')
|
||||||
ax[0, 1].plot(epochs,[x["acc"] for x in log])
|
ax[1, 0].plot(epochs,[x["acc"] for x in log])
|
||||||
|
|
||||||
if log[0]["param"]!= None:
|
if log[0]["param"]!= None:
|
||||||
ax[1, 1].set_title('Prob')
|
|
||||||
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||||
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
#proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
ax[1, 1].stackplot(epochs, proba, labels=param_names)
|
proba=[[x["param"][idx]['p'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
ax[1, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
mag=[[x["param"][idx]['m'] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||||
|
|
||||||
ax[1, 0].set_title('Mean prob')
|
ax[0, 1].set_title('Prob =f(epoch)')
|
||||||
mean = np.mean([x["param"] for x in log], axis=0)
|
ax[0, 1].stackplot(epochs, proba, labels=param_names)
|
||||||
std = np.std([x["param"] for x in log], axis=0)
|
#ax[0, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
ax[1, 0].bar(param_names, mean, yerr=std)
|
|
||||||
plt.sca(ax[1, 0]), plt.xticks(rotation=90)
|
ax[1, 1].set_title('Prob =f(TF)')
|
||||||
|
mean = np.mean(proba, axis=1)
|
||||||
|
std = np.std(proba, axis=1)
|
||||||
|
ax[1, 1].bar(param_names, mean, yerr=std)
|
||||||
|
plt.sca(ax[1, 1]), plt.xticks(rotation=90)
|
||||||
|
|
||||||
|
ax[0, 2].set_title('Mag =f(epoch)')
|
||||||
|
ax[0, 2].stackplot(epochs, mag, labels=param_names)
|
||||||
|
ax[0, 2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
|
|
||||||
|
ax[1, 2].set_title('Mag =f(TF)')
|
||||||
|
mean = np.mean(mag, axis=1)
|
||||||
|
std = np.std(mag, axis=1)
|
||||||
|
ax[1, 2].bar(param_names, mean, yerr=std)
|
||||||
|
plt.sca(ax[1, 2]), plt.xticks(rotation=90)
|
||||||
|
|
||||||
|
|
||||||
fig_name = fig_name.replace('.',',')
|
fig_name = fig_name.replace('.',',')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue