Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-15 16:55:03 -05:00
parent e291bc2e44
commit 75901b69b4
6 changed files with 198 additions and 83 deletions

View file

@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ):
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
@ -545,13 +545,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
#self._fixed_mag=5 #[0, PARAMETER_MAX]
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag
else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX]
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
})
for tf in TF.TF_no_grad :
if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Distribution
@ -1094,8 +1096,8 @@ class Augmented_model(nn.Module):
self.augment(mode=True)
def initialize(self):
self._mods['model'].initialize()
#def initialize(self):
# self._mods['model'].initialize()
def forward(self, x):
return self._mods['model'](self._mods['data_aug'](x))
@ -1136,4 +1138,81 @@ class Augmented_model(nn.Module):
return self._mods[key]
def __str__(self):
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
'''
import higher
class Augmented_model2(nn.Module):
def __init__(self, data_augmenter, model):
super(Augmented_model2, self).__init__()
self._mods = nn.ModuleDict({
'data_aug': data_augmenter,
'model': model,
'fmodel': None
})
self.augment(mode=True)
def initialize(self):
self._mods['model'].initialize()
def forward(self, x):
if self._mods['fmodel']:
return self._mods['fmodel'](self._mods['data_aug'](x))
else:
return self._mods['model'](self._mods['data_aug'](x))
def functional(self, opt, track_higher_grads=True):
self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True)
return higher.optim.get_diff_optim(opt,
self._mods['model'].parameters(),
fmodel=self._mods['fmodel'],
track_higher_grads=track_higher_grads)
def detach_(self):
tmp = self._mods['fmodel'].fast_params
self._mods['fmodel']._fast_params=[]
self._mods['fmodel'].update_params(tmp)
for p in self._mods['fmodel'].fast_params:
p.detach_().requires_grad_()
def augment(self, mode=True):
self._data_augmentation=mode
self._mods['data_aug'].augment(mode)
def train(self, mode=None):
if mode is None :
mode=self._data_augmentation
self._mods['data_aug'].augment(mode)
super(Augmented_model2, self).train(mode)
return self
def eval(self):
return self.train(mode=False)
#super(Augmented_model, self).eval()
def items(self):
"""Return an iterable of the ModuleDict key/value pairs.
"""
return self._mods.items()
def update(self, modules):
self._mods.update(modules)
def is_augmenting(self):
return self._data_augmentation
def TF_names(self):
try:
return self._mods['data_aug']._TF
except:
return None
def __getitem__(self, key):
return self._mods[key]
def __str__(self):
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
'''

View file

@ -3,6 +3,40 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import higher
class Higher_model(nn.Module):
def __init__(self, model):
super(Higher_model, self).__init__()
self._mods = nn.ModuleDict({
'original': model,
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
})
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
return higher.optim.get_diff_optim(opt,
self._mods['original'].parameters(),
fmodel=self._mods['functional'],
grad_callback=grad_callback,
track_higher_grads=track_higher_grads)
def forward(self, x):
return self._mods['functional'](x)
def detach_(self):
tmp = self._mods['functional'].fast_params
self._mods['functional']._fast_params=[]
self._mods['functional'].update_params(tmp)
for p in self._mods['functional'].fast_params:
p.detach_().requires_grad_()
def __getitem__(self, key):
return self._mods[key]
def __str__(self):
return self._mods['original'].__str__()
## Basic CNN ##
class LeNet_F(nn.Module):
def __init__(self, num_inp, num_out):

View file

@ -19,8 +19,8 @@ tf_names = [
'Color',
'Brightness',
'Sharpness',
#'Posterize',
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
#Color TF (Common mag scale)
#'+Contrast',
@ -67,7 +67,7 @@ if __name__ == "__main__":
'aug_model'
}
n_inner_iter = 1
epochs = 100
epochs = 15
dataug_epoch_start=0
optim_param={
'Meta':{
@ -81,11 +81,13 @@ if __name__ == "__main__":
}
}
model = LeNet(3,10)
#model = LeNet(3,10)
#model = MobileNetV2(num_classes=10)
#model = ResNet(num_classes=10)
model = ResNet(num_classes=10)
#model = WideResNet(num_classes=10, wrn_size=32)
model = Higher_model(model) #run_dist_dataugV3
#### Classic ####
if 'classic' in tasks:
t0 = time.process_time()
@ -172,12 +174,12 @@ if __name__ == "__main__":
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
log= run_dist_dataugV2(model=aug_model,
log= run_dist_dataugV3(model=aug_model,
epochs=epochs,
inner_it=n_inner_iter,
dataug_epoch_start=dataug_epoch_start,
opt_param=optim_param,
print_freq=10,
print_freq=1,
KLdiv=True,
loss_patience=None)
@ -187,7 +189,7 @@ if __name__ == "__main__":
times = [x["time"] for x in log]
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+"demi_mag"
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)#+"demi_mag"
with open("res/log/%s.json" % filename, "w+") as f:
json.dump(out, f, indent=True)
print('Log :\"',f.name, '\" saved !')

View file

@ -654,7 +654,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
model.train()
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
for epoch in range(1, epochs+1):
#print_torch_mem("Start epoch "+str(epoch))
@ -742,9 +742,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
#if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
@ -835,7 +834,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
#if inner_it!=0:
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
high_grad_track = True
if inner_it == 0:
@ -853,12 +852,17 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
#fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
diffopt = model['model'].get_diffopt(
inner_opt,
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
track_higher_grads=high_grad_track)
#meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
print(len(fmodel._fast_params))
#print(len(model['model']['functional']._fast_params))
for epoch in range(1, epochs+1):
#print_torch_mem("Start epoch "+str(epoch))
@ -871,30 +875,30 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
if(not KLdiv):
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
logits = model(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
if model._data_augmentation: #Weight loss
w_loss = model['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
if fmodel._data_augmentation :
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
fmodel.augment(mode=True)
if model._data_augmentation :
model.augment(mode=False)
sup_logits = model(xs)
model.augment(mode=True)
else:
sup_logits = fmodel(xs)
sup_logits = model(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
if model._data_augmentation:
aug_logits = model(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
aug_loss=0
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
w_loss = model['data_aug'].loss_weight() #Weight loss
#if epoch>50: #debut differe ?
#KL div w/ logits - Similarite predictions (distributions)
@ -915,75 +919,36 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
#print(fmodel['model']._params['b4'].grad)
#print('prob grad', fmodel['data_aug']['prob'].grad)
#for _, p in fmodel['data_aug'].named_parameters():
# p.requires_grad = False
t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
print(len(fmodel._fast_params),"step", time.process_time()-t)
print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
#for _, p in fmodel['data_aug'].named_parameters():
# p.requires_grad = True
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
#print("meta")
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss()
#print_graph(val_loss)
val_loss.backward()
print('proba grad',fmodel['data_aug']['prob'].grad)
#countcopy+=1
#model_copy(src=fmodel, dst=model)
#optim_copy(dopt=diffopt, opt=inner_opt)
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
for paramName, paramValue, in fmodel['data_aug'].named_parameters():
for netCopyName, netCopyValue, in model['data_aug'].named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#del meta_opt.param_groups[0]
#meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]})
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
meta_opt.step()
fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict())
fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#model['data_aug'].next_TF_set()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
#fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()]
diffopt.detach_()
tmp = fmodel.fast_params
fmodel._fast_params=[]
fmodel.update_params(tmp)
for p in fmodel.fast_params:
p.detach_().requires_grad_()
print(len(fmodel._fast_params))
print('TF Proba :', fmodel['data_aug']['prob'].data)
model['model'].detach_()
tf = time.process_time()
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
#model_copy(src=fmodel, dst=model)
if(not high_grad_track):
#countcopy+=1
#model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
accuracy, test_loss =test(model)
model.train()

View file

@ -103,7 +103,8 @@ TF_dict={ #Dataugv5
}
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'}
TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize', '=Solarize', '=Posterize'}
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'}
TF_ignore_mag= TF_no_mag | TF_no_grad
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
return (float_image*255.).type(torch.uint8)

View file

@ -314,4 +314,38 @@ class loss_monitor(): #Voir https://github.com/pytorch/ignite
return False
def reset(self):
self.__init__(self.patience, self.end_train)
self.__init__(self.patience, self.end_train)
### https://github.com/facebookresearch/higher/issues/18 ####
from torch._six import inf
def clip_norm(tensors, max_norm, norm_type=2):
r"""Clips norm of passed tensors.
The norm is computed over all tensors together, as if they were
concatenated into a single vector. Clipped tensors are returned.
Arguments:
tensors (Iterable[Tensor]): an iterable of Tensors or a
single Tensor to be normalized.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Clipped (List[Tensor]) tensors.
"""
if isinstance(tensors, torch.Tensor):
tensors = [tensors]
tensors = list(tensors)
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(t.abs().max() for t in tensors)
else:
total_norm = 0
for t in tensors:
param_norm = t.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef >= 1:
return tensors
return [t.mul(clip_coef) for t in tensors]