mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout RandAugment
This commit is contained in:
parent
3c2022de32
commit
4a7e73088d
4 changed files with 249 additions and 37 deletions
170
higher/dataug.py
170
higher/dataug.py
|
@ -659,7 +659,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
return w_loss
|
return w_loss
|
||||||
|
|
||||||
def reg_loss(self, reg_factor=0.005):
|
def reg_loss(self, reg_factor=0.005):
|
||||||
if self._fixed_mag:
|
if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF
|
||||||
return torch.tensor(0)
|
return torch.tensor(0)
|
||||||
else:
|
else:
|
||||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||||
|
@ -692,6 +692,174 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
else:
|
else:
|
||||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||||
|
|
||||||
|
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||||
|
super(RandAug, self).__init__()
|
||||||
|
|
||||||
|
self._data_augmentation = True
|
||||||
|
|
||||||
|
self._TF_dict = TF_dict
|
||||||
|
self._TF= list(self._TF_dict.keys())
|
||||||
|
self._nb_tf= len(self._TF)
|
||||||
|
self._N_seqTF = N_TF
|
||||||
|
|
||||||
|
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
||||||
|
self._params = nn.ParameterDict({
|
||||||
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #pas utilise
|
||||||
|
"mag" : nn.Parameter(torch.tensor(float(mag))),
|
||||||
|
})
|
||||||
|
self._shared_mag = True
|
||||||
|
self._fixed_mag = True
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||||
|
device = x.device
|
||||||
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||||
|
|
||||||
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
|
|
||||||
|
for _ in range(self._N_seqTF):
|
||||||
|
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
|
||||||
|
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||||
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
|
||||||
|
sample = cat_distrib.sample()
|
||||||
|
|
||||||
|
## Transformations ##
|
||||||
|
x = self.apply_TF(x, sample)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_TF(self, x, sampled_TF):
|
||||||
|
smps_x=[]
|
||||||
|
|
||||||
|
for tf_idx in range(self._nb_tf):
|
||||||
|
mask = sampled_TF==tf_idx #Create selection mask
|
||||||
|
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
||||||
|
|
||||||
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
|
magnitude=self._params["mag"].detach()
|
||||||
|
|
||||||
|
tf=self._TF[tf_idx]
|
||||||
|
#print(magnitude)
|
||||||
|
|
||||||
|
#In place
|
||||||
|
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def adjust_param(self, soft=False):
|
||||||
|
pass #Pas de parametre a opti
|
||||||
|
|
||||||
|
def loss_weight(self):
|
||||||
|
return 1 #Pas d'echantillon = pas de ponderation
|
||||||
|
|
||||||
|
def reg_loss(self, reg_factor=0.005):
|
||||||
|
return torch.tensor(0) #Pas de regularisation
|
||||||
|
|
||||||
|
def train(self, mode=None):
|
||||||
|
if mode is None :
|
||||||
|
mode=self._data_augmentation
|
||||||
|
self.augment(mode=mode) #Inutile si mode=None
|
||||||
|
super(RandAug, self).train(mode)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.train(mode=False)
|
||||||
|
|
||||||
|
def augment(self, mode=True):
|
||||||
|
self._data_augmentation=mode
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._params[key]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||||
|
|
||||||
|
class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||||
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||||
|
super(RandAug, self).__init__()
|
||||||
|
|
||||||
|
self._data_augmentation = True
|
||||||
|
|
||||||
|
self._TF_dict = TF_dict
|
||||||
|
self._TF= list(self._TF_dict.keys())
|
||||||
|
self._nb_tf= len(self._TF)
|
||||||
|
self._N_seqTF = N_TF
|
||||||
|
|
||||||
|
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
||||||
|
self._params = nn.ParameterDict({
|
||||||
|
"prob": nn.Parameter(torch.tensor(0.5)),
|
||||||
|
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))),
|
||||||
|
})
|
||||||
|
self._shared_mag = True
|
||||||
|
self._fixed_mag = True
|
||||||
|
|
||||||
|
self._op_list =[]
|
||||||
|
for tf in self._TF:
|
||||||
|
for mag in range(0.1, self._params['mag'], 0.1):
|
||||||
|
op_list+=[(tf, self._params['prob'], mag)]
|
||||||
|
self._nb_op = len(self._op_list)
|
||||||
|
|
||||||
|
print(self._op_list)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||||
|
device = x.device
|
||||||
|
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||||
|
|
||||||
|
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||||
|
|
||||||
|
for _ in range(self._N_seqTF):
|
||||||
|
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
|
||||||
|
uniforme_dist = torch.ones(1, self._nb_op, device=device).softmax(dim=1)
|
||||||
|
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_op), device=device)*uniforme_dist)
|
||||||
|
sample = cat_distrib.sample()
|
||||||
|
|
||||||
|
## Transformations ##
|
||||||
|
x = self.apply_TF(x, sample)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_TF(self, x, sampled_TF):
|
||||||
|
smps_x=[]
|
||||||
|
|
||||||
|
for op_idx in range(self._nb_op):
|
||||||
|
mask = sampled_TF==tf_idx #Create selection mask
|
||||||
|
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
||||||
|
|
||||||
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
|
if TF.random.random() < self.op_list[op_idx][1]:
|
||||||
|
magnitude=self.op_list[op_idx][2]
|
||||||
|
tf=self.op_list[op_idx][0]
|
||||||
|
|
||||||
|
#In place
|
||||||
|
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def adjust_param(self, soft=False):
|
||||||
|
pass #Pas de parametre a opti
|
||||||
|
|
||||||
|
def loss_weight(self):
|
||||||
|
return 1 #Pas d'echantillon = pas de ponderation
|
||||||
|
|
||||||
|
def reg_loss(self, reg_factor=0.005):
|
||||||
|
return torch.tensor(0) #Pas de regularisation
|
||||||
|
|
||||||
|
def train(self, mode=None):
|
||||||
|
if mode is None :
|
||||||
|
mode=self._data_augmentation
|
||||||
|
self.augment(mode=mode) #Inutile si mode=None
|
||||||
|
super(RandAug, self).train(mode)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.train(mode=False)
|
||||||
|
|
||||||
|
def augment(self, mode=True):
|
||||||
|
self._data_augmentation=mode
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._params[key]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||||
|
|
||||||
class Augmented_model(nn.Module):
|
class Augmented_model(nn.Module):
|
||||||
def __init__(self, data_augmenter, model):
|
def __init__(self, data_augmenter, model):
|
||||||
|
|
|
@ -14,6 +14,26 @@ tf_names = [
|
||||||
'ShearX',
|
'ShearX',
|
||||||
'ShearY',
|
'ShearY',
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast',
|
||||||
|
'Color',
|
||||||
|
'Brightness',
|
||||||
|
'Sharpness',
|
||||||
|
'Posterize',
|
||||||
|
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
|
#Color TF (Common mag scale)
|
||||||
|
#'+Contrast',
|
||||||
|
#'+Color',
|
||||||
|
#'+Brightness',
|
||||||
|
#'+Sharpness',
|
||||||
|
#'-Contrast',
|
||||||
|
#'-Color',
|
||||||
|
#'-Brightness',
|
||||||
|
#'-Sharpness',
|
||||||
|
#'=Posterize',
|
||||||
|
#'=Solarize',
|
||||||
|
|
||||||
#'BRotate',
|
#'BRotate',
|
||||||
#'BTranslateX',
|
#'BTranslateX',
|
||||||
#'BTranslateY',
|
#'BTranslateY',
|
||||||
|
@ -24,14 +44,10 @@ tf_names = [
|
||||||
#'BadTranslateY',
|
#'BadTranslateY',
|
||||||
#'BadTranslateY_neg',
|
#'BadTranslateY_neg',
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
#'BadColor',
|
||||||
'Contrast',
|
#'BadSharpness',
|
||||||
'Color',
|
#'BadContrast',
|
||||||
'Brightness',
|
#'BadBrightness',
|
||||||
'Sharpness',
|
|
||||||
'Posterize',
|
|
||||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
|
||||||
|
|
||||||
#Non fonctionnel
|
#Non fonctionnel
|
||||||
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
|
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
|
||||||
#'Equalize',
|
#'Equalize',
|
||||||
|
@ -47,8 +63,8 @@ else:
|
||||||
##########################################
|
##########################################
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
n_inner_iter = 10
|
n_inner_iter = 0
|
||||||
epochs = 100
|
epochs = 150
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
|
|
||||||
#### Classic ####
|
#### Classic ####
|
||||||
|
@ -74,12 +90,13 @@ if __name__ == "__main__":
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
'''
|
'''
|
||||||
#### Augmented Model ####
|
#### Augmented Model ####
|
||||||
'''
|
#'''
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=True, fixed_mag=False, shared_mag=False), LeNet(3,10)).to(device)
|
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
|
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
||||||
|
@ -98,9 +115,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print('Execution Time : %.00f '%(time.process_time() - t0))
|
print('Execution Time : %.00f '%(time.process_time() - t0))
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
'''
|
|
||||||
#### TF tests ####
|
|
||||||
#'''
|
#'''
|
||||||
|
#### TF tests ####
|
||||||
|
'''
|
||||||
res_folder="res/brutus-tests/"
|
res_folder="res/brutus-tests/"
|
||||||
epochs= 150
|
epochs= 150
|
||||||
inner_its = [1, 10]
|
inner_its = [1, 10]
|
||||||
|
@ -150,4 +167,4 @@ if __name__ == "__main__":
|
||||||
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
|
#plot_resV2(log, fig_name=res_folder+filename, param_names=tf_names)
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
|
|
||||||
#'''
|
'''
|
||||||
|
|
|
@ -540,6 +540,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||||
dl_val_it = iter(dl_val)
|
dl_val_it = iter(dl_val)
|
||||||
|
|
||||||
|
#if inner_it!=0:
|
||||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2)
|
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2)
|
||||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||||
|
|
||||||
|
@ -680,5 +681,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
model.augment(mode=True)
|
model.augment(mode=True)
|
||||||
if inner_it != 0: high_grad_track = True
|
if inner_it != 0: high_grad_track = True
|
||||||
|
|
||||||
|
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||||
|
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||||
|
|
||||||
#print("Copy ", countcopy)
|
#print("Copy ", countcopy)
|
||||||
return log
|
return log
|
|
@ -64,6 +64,27 @@ TF_dict={ #Dataugv5
|
||||||
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||||
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||||
|
|
||||||
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
|
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
|
#Color TF (Common mag scale)
|
||||||
|
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
|
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
|
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
|
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||||
|
|
||||||
|
|
||||||
'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))),
|
'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))),
|
||||||
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=0))),
|
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=0))),
|
||||||
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=1))),
|
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20*3), zero_pos=1))),
|
||||||
|
@ -74,14 +95,11 @@ TF_dict={ #Dataugv5
|
||||||
'BadTranslateX_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=0))),
|
'BadTranslateX_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=0))),
|
||||||
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
|
'BadTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=20*2, maxval=20*3), zero_pos=1))),
|
||||||
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
|
'BadTranslateY_neg': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-20*3, maxval=-20*2), zero_pos=1))),
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
'BadColor':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'BadSharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*2))),
|
||||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
|
||||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
|
||||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
|
||||||
|
|
||||||
#Non fonctionnel
|
#Non fonctionnel
|
||||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
@ -111,10 +129,15 @@ def float_image(int_image):
|
||||||
# return random.uniform(minval, real_max)
|
# return random.uniform(minval, real_max)
|
||||||
|
|
||||||
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||||
real_max = float_parameter(mag, maxval=maxval)
|
real_mag = float_parameter(mag, maxval=maxval)
|
||||||
if not minval : minval = -real_max
|
if not minval : minval = -real_mag
|
||||||
#return random.uniform(minval, real_max)
|
#return random.uniform(minval, real_max)
|
||||||
return minval +(real_max-minval) * torch.rand(size, device=mag.device)
|
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
||||||
|
|
||||||
|
def invScale_rand_floats(size, mag, maxval, minval):
|
||||||
|
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]
|
||||||
|
real_mag = float_parameter(float(PARAMETER_MAX) - mag, maxval=maxval-minval)+minval
|
||||||
|
return real_mag + (maxval-real_mag) * torch.rand(size, device=mag.device) #[real_mag, max_val]
|
||||||
|
|
||||||
def zero_stack(tensor, zero_pos):
|
def zero_stack(tensor, zero_pos):
|
||||||
if zero_pos==0:
|
if zero_pos==0:
|
||||||
|
@ -139,7 +162,7 @@ def float_parameter(level, maxval):
|
||||||
#return float(level) * maxval / PARAMETER_MAX
|
#return float(level) * maxval / PARAMETER_MAX
|
||||||
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
|
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
|
||||||
|
|
||||||
def int_parameter(level, maxval): #Perte de gradient
|
#def int_parameter(level, maxval): #Perte de gradient
|
||||||
"""Helper function to scale `val` between 0 and maxval .
|
"""Helper function to scale `val` between 0 and maxval .
|
||||||
Args:
|
Args:
|
||||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||||
|
@ -149,7 +172,7 @@ def int_parameter(level, maxval): #Perte de gradient
|
||||||
An int that results from scaling `maxval` according to `level`.
|
An int that results from scaling `maxval` according to `level`.
|
||||||
"""
|
"""
|
||||||
#return int(level * maxval / PARAMETER_MAX)
|
#return int(level * maxval / PARAMETER_MAX)
|
||||||
return (level * maxval / PARAMETER_MAX)
|
# return (level * maxval / PARAMETER_MAX)
|
||||||
|
|
||||||
def flipLR(x):
|
def flipLR(x):
|
||||||
device = x.device
|
device = x.device
|
||||||
|
@ -279,19 +302,19 @@ def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
||||||
for idx, t in enumerate(thresholds): #Operation par image
|
for idx, t in enumerate(thresholds): #Operation par image
|
||||||
mask = x[idx] > t #Perte du gradient
|
mask = x[idx] > t #Perte du gradient
|
||||||
#In place
|
#In place
|
||||||
#inv_x = 1-x[idx][mask]
|
inv_x = 1-x[idx][mask]
|
||||||
#x[idx][mask]=inv_x
|
x[idx][mask]=inv_x
|
||||||
#
|
#
|
||||||
|
|
||||||
#Out of place
|
#Out of place
|
||||||
im = x[idx]
|
# im = x[idx]
|
||||||
inv_x = 1-im[mask]
|
# inv_x = 1-im[mask]
|
||||||
|
|
||||||
imgs.append(im.masked_scatter(mask,inv_x))
|
# imgs.append(im.masked_scatter(mask,inv_x))
|
||||||
|
|
||||||
idxs=torch.tensor(range(x.shape[0]), device=x.device)
|
#idxs=torch.tensor(range(x.shape[0]), device=x.device)
|
||||||
idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||||
x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||||
#
|
#
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue