Minor improvement (RandAug)

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-30 11:21:25 -05:00
parent 6bba069d8a
commit 561b71b30a
5 changed files with 50 additions and 179 deletions

View file

@ -187,11 +187,11 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters. Ensure that the parameters value stays in the right intevals. This should be called after each update of those parameters.
Args: Args:
soft (bool): Wether to use a softmax function for TF probabilites. Not Recommended as it tends to lock the probabilities, preventing them to be learned. (default: False) soft (bool): Wether to use a softmax function for TF probabilites. Tends to lock the probabilities if the learning rate is low, preventing them to be learned. (default: False)
""" """
if not self._fixed_prob: if not self._fixed_prob:
if soft : if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0)
else: else:
self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0) self._params['prob'].data = self._params['prob'].data.clamp(min=1/(self._nb_tf*100),max=1.0)
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1 self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
@ -269,6 +269,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
""" """
self._data_augmentation=mode self._data_augmentation=mode
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key): def __getitem__(self, key):
"""Access to the learnable parameters """Access to the learnable parameters
Args: Args:
@ -588,6 +596,14 @@ class Data_augV7(nn.Module): #Proba sequentielles
""" """
self._data_augmentation=mode self._data_augmentation=mode
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key): def __getitem__(self, key):
"""Access to the learnable parameters """Access to the learnable parameters
Args: Args:
@ -659,6 +675,8 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
}) })
self._shared_mag = True self._shared_mag = True
self._fixed_mag = True self._fixed_mag = True
self._fixed_prob=True
self._fixed_mix=True
self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX) self._params['mag'].data = self._params['mag'].data.clamp(min=TF.PARAMETER_MIN, max=TF.PARAMETER_MAX)
@ -753,6 +771,14 @@ class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
""" """
self._data_augmentation=mode self._data_augmentation=mode
def is_augmenting(self):
""" Return wether data augmentation is applied.
Returns:
bool : True if data augmentation is applied.
"""
return self._data_augmentation
def __getitem__(self, key): def __getitem__(self, key):
"""Access to the learnable parameters """Access to the learnable parameters
Args: Args:
@ -796,7 +822,7 @@ class Higher_model(nn.Module):
""" """
super(Higher_model, self).__init__() super(Higher_model, self).__init__()
self._name = model.__str__() self._name = model.__class__.__name__ #model.__str__()
self._mods = nn.ModuleDict({ self._mods = nn.ModuleDict({
'original': model, 'original': model,
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True) 'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)

View file

@ -1,163 +0,0 @@
from model import *
from dataug import *
#from utils import *
from train_utils import *
tf_names = [
## Geometric TF ##
'Identity',
'FlipUD',
'FlipLR',
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
]
device = torch.device('cuda')
if device == torch.device('cpu'):
device_name = 'CPU'
else:
device_name = torch.cuda.get_device_name(device)
##########################################
if __name__ == "__main__":
n_inner_iter = 1
epochs = 150
dataug_epoch_start=0
optim_param={
'Meta':{
'optim':'Adam',
'lr':1e-2, #1e-2
},
'Inner':{
'optim': 'SGD',
'lr':1e-1, #1e-2
'momentum':0.9, #0.9
}
}
#model = LeNet(3,10)
#model = ResNet(num_classes=10)
#model = MobileNetV2(num_classes=10)
#model = WideResNet(num_classes=10, wrn_size=32)
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
####
'''
t0 = time.process_time()
aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None)
exec_time=time.process_time() - t0
####
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
'''
####
'''
t0 = time.process_time()
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=True, loss_patience=None)
exec_time=time.process_time() - t0
####
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
'''
res_folder="../res/brutus-tests2/"
epochs= 150
inner_its = [1]
dist_mix = [0.0, 0.5, 0.8, 1.0]
dataug_epoch_starts= [0]
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
TF_nb = [len(tf_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)]
N_seq_TF= [4, 3, 2]
mag_setup = [(True,True), (False, False)] #(Fixed, Shared)
#prob_setup = [True, False]
nb_run= 3
try:
os.mkdir(res_folder)
os.mkdir(res_folder+"log/")
except FileExistsError:
pass
for n_inner_iter in inner_its:
for dataug_epoch_start in dataug_epoch_starts:
for n_tf in N_seq_TF:
for dist in dist_mix:
#for i in TF_nb:
for m_setup in mag_setup:
#for p_setup in prob_setup:
p_setup=False
for run in range(nb_run):
if (n_inner_iter == 0 and (m_setup!=(True,True) and p_setup!=True)) or (p_setup and dist!=0.0): continue #Autres setup inutiles sans meta-opti
#keys = list(TF.TF_dict.keys())[0:i]
#ntf_dict = {k: TF.TF_dict[k] for k in keys}
t0 = time.process_time()
model = ResNet(num_classes=10)
model = Higher_model(model) #run_dist_dataugV3
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV3(model=aug_model,
epochs=epochs,
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=50,
KLdiv=True)
exec_time=time.process_time() - t0
####
print('-'*9)
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it-{}".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter, run)
with open("../res/log/%s.json" % filename, "w+") as f:
try:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')
except:
print("Failed to save logs :",f.name)
try:
plot_resV2(log, fig_name="../res/"+filename, param_names=aug_model.TF_names())
except:
print("Failed to plot res")
print('Execution Time : %.00f '%(exec_time))
print('-'*9)
#'''

View file

@ -53,10 +53,6 @@ tf_names = [
#'Random', #'Random',
#'RandBlend' #'RandBlend'
#Non fonctionnel
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
#'Equalize',
] ]
@ -67,6 +63,12 @@ if device == torch.device('cpu'):
else: else:
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
torch.backends.cudnn.benchmark = True #Faster if same input size #Not recommended for reproductibility
#Increase reproductibility
torch.manual_seed(0)
np.random.seed(0)
########################################## ##########################################
if __name__ == "__main__": if __name__ == "__main__":
@ -78,7 +80,7 @@ if __name__ == "__main__":
} }
#Parameters #Parameters
n_inner_iter = 1 n_inner_iter = 1
epochs = 1 epochs = 150
dataug_epoch_start=0 dataug_epoch_start=0
optim_param={ optim_param={
'Meta':{ 'Meta':{
@ -95,9 +97,8 @@ if __name__ == "__main__":
#Models #Models
model = LeNet(3,10) model = LeNet(3,10)
#model = ResNet(num_classes=10) #model = ResNet(num_classes=10)
#Lents #import torchvision.models as models
#model = MobileNetV2(num_classes=10) #model=models.resnet18()
#model = WideResNet(num_classes=10, wrn_size=32)
#### Classic #### #### Classic ####
if 'classic' in tasks: if 'classic' in tasks:
@ -105,7 +106,7 @@ if __name__ == "__main__":
model = model.to(device) model = model.to(device)
print("{} on {} for {} epochs".format(str(model), device_name, epochs)) print("{} on {} for {} epochs".format(str(model), device_name, epochs))
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1) log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=20)
#log= train_classic_higher(model=model, epochs=epochs) #log= train_classic_higher(model=model, epochs=epochs)
exec_time=time.process_time() - t0 exec_time=time.process_time() - t0
@ -130,11 +131,10 @@ if __name__ == "__main__":
tf_dict = {k: TF.TF_dict[k] for k in tf_names} tf_dict = {k: TF.TF_dict[k] for k in tf_names}
model = Higher_model(model) #run_dist_dataugV3 model = Higher_model(model) #run_dist_dataugV3
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device) aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.8, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device) #aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter)) print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_simple_smartaug(model=aug_model, epochs=epochs, inner_it=n_inner_iter, opt_param=optim_param)
log= run_dist_dataugV3(model=aug_model, log= run_dist_dataugV3(model=aug_model,
epochs=epochs, epochs=epochs,
inner_it=n_inner_iter, inner_it=n_inner_iter,
@ -142,7 +142,8 @@ if __name__ == "__main__":
opt_param=optim_param, opt_param=optim_param,
print_freq=1, print_freq=1,
unsup_loss=1, unsup_loss=1,
hp_opt=False) hp_opt=False,
save_sample_freq=None)
exec_time=time.process_time() - t0 exec_time=time.process_time() - t0
#### ####

