mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Modif solarize (Tjrs pas differentiable...)
This commit is contained in:
parent
4a7e73088d
commit
d822f8f92e
5 changed files with 56 additions and 42 deletions
|
@ -2,11 +2,12 @@ from utils import *
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
'''
|
#'''
|
||||||
files=[
|
files=[
|
||||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
"res/brutus-tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx1-Mag)-LeNet)-150epochs(dataug:0)-1in_it-0.json",
|
#"res/brutus-tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx1-Mag)-LeNet)-150epochs(dataug:0)-1in_it-0.json",
|
||||||
|
"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||||
]
|
]
|
||||||
|
|
||||||
for idx, file in enumerate(files):
|
for idx, file in enumerate(files):
|
||||||
|
@ -15,7 +16,7 @@ if __name__ == "__main__":
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
||||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||||
'''
|
#'''
|
||||||
## Loss , Acc, Proba = f(epoch) ##
|
## Loss , Acc, Proba = f(epoch) ##
|
||||||
#plot_compare(filenames=files, fig_name="res/compare")
|
#plot_compare(filenames=files, fig_name="res/compare")
|
||||||
|
|
||||||
|
@ -75,6 +76,7 @@ if __name__ == "__main__":
|
||||||
'''
|
'''
|
||||||
|
|
||||||
#Res print
|
#Res print
|
||||||
|
'''
|
||||||
nb_run=3
|
nb_run=3
|
||||||
accs = []
|
accs = []
|
||||||
times = []
|
times = []
|
||||||
|
@ -88,3 +90,4 @@ if __name__ == "__main__":
|
||||||
times.append(data['Time'][0])
|
times.append(data['Time'][0])
|
||||||
|
|
||||||
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
||||||
|
'''
|
|
@ -692,7 +692,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
else:
|
else:
|
||||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
|
||||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh
|
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||||
super(RandAug, self).__init__()
|
super(RandAug, self).__init__()
|
||||||
|
|
||||||
|
@ -773,9 +773,9 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||||
|
|
||||||
class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
class RandAugUDA(nn.Module): #RandAugment from UDA (for DA during training)
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||||
super(RandAug, self).__init__()
|
super(RandAugUDA, self).__init__()
|
||||||
|
|
||||||
self._data_augmentation = True
|
self._data_augmentation = True
|
||||||
|
|
||||||
|
@ -786,7 +786,7 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
|
|
||||||
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.tensor(0.5)),
|
"prob": nn.Parameter(torch.tensor(0.5).unsqueeze(dim=0)),
|
||||||
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))),
|
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))),
|
||||||
})
|
})
|
||||||
self._shared_mag = True
|
self._shared_mag = True
|
||||||
|
@ -794,12 +794,10 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
|
|
||||||
self._op_list =[]
|
self._op_list =[]
|
||||||
for tf in self._TF:
|
for tf in self._TF:
|
||||||
for mag in range(0.1, self._params['mag'], 0.1):
|
for mag in range(1, int(self._params['mag']*10), 1):
|
||||||
op_list+=[(tf, self._params['prob'], mag)]
|
self._op_list+=[(tf, self._params['prob'].item(), mag/10)]
|
||||||
self._nb_op = len(self._op_list)
|
self._nb_op = len(self._op_list)
|
||||||
|
|
||||||
print(self._op_list)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||||
device = x.device
|
device = x.device
|
||||||
|
@ -821,16 +819,16 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
smps_x=[]
|
smps_x=[]
|
||||||
|
|
||||||
for op_idx in range(self._nb_op):
|
for op_idx in range(self._nb_op):
|
||||||
mask = sampled_TF==tf_idx #Create selection mask
|
mask = sampled_TF==op_idx #Create selection mask
|
||||||
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
||||||
|
|
||||||
if smp_x.shape[0]!=0: #if there's data to TF
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
if TF.random.random() < self.op_list[op_idx][1]:
|
if TF.random.random() < self._op_list[op_idx][1]:
|
||||||
magnitude=self.op_list[op_idx][2]
|
magnitude=self._op_list[op_idx][2]
|
||||||
tf=self.op_list[op_idx][0]
|
tf=self._op_list[op_idx][0]
|
||||||
|
|
||||||
#In place
|
#In place
|
||||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
x[mask]=self._TF_dict[tf](x=smp_x, mag=torch.tensor(magnitude, device=x.device))
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -847,7 +845,7 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
if mode is None :
|
if mode is None :
|
||||||
mode=self._data_augmentation
|
mode=self._data_augmentation
|
||||||
self.augment(mode=mode) #Inutile si mode=None
|
self.augment(mode=mode) #Inutile si mode=None
|
||||||
super(RandAug, self).train(mode)
|
super(RandAugUDA, self).train(mode)
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
self.train(mode=False)
|
self.train(mode=False)
|
||||||
|
@ -859,7 +857,7 @@ class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
return self._params[key]
|
return self._params[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
return "RandAugUDA(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||||
|
|
||||||
class Augmented_model(nn.Module):
|
class Augmented_model(nn.Module):
|
||||||
def __init__(self, data_augmenter, model):
|
def __init__(self, data_augmenter, model):
|
||||||
|
|
|
@ -5,21 +5,21 @@ from train_utils import *
|
||||||
|
|
||||||
tf_names = [
|
tf_names = [
|
||||||
## Geometric TF ##
|
## Geometric TF ##
|
||||||
'Identity',
|
#'Identity',
|
||||||
'FlipUD',
|
#'FlipUD',
|
||||||
'FlipLR',
|
#'FlipLR',
|
||||||
'Rotate',
|
#'Rotate',
|
||||||
'TranslateX',
|
#'TranslateX',
|
||||||
'TranslateY',
|
#'TranslateY',
|
||||||
'ShearX',
|
#'ShearX',
|
||||||
'ShearY',
|
#'ShearY',
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
'Contrast',
|
#'Contrast',
|
||||||
'Color',
|
#'Color',
|
||||||
'Brightness',
|
#'Brightness',
|
||||||
'Sharpness',
|
#'Sharpness',
|
||||||
'Posterize',
|
#'Posterize',
|
||||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
#Color TF (Common mag scale)
|
#Color TF (Common mag scale)
|
||||||
|
@ -48,6 +48,7 @@ tf_names = [
|
||||||
#'BadSharpness',
|
#'BadSharpness',
|
||||||
#'BadContrast',
|
#'BadContrast',
|
||||||
#'BadBrightness',
|
#'BadBrightness',
|
||||||
|
|
||||||
#Non fonctionnel
|
#Non fonctionnel
|
||||||
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
|
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
|
||||||
#'Equalize',
|
#'Equalize',
|
||||||
|
@ -63,8 +64,8 @@ else:
|
||||||
##########################################
|
##########################################
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
n_inner_iter = 0
|
n_inner_iter = 10
|
||||||
epochs = 150
|
epochs = 1
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
|
|
||||||
#### Classic ####
|
#### Classic ####
|
||||||
|
@ -94,12 +95,12 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=True, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=None)
|
||||||
|
|
||||||
####
|
####
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
|
|
|
@ -651,7 +651,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
print('TF Proba :', model['data_aug']['prob'].data)
|
print('TF Proba :', model['data_aug']['prob'].data)
|
||||||
#print('proba grad',model['data_aug']['prob'].grad)
|
#print('proba grad',model['data_aug']['prob'].grad)
|
||||||
print('TF Mag :', model['data_aug']['mag'].data)
|
print('TF Mag :', model['data_aug']['mag'].data)
|
||||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
print('Mag grad',model['data_aug']['mag'].grad)
|
||||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||||
#############
|
#############
|
||||||
#### Log ####
|
#### Log ####
|
||||||
|
|
|
@ -298,12 +298,12 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH
|
||||||
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
||||||
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
|
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
|
||||||
batch_size, channels, h, w = x.shape
|
batch_size, channels, h, w = x.shape
|
||||||
imgs=[]
|
#imgs=[]
|
||||||
for idx, t in enumerate(thresholds): #Operation par image
|
#for idx, t in enumerate(thresholds): #Operation par image
|
||||||
mask = x[idx] > t #Perte du gradient
|
# mask = x[idx] > t #Perte du gradient
|
||||||
#In place
|
#In place
|
||||||
inv_x = 1-x[idx][mask]
|
# inv_x = 1-x[idx][mask]
|
||||||
x[idx][mask]=inv_x
|
# x[idx][mask]=inv_x
|
||||||
#
|
#
|
||||||
|
|
||||||
#Out of place
|
#Out of place
|
||||||
|
@ -316,6 +316,18 @@ def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
||||||
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||||
#
|
#
|
||||||
|
|
||||||
|
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
|
#print(thresholds.grad_fn)
|
||||||
|
x=torch.where(x>thresholds,1-x, x)
|
||||||
|
#print(mask.grad_fn)
|
||||||
|
|
||||||
|
#x=x.min(thresholds)
|
||||||
|
#inv_x = 1-x[mask]
|
||||||
|
#x=x.where(x<thresholds,1-x)
|
||||||
|
#x[mask]=inv_x
|
||||||
|
#x=x.masked_scatter(mask, inv_x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue