mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Modifs dist_dataugv3 (-copy/+rapide) + Legere modif TF
This commit is contained in:
parent
e291bc2e44
commit
75901b69b4
6 changed files with 198 additions and 83 deletions
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
|||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||
|
||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True, ):
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||
super(Data_augV5, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
||||
|
@ -545,13 +545,15 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
self._shared_mag = shared_mag
|
||||
self._fixed_mag = fixed_mag
|
||||
|
||||
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
init_mag = float(TF.PARAMETER_MAX) if self._fixed_mag else float(TF.PARAMETER_MAX)/2
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)/2) if self._shared_mag
|
||||
else torch.tensor(float(TF.PARAMETER_MAX)/2).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
"mag" : nn.Parameter(torch.tensor(init_mag) if self._shared_mag
|
||||
else torch.tensor(init_mag).repeat(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
})
|
||||
|
||||
for tf in TF.TF_no_grad :
|
||||
if tf in self._TF: self._params['mag'].data[self._TF.index(tf)]=float(TF.PARAMETER_MAX) #TF fixe a max parameter
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Distribution
|
||||
|
@ -1094,8 +1096,8 @@ class Augmented_model(nn.Module):
|
|||
|
||||
self.augment(mode=True)
|
||||
|
||||
def initialize(self):
|
||||
self._mods['model'].initialize()
|
||||
#def initialize(self):
|
||||
# self._mods['model'].initialize()
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
@ -1136,4 +1138,81 @@ class Augmented_model(nn.Module):
|
|||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||
|
||||
'''
|
||||
import higher
|
||||
class Augmented_model2(nn.Module):
|
||||
def __init__(self, data_augmenter, model):
|
||||
super(Augmented_model2, self).__init__()
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'data_aug': data_augmenter,
|
||||
'model': model,
|
||||
'fmodel': None
|
||||
})
|
||||
|
||||
self.augment(mode=True)
|
||||
|
||||
def initialize(self):
|
||||
self._mods['model'].initialize()
|
||||
|
||||
def forward(self, x):
|
||||
if self._mods['fmodel']:
|
||||
return self._mods['fmodel'](self._mods['data_aug'](x))
|
||||
else:
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
||||
def functional(self, opt, track_higher_grads=True):
|
||||
self._mods['fmodel'] = higher.patch.monkeypatch(self._mods['model'], device=None, copy_initial_weights=True)
|
||||
|
||||
return higher.optim.get_diff_optim(opt,
|
||||
self._mods['model'].parameters(),
|
||||
fmodel=self._mods['fmodel'],
|
||||
track_higher_grads=track_higher_grads)
|
||||
|
||||
def detach_(self):
|
||||
tmp = self._mods['fmodel'].fast_params
|
||||
self._mods['fmodel']._fast_params=[]
|
||||
self._mods['fmodel'].update_params(tmp)
|
||||
for p in self._mods['fmodel'].fast_params:
|
||||
p.detach_().requires_grad_()
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self._mods['data_aug'].augment(mode)
|
||||
super(Augmented_model2, self).train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
return self.train(mode=False)
|
||||
#super(Augmented_model, self).eval()
|
||||
|
||||
def items(self):
|
||||
"""Return an iterable of the ModuleDict key/value pairs.
|
||||
"""
|
||||
return self._mods.items()
|
||||
|
||||
def update(self, modules):
|
||||
self._mods.update(modules)
|
||||
|
||||
def is_augmenting(self):
|
||||
return self._data_augmentation
|
||||
|
||||
def TF_names(self):
|
||||
try:
|
||||
return self._mods['data_aug']._TF
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
||||
'''
|
|
@ -3,6 +3,40 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
import higher
|
||||
class Higher_model(nn.Module):
|
||||
def __init__(self, model):
|
||||
super(Higher_model, self).__init__()
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'original': model,
|
||||
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
})
|
||||
|
||||
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
|
||||
return higher.optim.get_diff_optim(opt,
|
||||
self._mods['original'].parameters(),
|
||||
fmodel=self._mods['functional'],
|
||||
grad_callback=grad_callback,
|
||||
track_higher_grads=track_higher_grads)
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['functional'](x)
|
||||
|
||||
def detach_(self):
|
||||
tmp = self._mods['functional'].fast_params
|
||||
self._mods['functional']._fast_params=[]
|
||||
self._mods['functional'].update_params(tmp)
|
||||
for p in self._mods['functional'].fast_params:
|
||||
p.detach_().requires_grad_()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return self._mods['original'].__str__()
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet_F(nn.Module):
|
||||
def __init__(self, num_inp, num_out):
|
||||
|
|
|
@ -19,8 +19,8 @@ tf_names = [
|
|||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
#'Posterize',
|
||||
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
'Posterize',
|
||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
#Color TF (Common mag scale)
|
||||
#'+Contrast',
|
||||
|
@ -67,7 +67,7 @@ if __name__ == "__main__":
|
|||
'aug_model'
|
||||
}
|
||||
n_inner_iter = 1
|
||||
epochs = 100
|
||||
epochs = 15
|
||||
dataug_epoch_start=0
|
||||
optim_param={
|
||||
'Meta':{
|
||||
|
@ -81,11 +81,13 @@ if __name__ == "__main__":
|
|||
}
|
||||
}
|
||||
|
||||
model = LeNet(3,10)
|
||||
#model = LeNet(3,10)
|
||||
#model = MobileNetV2(num_classes=10)
|
||||
#model = ResNet(num_classes=10)
|
||||
model = ResNet(num_classes=10)
|
||||
#model = WideResNet(num_classes=10, wrn_size=32)
|
||||
|
||||
model = Higher_model(model) #run_dist_dataugV3
|
||||
|
||||
#### Classic ####
|
||||
if 'classic' in tasks:
|
||||
t0 = time.process_time()
|
||||
|
@ -172,12 +174,12 @@ if __name__ == "__main__":
|
|||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||
log= run_dist_dataugV2(model=aug_model,
|
||||
log= run_dist_dataugV3(model=aug_model,
|
||||
epochs=epochs,
|
||||
inner_it=n_inner_iter,
|
||||
dataug_epoch_start=dataug_epoch_start,
|
||||
opt_param=optim_param,
|
||||
print_freq=10,
|
||||
print_freq=1,
|
||||
KLdiv=True,
|
||||
loss_patience=None)
|
||||
|
||||
|
@ -187,7 +189,7 @@ if __name__ == "__main__":
|
|||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)+"demi_mag"
|
||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)#+"demi_mag"
|
||||
with open("res/log/%s.json" % filename, "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
|
|
|
@ -654,7 +654,7 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
model.train()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
|
@ -742,9 +742,8 @@ def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
#if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
|
@ -835,7 +834,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
inner_opt = torch.optim.SGD(model['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
|
@ -853,12 +852,17 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
|
||||
#fmodel = higher.patch.monkeypatch(model['model'], device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model['model'].parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
|
||||
diffopt = model['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
track_higher_grads=high_grad_track)
|
||||
|
||||
#meta_opt = torch.optim.Adam(fmodel['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
|
||||
print(len(fmodel._fast_params))
|
||||
#print(len(model['model']['functional']._fast_params))
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
|
@ -871,30 +875,30 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
logits = model(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
if model._data_augmentation: #Weight loss
|
||||
w_loss = model['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
if fmodel._data_augmentation :
|
||||
fmodel.augment(mode=False)
|
||||
sup_logits = fmodel(xs)
|
||||
fmodel.augment(mode=True)
|
||||
if model._data_augmentation :
|
||||
model.augment(mode=False)
|
||||
sup_logits = model(xs)
|
||||
model.augment(mode=True)
|
||||
else:
|
||||
sup_logits = fmodel(xs)
|
||||
sup_logits = model(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if fmodel._data_augmentation:
|
||||
aug_logits = fmodel(xs)
|
||||
if model._data_augmentation:
|
||||
aug_logits = model(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss=0
|
||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||
w_loss = model['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
#if epoch>50: #debut differe ?
|
||||
#KL div w/ logits - Similarite predictions (distributions)
|
||||
|
@ -915,75 +919,36 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
|||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
#for _, p in fmodel['data_aug'].named_parameters():
|
||||
# p.requires_grad = False
|
||||
t = time.process_time()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||
print(len(model['model']['functional']._fast_params),"step", time.process_time()-t)
|
||||
|
||||
#for _, p in fmodel['data_aug'].named_parameters():
|
||||
# p.requires_grad = True
|
||||
|
||||
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||
#print("meta")
|
||||
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val) + model['data_aug'].reg_loss()
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
print('proba grad',fmodel['data_aug']['prob'].grad)
|
||||
#countcopy+=1
|
||||
#model_copy(src=fmodel, dst=model)
|
||||
#optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['prob'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
torch.nn.utils.clip_grad_norm_(fmodel['data_aug']['mag'], max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
for paramName, paramValue, in fmodel['data_aug'].named_parameters():
|
||||
for netCopyName, netCopyValue, in model['data_aug'].named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
|
||||
#del meta_opt.param_groups[0]
|
||||
#meta_opt.add_param_group({'params' : [p for p in fmodel['data_aug'].parameters()]})
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
meta_opt.step()
|
||||
fmodel['data_aug'].load_state_dict(model['data_aug'].state_dict())
|
||||
fmodel['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
#model['data_aug'].next_TF_set()
|
||||
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
|
||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
|
||||
#fmodel.fast_params=[higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in fmodel.parameters()]
|
||||
diffopt.detach_()
|
||||
tmp = fmodel.fast_params
|
||||
fmodel._fast_params=[]
|
||||
fmodel.update_params(tmp)
|
||||
for p in fmodel.fast_params:
|
||||
p.detach_().requires_grad_()
|
||||
print(len(fmodel._fast_params))
|
||||
|
||||
print('TF Proba :', fmodel['data_aug']['prob'].data)
|
||||
model['model'].detach_()
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
#model_copy(src=fmodel, dst=model)
|
||||
|
||||
|
||||
if(not high_grad_track):
|
||||
#countcopy+=1
|
||||
#model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||
val_loss = compute_vaLoss(model=model, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
accuracy, test_loss =test(model)
|
||||
model.train()
|
||||
|
|
|
@ -103,7 +103,8 @@ TF_dict={ #Dataugv5
|
|||
}
|
||||
|
||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'}
|
||||
TF_ignore_mag= TF_no_mag | {'Solarize', 'Posterize', '=Solarize', '=Posterize'}
|
||||
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'}
|
||||
TF_ignore_mag= TF_no_mag | TF_no_grad
|
||||
|
||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||
return (float_image*255.).type(torch.uint8)
|
||||
|
|
|
@ -314,4 +314,38 @@ class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
|||
return False
|
||||
|
||||
def reset(self):
|
||||
self.__init__(self.patience, self.end_train)
|
||||
self.__init__(self.patience, self.end_train)
|
||||
|
||||
### https://github.com/facebookresearch/higher/issues/18 ####
|
||||
from torch._six import inf
|
||||
|
||||
def clip_norm(tensors, max_norm, norm_type=2):
|
||||
r"""Clips norm of passed tensors.
|
||||
The norm is computed over all tensors together, as if they were
|
||||
concatenated into a single vector. Clipped tensors are returned.
|
||||
Arguments:
|
||||
tensors (Iterable[Tensor]): an iterable of Tensors or a
|
||||
single Tensor to be normalized.
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
Returns:
|
||||
Clipped (List[Tensor]) tensors.
|
||||
"""
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
tensors = [tensors]
|
||||
tensors = list(tensors)
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(t.abs().max() for t in tensors)
|
||||
else:
|
||||
total_norm = 0
|
||||
for t in tensors:
|
||||
param_norm = t.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef >= 1:
|
||||
return tensors
|
||||
return [t.mul(clip_coef) for t in tensors]
|
Loading…
Add table
Add a link
Reference in a new issue