mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Refactoring de TF_dict
This commit is contained in:
parent
fd4dcdb392
commit
103277fadd
8 changed files with 245 additions and 23 deletions
160
higher/dataug.py
160
higher/dataug.py
|
@ -322,8 +322,9 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
|
||||
#self._TF_matrix={}
|
||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
||||
self._mag_fct = TF_dict
|
||||
self._TF=list(self._mag_fct.keys())
|
||||
#self._mag_fct = TF_dict
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_TF = N_TF
|
||||
|
@ -356,7 +357,6 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
print(self.distrib.shape)
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
|
@ -414,6 +414,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
magnitude=self._fixed_mag
|
||||
tf=self._TF[tf_idx]
|
||||
|
||||
'''
|
||||
## Geometric TF ##
|
||||
if tf=='Identity':
|
||||
pass
|
||||
|
@ -449,6 +450,8 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
raise Exception("Invalid TF requested : ", tf)
|
||||
|
||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
||||
'''
|
||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
||||
|
||||
#idx= mask.nonzero()
|
||||
#print('-'*8)
|
||||
|
@ -527,6 +530,157 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
else:
|
||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
||||
|
||||
class Data_augV5(nn.Module): #Transformations avec mask
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0):
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
#self._TF_matrix={}
|
||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
||||
self._mag_fct = TF_dict
|
||||
self._TF=list(self._mag_fct.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_TF = N_TF
|
||||
|
||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
})
|
||||
|
||||
self._samples = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
self._samples = []
|
||||
|
||||
for _ in range(self._N_TF):
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
self._samples.append(sample)
|
||||
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
device = x.device
|
||||
smps_x=[]
|
||||
masks=[]
|
||||
for tf_idx in range(self._nb_tf):
|
||||
mask = sampled_TF==tf_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ?
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._fixed_mag
|
||||
tf=self._TF[tf_idx]
|
||||
|
||||
## Geometric TF ##
|
||||
if tf=='Identity':
|
||||
pass
|
||||
elif tf=='FlipLR':
|
||||
smp_x = TF.flipLR(smp_x)
|
||||
elif tf=='FlipUD':
|
||||
smp_x = TF.flipUD(smp_x)
|
||||
elif tf=='Rotate':
|
||||
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='TranslateX' or tf=='TranslateY':
|
||||
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='ShearX' or tf=='ShearY' :
|
||||
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
elif tf=='Contrast':
|
||||
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Color':
|
||||
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Brightness':
|
||||
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Sharpness':
|
||||
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Posterize':
|
||||
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
|
||||
elif tf=='Solarize':
|
||||
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Equalize':
|
||||
smp_x = TF.equalize(smp_x)
|
||||
elif tf=='Auto_Contrast':
|
||||
smp_x = TF.auto_contrast(smp_x)
|
||||
else:
|
||||
raise Exception("Invalid TF requested : ", tf)
|
||||
|
||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
||||
return x
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
def loss_weight(self):
|
||||
# 1 seule TF
|
||||
#self._sample = self._samples[-1]
|
||||
#w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||
#w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1)
|
||||
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
#w_loss = torch.sum(w_loss,dim=1)
|
||||
|
||||
#Plusieurs TF sequentielles (Hypothese ordre negligeable)
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
|
||||
for sample in self._samples:
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
|
||||
w_loss += tmp_w
|
||||
|
||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV5, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
if not self._mix_dist:
|
||||
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
|
||||
else:
|
||||
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
|
||||
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
def __init__(self, data_augmenter, model):
|
||||
super(Augmented_model, self).__init__()
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 55 KiB |
Binary file not shown.
Before Width: | Height: | Size: 45 KiB |
Binary file not shown.
Before Width: | Height: | Size: 36 KiB |
|
@ -64,18 +64,18 @@ if __name__ == "__main__":
|
|||
print('-'*9)
|
||||
'''
|
||||
#### Augmented Model ####
|
||||
'''
|
||||
#'''
|
||||
t0 = time.process_time()
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
#tf_dict = TF.TF_dict
|
||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), LeNet(3,10)).to(device)
|
||||
aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||
aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device)
|
||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10)
|
||||
|
||||
####
|
||||
plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names)
|
||||
plot_resV2(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names)
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
|
@ -86,12 +86,13 @@ if __name__ == "__main__":
|
|||
|
||||
print('Execution Time : %.00f (s?)'%(time.process_time() - t0))
|
||||
print('-'*9)
|
||||
'''
|
||||
#'''
|
||||
#### TF number tests ####
|
||||
#'''
|
||||
res_folder="res/TF_nb_tests/"
|
||||
epochs= 100
|
||||
inner_its = [0, 1, 10]
|
||||
dist_mix = [0.0, 0.5]
|
||||
dataug_epoch_starts= [0]
|
||||
TF_nb = [len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
|
||||
N_seq_TF= [2, 3, 4, 6]
|
||||
|
|
|
@ -586,8 +586,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
#PAS PONDERE LOSS POUR DIST MIX
|
||||
if fmodel._data_augmentation: # and not fmodel['data_aug']._mix_dist: #Weight loss
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight().to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
|
|
@ -3,30 +3,54 @@ import kornia
|
|||
import random
|
||||
|
||||
### Available TF for Dataug ###
|
||||
'''
|
||||
TF_dict={ #f(mag_normalise)=mag_reelle
|
||||
## Geometric TF ##
|
||||
'Identity' : (lambda mag: None),
|
||||
'FlipUD' : (lambda mag: None),
|
||||
'FlipLR' : (lambda mag: None),
|
||||
'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))),
|
||||
'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]),
|
||||
'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]),
|
||||
'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]),
|
||||
'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]),
|
||||
'Rotate': (lambda mag: rand_int(mag,maxval=30)),
|
||||
'TranslateX': (lambda mag: [rand_int(mag,maxval=20), 0]),
|
||||
'TranslateY': (lambda mag: [0, rand_int(mag,maxval=20)]),
|
||||
'ShearX': (lambda mag: [rand_float(mag, maxval=0.3), 0]),
|
||||
'ShearY': (lambda mag: [0, rand_float(mag, maxval=0.3)]),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))),
|
||||
'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))),
|
||||
'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch
|
||||
'Contrast': (lambda mag: rand_float(mag,minval=0.1, maxval=1.9)),
|
||||
'Color':(lambda mag: rand_float(mag,minval=0.1, maxval=1.9)),
|
||||
'Brightness':(lambda mag: rand_float(mag,minval=1., maxval=1.9)),
|
||||
'Sharpness':(lambda mag: rand_float(mag,minval=0.1, maxval=1.9)),
|
||||
'Posterize': (lambda mag: rand_int(mag,minval=4, maxval=8)),
|
||||
'Solarize': (lambda mag: rand_int(mag,minval=1, maxval=256)/256.), #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
'''
|
||||
TF_dict={
|
||||
## Geometric TF ##
|
||||
'Identity' : (lambda x, mag: x),
|
||||
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||
'Rotate': (lambda x, mag: rotate(x, angle=torch.tensor([rand_int(mag, maxval=30)for _ in x], device=x.device))),
|
||||
'TranslateX': (lambda x, mag: translate(x, translation=torch.tensor([[rand_int(mag, maxval=20), 0] for _ in x], device=x.device))),
|
||||
'TranslateY': (lambda x, mag: translate(x, translation=torch.tensor([[0, rand_int(mag, maxval=20)] for _ in x], device=x.device))),
|
||||
'ShearX': (lambda x, mag: shear(x, shear=torch.tensor([[rand_float(mag, maxval=0.3), 0] for _ in x], device=x.device))),
|
||||
'ShearY': (lambda x, mag: shear(x, shear=torch.tensor([[0, rand_float(mag, maxval=0.3)] for _ in x], device=x.device))),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||
'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=1., maxval=1.9) for _ in x], device=x.device))),
|
||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))),
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
|
||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||
return (float_image*255.).type(torch.uint8)
|
||||
|
@ -34,8 +58,19 @@ def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/25
|
|||
def float_image(int_image):
|
||||
return int_image.type(torch.float)/255.
|
||||
|
||||
def rand_inverse(value):
|
||||
return value if random.random() < 0.5 else -value
|
||||
#def rand_inverse(value):
|
||||
# return value if random.random() < 0.5 else -value
|
||||
|
||||
def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||
real_max = int_parameter(mag, maxval=maxval)
|
||||
if not minval : minval = -real_max
|
||||
return random.randint(minval, real_max)
|
||||
|
||||
def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
||||
real_max = float_parameter(mag, maxval=maxval)
|
||||
if not minval : minval = -real_max
|
||||
return random.uniform(minval, real_max)
|
||||
|
||||
|
||||
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
||||
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
|
||||
|
|
|
@ -48,6 +48,38 @@ def plot_res(log, fig_name='res', param_names=None):
|
|||
plt.savefig(fig_name)
|
||||
plt.close()
|
||||
|
||||
def plot_resV2(log, fig_name='res', param_names=None):
|
||||
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 15))
|
||||
|
||||
ax[0, 0].set_title('Loss')
|
||||
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||
ax[0, 0].legend()
|
||||
|
||||
ax[0, 1].set_title('Acc')
|
||||
ax[0, 1].plot(epochs,[x["acc"] for x in log])
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
ax[1, 1].set_title('Prob')
|
||||
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
ax[1, 1].stackplot(epochs, proba, labels=param_names)
|
||||
ax[1, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
ax[1, 0].set_title('Mean prob')
|
||||
mean = np.mean([x["param"] for x in log], axis=0)
|
||||
std = np.std([x["param"] for x in log], axis=0)
|
||||
ax[1, 0].bar(param_names, mean, yerr=std)
|
||||
plt.sca(ax[1, 0]), plt.xticks(rotation=90)
|
||||
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def plot_compare(filenames, fig_name='res'):
|
||||
|
||||
all_data=[]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue