Refactoring de TF_dict

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-14 21:17:54 -05:00
parent fd4dcdb392
commit 103277fadd
8 changed files with 245 additions and 23 deletions

View file

@ -322,8 +322,9 @@ class Data_augV4(nn.Module): #Transformations avec mask
#self._TF_matrix={} #self._TF_matrix={}
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix #self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
self._mag_fct = TF_dict #self._mag_fct = TF_dict
self._TF=list(self._mag_fct.keys()) self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._nb_tf= len(self._TF) self._nb_tf= len(self._TF)
self._N_TF = N_TF self._N_TF = N_TF
@ -356,7 +357,6 @@ class Data_augV4(nn.Module): #Transformations avec mask
self._distrib = uniforme_dist self._distrib = uniforme_dist
else: else:
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
print(self.distrib.shape)
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib) cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
sample = cat_distrib.sample() sample = cat_distrib.sample()
@ -414,6 +414,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
magnitude=self._fixed_mag magnitude=self._fixed_mag
tf=self._TF[tf_idx] tf=self._TF[tf_idx]
'''
## Geometric TF ## ## Geometric TF ##
if tf=='Identity': if tf=='Identity':
pass pass
@ -449,6 +450,8 @@ class Data_augV4(nn.Module): #Transformations avec mask
raise Exception("Invalid TF requested : ", tf) raise Exception("Invalid TF requested : ", tf)
x[mask]=smp_x # Refusionner eviter x[mask] : in place x[mask]=smp_x # Refusionner eviter x[mask] : in place
'''
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
#idx= mask.nonzero() #idx= mask.nonzero()
#print('-'*8) #print('-'*8)
@ -527,6 +530,157 @@ class Data_augV4(nn.Module): #Transformations avec mask
else: else:
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF) return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
class Data_augV5(nn.Module): #Transformations avec mask
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0):
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
self._data_augmentation = True
#self._TF_matrix={}
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
self._mag_fct = TF_dict
self._TF=list(self._mag_fct.keys())
self._nb_tf= len(self._TF)
self._N_TF = N_TF
self._fixed_mag=5 #[0, PARAMETER_MAX]
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
})
self._samples = []
self._mix_dist = False
if mix_dist != 0.0:
self._mix_dist = True
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
def forward(self, x):
if self._data_augmentation:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
self._samples = []
for _ in range(self._N_TF):
## Echantillonage ##
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
if not self._mix_dist:
self._distrib = uniforme_dist
else:
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
sample = cat_distrib.sample()
self._samples.append(sample)
## Transformations ##
x = self.apply_TF(x, sample)
return x
def apply_TF(self, x, sampled_TF):
device = x.device
smps_x=[]
masks=[]
for tf_idx in range(self._nb_tf):
mask = sampled_TF==tf_idx #Create selection mask
smp_x = x[mask] #torch.masked_select() ?
if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._fixed_mag
tf=self._TF[tf_idx]
## Geometric TF ##
if tf=='Identity':
pass
elif tf=='FlipLR':
smp_x = TF.flipLR(smp_x)
elif tf=='FlipUD':
smp_x = TF.flipUD(smp_x)
elif tf=='Rotate':
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='TranslateX' or tf=='TranslateY':
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='ShearX' or tf=='ShearY' :
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
## Color TF (Expect image in the range of [0, 1]) ##
elif tf=='Contrast':
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Color':
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Brightness':
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Sharpness':
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Posterize':
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
elif tf=='Solarize':
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
elif tf=='Equalize':
smp_x = TF.equalize(smp_x)
elif tf=='Auto_Contrast':
smp_x = TF.auto_contrast(smp_x)
else:
raise Exception("Invalid TF requested : ", tf)
x[mask]=smp_x # Refusionner eviter x[mask] : in place
return x
def adjust_prob(self, soft=False): #Detach from gradient ?
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
else:
self._params['prob'].data = F.relu(self._params['prob'].data)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
def loss_weight(self):
# 1 seule TF
#self._sample = self._samples[-1]
#w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
#w_loss.scatter_(dim=1, index=self._sample.view(-1,1), value=1)
#w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
#w_loss = torch.sum(w_loss,dim=1)
#Plusieurs TF sequentielles (Hypothese ordre negligeable)
w_loss = torch.zeros((self._samples[0].shape[0],self._nb_tf), device=self._samples[0].device)
for sample in self._samples:
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
tmp_w.scatter_(dim=1, index=sample.view(-1,1), value=1/self._N_TF)
w_loss += tmp_w
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
w_loss = torch.sum(w_loss,dim=1)
return w_loss
def train(self, mode=None):
if mode is None :
mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None
super(Data_augV5, self).train(mode)
def eval(self):
self.train(mode=False)
def augment(self, mode=True):
self._data_augmentation=mode
def __getitem__(self, key):
return self._params[key]
def __str__(self):
if not self._mix_dist:
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_TF)
else:
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_TF)
class Augmented_model(nn.Module): class Augmented_model(nn.Module):
def __init__(self, data_augmenter, model): def __init__(self, data_augmenter, model):
super(Augmented_model, self).__init__() super(Augmented_model, self).__init__()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

View file

@ -64,18 +64,18 @@ if __name__ == "__main__":
print('-'*9) print('-'*9)
''' '''
#### Augmented Model #### #### Augmented Model ####
''' #'''
t0 = time.process_time() t0 = time.process_time()
tf_dict = {k: TF.TF_dict[k] for k in tf_names} tf_dict = {k: TF.TF_dict[k] for k in tf_names}
#tf_dict = TF.TF_dict #tf_dict = TF.TF_dict
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), LeNet(3,10)).to(device) aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.5), LeNet(3,10)).to(device)
aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
print(str(aug_model), 'on', device_name) print(str(aug_model), 'on', device_name)
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10) log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=10)
#### ####
plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names) plot_resV2(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), param_names=tf_names)
print('-'*9) print('-'*9)
times = [x["time"] for x in log] times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log} out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
@ -86,12 +86,13 @@ if __name__ == "__main__":
print('Execution Time : %.00f (s?)'%(time.process_time() - t0)) print('Execution Time : %.00f (s?)'%(time.process_time() - t0))
print('-'*9) print('-'*9)
''' #'''
#### TF number tests #### #### TF number tests ####
#''' #'''
res_folder="res/TF_nb_tests/" res_folder="res/TF_nb_tests/"
epochs= 100 epochs= 100
inner_its = [0, 1, 10] inner_its = [0, 1, 10]
dist_mix = [0.0, 0.5]
dataug_epoch_starts= [0] dataug_epoch_starts= [0]
TF_nb = [len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)] TF_nb = [len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
N_seq_TF= [2, 3, 4, 6] N_seq_TF= [2, 3, 4, 6]

View file

@ -586,8 +586,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
logits = fmodel(xs) # modified `params` can also be passed as a kwarg logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards() loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
#PAS PONDERE LOSS POUR DIST MIX
if fmodel._data_augmentation: # and not fmodel['data_aug']._mix_dist: #Weight loss if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight().to(device) w_loss = fmodel['data_aug'].loss_weight().to(device)
loss = loss * w_loss loss = loss * w_loss
loss = loss.mean() loss = loss.mean()

View file

@ -3,30 +3,54 @@ import kornia
import random import random
### Available TF for Dataug ### ### Available TF for Dataug ###
'''
TF_dict={ #f(mag_normalise)=mag_reelle TF_dict={ #f(mag_normalise)=mag_reelle
## Geometric TF ## ## Geometric TF ##
'Identity' : (lambda mag: None), 'Identity' : (lambda mag: None),
'FlipUD' : (lambda mag: None), 'FlipUD' : (lambda mag: None),
'FlipLR' : (lambda mag: None), 'FlipLR' : (lambda mag: None),
'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))), 'Rotate': (lambda mag: rand_int(mag,maxval=30)),
'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]), 'TranslateX': (lambda mag: [rand_int(mag,maxval=20), 0]),
'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]), 'TranslateY': (lambda mag: [0, rand_int(mag,maxval=20)]),
'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]), 'ShearX': (lambda mag: [rand_float(mag, maxval=0.3), 0]),
'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]), 'ShearY': (lambda mag: [0, rand_float(mag, maxval=0.3)]),
## Color TF (Expect image in the range of [0, 1]) ## ## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), 'Contrast': (lambda mag: rand_float(mag,minval=0.1, maxval=1.9)),
'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), 'Color':(lambda mag: rand_float(mag,minval=0.1, maxval=1.9)),
'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))), 'Brightness':(lambda mag: rand_float(mag,minval=1., maxval=1.9)),
'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))), 'Sharpness':(lambda mag: rand_float(mag,minval=0.1, maxval=1.9)),
'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))), 'Posterize': (lambda mag: rand_int(mag,minval=4, maxval=8)),
'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch 'Solarize': (lambda mag: rand_int(mag,minval=1, maxval=256)/256.), #=>Image entre [0,1] #Pas opti pour des batch
#Non fonctionnel #Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent) #'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None), #'Equalize': (lambda mag: None),
} }
'''
TF_dict={
## Geometric TF ##
'Identity' : (lambda x, mag: x),
'FlipUD' : (lambda x, mag: flipUD(x)),
'FlipLR' : (lambda x, mag: flipLR(x)),
'Rotate': (lambda x, mag: rotate(x, angle=torch.tensor([rand_int(mag, maxval=30)for _ in x], device=x.device))),
'TranslateX': (lambda x, mag: translate(x, translation=torch.tensor([[rand_int(mag, maxval=20), 0] for _ in x], device=x.device))),
'TranslateY': (lambda x, mag: translate(x, translation=torch.tensor([[0, rand_int(mag, maxval=20)] for _ in x], device=x.device))),
'ShearX': (lambda x, mag: shear(x, shear=torch.tensor([[rand_float(mag, maxval=0.3), 0] for _ in x], device=x.device))),
'ShearY': (lambda x, mag: shear(x, shear=torch.tensor([[0, rand_float(mag, maxval=0.3)] for _ in x], device=x.device))),
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast': (lambda x, mag: contrast(x, contrast_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Color':(lambda x, mag: color(x, color_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Brightness':(lambda x, mag: brightness(x, brightness_factor=torch.tensor([rand_float(mag, minval=1., maxval=1.9) for _ in x], device=x.device))),
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=torch.tensor([rand_float(mag, minval=0.1, maxval=1.9) for _ in x], device=x.device))),
'Posterize': (lambda x, mag: posterize(x, bits=torch.tensor([rand_int(mag, minval=4, maxval=8) for _ in x], device=x.device))),
'Solarize': (lambda x, mag: solarize(x, thresholds=torch.tensor([rand_int(mag,minval=1, maxval=256)/256. for _ in x], device=x.device))) , #=>Image entre [0,1] #Pas opti pour des batch
#Non fonctionnel
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
#'Equalize': (lambda mag: None),
}
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039) def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
return (float_image*255.).type(torch.uint8) return (float_image*255.).type(torch.uint8)
@ -34,8 +58,19 @@ def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/25
def float_image(int_image): def float_image(int_image):
return int_image.type(torch.float)/255. return int_image.type(torch.float)/255.
def rand_inverse(value): #def rand_inverse(value):
return value if random.random() < 0.5 else -value # return value if random.random() < 0.5 else -value
def rand_int(mag, maxval, minval=None): #[(-maxval,minval), maxval]
real_max = int_parameter(mag, maxval=maxval)
if not minval : minval = -real_max
return random.randint(minval, real_max)
def rand_float(mag, maxval, minval=None): #[(-maxval,minval), maxval]
real_max = float_parameter(mag, maxval=maxval)
if not minval : minval = -real_max
return random.uniform(minval, real_max)
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137 #https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted

View file

@ -48,6 +48,38 @@ def plot_res(log, fig_name='res', param_names=None):
plt.savefig(fig_name) plt.savefig(fig_name)
plt.close() plt.close()
def plot_resV2(log, fig_name='res', param_names=None):
epochs = [x["epoch"] for x in log]
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 15))
ax[0, 0].set_title('Loss')
ax[0, 0].plot(epochs,[x["train_loss"] for x in log], label='Train')
ax[0, 0].plot(epochs,[x["val_loss"] for x in log], label='Val')
ax[0, 0].legend()
ax[0, 1].set_title('Acc')
ax[0, 1].plot(epochs,[x["acc"] for x in log])
if log[0]["param"]!= None:
ax[1, 1].set_title('Prob')
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
ax[1, 1].stackplot(epochs, proba, labels=param_names)
ax[1, 1].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
ax[1, 0].set_title('Mean prob')
mean = np.mean([x["param"] for x in log], axis=0)
std = np.std([x["param"] for x in log], axis=0)
ax[1, 0].bar(param_names, mean, yerr=std)
plt.sca(ax[1, 0]), plt.xticks(rotation=90)
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_compare(filenames, fig_name='res'): def plot_compare(filenames, fig_name='res'):
all_data=[] all_data=[]