View file

@ -288,12 +288,18 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
model['model'].detach_() model['model'].detach_()
meta_opt.zero_grad() meta_opt.zero_grad()
elif not high_grad_track:
diffopt.detach_()
model['model'].detach_()
tf = time.process_time() tf = time.process_time()
if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving if (save_sample_freq and epoch%save_sample_freq==0): #Data sample saving
try: try:
viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch)) viz_sample_data(imgs=xs, labels=ys, fig_name='../samples/data_sample_epoch{}_noTF'.format(epoch))
model.train()
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch)) viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='../samples/data_sample_epoch{}'.format(epoch))
model.eval()
except: except:
print("Couldn't save samples epoch"+epoch) print("Couldn't save samples epoch"+epoch)
pass pass
@ -315,9 +321,9 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=1, dataug_epoch_start
"acc": accuracy, "acc": accuracy,
"time": tf - t0, "time": tf - t0,
"mix_dist": model['data_aug']['mix_dist'].item(),
"param": param, "param": param,
} }
if not model['data_aug']._fixed_mix: data["mix_dist"]=model['data_aug']['mix_dist'].item()
if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups] if hp_opt : data["opt_param"]=[{'lr': p_grp['lr'].item(), 'momentum': p_grp['momentum'].item()} for p_grp in diffopt.param_groups]
log.append(data) log.append(data)
############# #############

View file

@ -131,6 +131,7 @@ def viz_sample_data(imgs, labels, fig_name='data_sample', weight_labels=None):
fig_name (string): Relative path where to save the graph. (default: data_sample) fig_name (string): Relative path where to save the graph. (default: data_sample)
weight_labels (Tensor): Weights associated to each labels. (default: None) weight_labels (Tensor): Weights associated to each labels. (default: None)
""" """
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu() sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
plt.figure(figsize=(10,10)) plt.figure(figsize=(10,10))