mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Test brutus suite
This commit is contained in:
parent
a772d13b83
commit
61dad1ad78
4 changed files with 90 additions and 32 deletions
|
@ -2,10 +2,11 @@ from utils import *
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
#'''
|
'''
|
||||||
files=[
|
files=[
|
||||||
"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
|
"res/brutus-tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx1-Mag)-LeNet)-150epochs(dataug:0)-1in_it-0.json",
|
||||||
]
|
]
|
||||||
|
|
||||||
for idx, file in enumerate(files):
|
for idx, file in enumerate(files):
|
||||||
|
@ -14,7 +15,7 @@ if __name__ == "__main__":
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
||||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||||
#'''
|
'''
|
||||||
## Loss , Acc, Proba = f(epoch) ##
|
## Loss , Acc, Proba = f(epoch) ##
|
||||||
#plot_compare(filenames=files, fig_name="res/compare")
|
#plot_compare(filenames=files, fig_name="res/compare")
|
||||||
|
|
||||||
|
@ -72,3 +73,18 @@ if __name__ == "__main__":
|
||||||
plt.savefig(fig_name, bbox_inches='tight')
|
plt.savefig(fig_name, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
#Res print
|
||||||
|
nb_run=3
|
||||||
|
accs = []
|
||||||
|
times = []
|
||||||
|
files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Mix1.0-14TFx2-Mag)-LeNet)-150epochs(dataug:0)-1in_it-%s.json"%str(run) for run in range(nb_run)]
|
||||||
|
|
||||||
|
for idx, file in enumerate(files):
|
||||||
|
#legend+=str(idx)+'-'+file+'\n'
|
||||||
|
with open(file) as json_file:
|
||||||
|
data = json.load(json_file)
|
||||||
|
accs.append(data['Accuracy'])
|
||||||
|
times.append(data['Time'][0])
|
||||||
|
|
||||||
|
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
||||||
|
|
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_mag=True, shared_mag=True):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||||
super(Data_augV5, self).__init__()
|
super(Data_augV5, self).__init__()
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
|
|
||||||
|
@ -555,6 +555,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||||
|
|
||||||
#Distribution
|
#Distribution
|
||||||
|
self._fixed_prob=fixed_prob
|
||||||
self._samples = []
|
self._samples = []
|
||||||
self._mix_dist = False
|
self._mix_dist = False
|
||||||
if mix_dist != 0.0:
|
if mix_dist != 0.0:
|
||||||
|
@ -570,12 +571,12 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self._data_augmentation:
|
self._samples = []
|
||||||
|
if self._data_augmentation and TF.random.random() < 0.5:
|
||||||
device = x.device
|
device = x.device
|
||||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||||
|
|
||||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
self._samples = []
|
|
||||||
|
|
||||||
for _ in range(self._N_seqTF):
|
for _ in range(self._N_seqTF):
|
||||||
## Echantillonage ##
|
## Echantillonage ##
|
||||||
|
@ -584,7 +585,8 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
if not self._mix_dist:
|
if not self._mix_dist:
|
||||||
self._distrib = uniforme_dist
|
self._distrib = uniforme_dist
|
||||||
else:
|
else:
|
||||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
|
self._distrib = (self._mix_factor*prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||||
|
|
||||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||||
sample = cat_distrib.sample()
|
sample = cat_distrib.sample()
|
||||||
|
@ -622,7 +624,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||||
|
if not self._fixed_prob:
|
||||||
if soft :
|
if soft :
|
||||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||||
else:
|
else:
|
||||||
|
@ -630,10 +632,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||||
|
|
||||||
|
if not self._fixed_mag:
|
||||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||||
|
|
||||||
def loss_weight(self):
|
def loss_weight(self):
|
||||||
|
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||||
|
|
||||||
|
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||||
# 1 seule TF
|
# 1 seule TF
|
||||||
#self._sample = self._samples[-1]
|
#self._sample = self._samples[-1]
|
||||||
#w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
#w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||||
|
@ -648,7 +654,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_seqTF)
|
||||||
w_loss += tmp_w
|
w_loss += tmp_w
|
||||||
|
|
||||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
w_loss = w_loss * prob/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||||
w_loss = torch.sum(w_loss,dim=1)
|
w_loss = torch.sum(w_loss,dim=1)
|
||||||
return w_loss
|
return w_loss
|
||||||
|
|
||||||
|
@ -676,13 +682,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return self._params[key]
|
return self._params[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
dist_param=''
|
||||||
|
if self._fixed_prob: dist_param+='Fx'
|
||||||
mag_param='Mag'
|
mag_param='Mag'
|
||||||
if self._fixed_mag: mag_param+= 'Fx'
|
if self._fixed_mag: mag_param+= 'Fx'
|
||||||
if self._shared_mag: mag_param+= 'Sh'
|
if self._shared_mag: mag_param+= 'Sh'
|
||||||
if not self._mix_dist:
|
if not self._mix_dist:
|
||||||
return "Data_augV5(Uniform-%dTFx%d-%s)" % (self._nb_tf, self._N_seqTF, mag_param)
|
return "Data_augV5(Uniform%s-%dTFx%d-%s)rand0.5" % (dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
else:
|
else:
|
||||||
return "Data_augV5(Mix%.1f-%dTFx%d-%s)" % (self._mix_factor, self._nb_tf, self._N_seqTF, mag_param)
|
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
|
||||||
|
|
||||||
class Augmented_model(nn.Module):
|
class Augmented_model(nn.Module):
|
||||||
|
|
|
@ -19,6 +19,10 @@ tf_names = [
|
||||||
#'BTranslateY',
|
#'BTranslateY',
|
||||||
#'BShearX',
|
#'BShearX',
|
||||||
#'BShearY',
|
#'BShearY',
|
||||||
|
#'BadTranslateX',
|
||||||
|
#'BadTranslateX_neg',
|
||||||
|
#'BadTranslateY',
|
||||||
|
#'BadTranslateY_neg',
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
'Contrast',
|
'Contrast',
|
||||||
|
@ -74,11 +78,11 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=None)
|
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
||||||
|
|
||||||
####
|
####
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
|
@ -99,7 +103,7 @@ if __name__ == "__main__":
|
||||||
#'''
|
#'''
|
||||||
res_folder="res/brutus-tests/"
|
res_folder="res/brutus-tests/"
|
||||||
epochs= 150
|
epochs= 150
|
||||||
inner_its = [0, 1, 10]
|
inner_its = [1, 10]
|
||||||
dist_mix = [0.0, 0.5, 1]
|
dist_mix = [0.0, 0.5, 1]
|
||||||
dataug_epoch_starts= [0]
|
dataug_epoch_starts= [0]
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
|
|
|
@ -28,6 +28,31 @@ TF_dict={ #Dataugv4
|
||||||
#'Equalize': (lambda mag: None),
|
#'Equalize': (lambda mag: None),
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
|
'''
|
||||||
|
TF_dict={ #Dataugv5 #AutoAugment
|
||||||
|
## Geometric TF ##
|
||||||
|
'Identity' : (lambda x, mag: x),
|
||||||
|
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||||
|
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||||
|
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||||
|
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||||
|
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||||
|
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||||
|
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
|
#Non fonctionnel
|
||||||
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
#'Equalize': (lambda mag: None),
|
||||||
|
}
|
||||||
|
'''
|
||||||
TF_dict={ #Dataugv5
|
TF_dict={ #Dataugv5
|
||||||
## Geometric TF ##
|
## Geometric TF ##
|
||||||
'Identity' : (lambda x, mag: x),
|
'Identity' : (lambda x, mag: x),
|
||||||
|
@ -45,6 +70,11 @@ TF_dict={ #Dataugv5
|
||||||
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=0))),
|
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=0))),
|
||||||
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=1))),
|
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3*3), zero_pos=1))),
|
||||||
|
|
||||||
|
'BadTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=0))),
|
||||||
|
'BadTranslateX_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=0))),
|
||||||
|
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
|
||||||
|
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
@ -70,15 +100,15 @@ def float_image(int_image):
|
||||||
#def rand_inverse(value):
|
#def rand_inverse(value):
|
||||||
# return value if random.random() < 0.5 else -value
|
# return value if random.random() < 0.5 else -value
|
||||||
|
|
||||||
def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
#def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
real_max = int_parameter(mag, maxval=maxval)
|
# real_max = int_parameter(mag, maxval=maxval)
|
||||||
if not minval : minval = -real_max
|
# if not minval : minval = -real_max
|
||||||
return random.randint(minval, real_max)
|
# return random.randint(minval, real_max)
|
||||||
|
|
||||||
def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
#def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
real_max = float_parameter(mag, maxval=maxval)
|
# real_max = float_parameter(mag, maxval=maxval)
|
||||||
if not minval : minval = -real_max
|
# if not minval : minval = -real_max
|
||||||
return random.uniform(minval, real_max)
|
# return random.uniform(minval, real_max)
|
||||||
|
|
||||||
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
real_max = float_parameter(mag, maxval=maxval)
|
real_max = float_parameter(mag, maxval=maxval)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue