mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Minor improvement + Comments
This commit is contained in:
parent
d21a6bbf5c
commit
c1ad787d97
5 changed files with 165 additions and 62 deletions
|
@ -6,15 +6,51 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
import higher
|
import higher
|
||||||
class Higher_model(nn.Module):
|
class Higher_model(nn.Module):
|
||||||
|
"""Model wrapper for higher gradient tracking.
|
||||||
|
|
||||||
|
Keep in memory the orginial model and it's functionnal, higher, version.
|
||||||
|
|
||||||
|
Might not be needed anymore if Higher implement detach for fmodel.
|
||||||
|
|
||||||
|
see : https://github.com/facebookresearch/higher
|
||||||
|
|
||||||
|
TODO: Get rid of the original model if not needed by user.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_name (string): Name of the model.
|
||||||
|
_mods (nn.ModuleDict): Models (Orginial and Higher version).
|
||||||
|
"""
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
|
"""Init Higher_model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Network for which higher gradients can be tracked.
|
||||||
|
"""
|
||||||
super(Higher_model, self).__init__()
|
super(Higher_model, self).__init__()
|
||||||
|
|
||||||
|
self._name = model.__str__()
|
||||||
self._mods = nn.ModuleDict({
|
self._mods = nn.ModuleDict({
|
||||||
'original': model,
|
'original': model,
|
||||||
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
'functional': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||||
})
|
})
|
||||||
|
|
||||||
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
|
def get_diffopt(self, opt, grad_callback=None, track_higher_grads=True):
|
||||||
|
"""Get a differentiable version of an Optimizer.
|
||||||
|
|
||||||
|
Higher/Differentiable optimizer required to be used for higher gradient tracking.
|
||||||
|
Usage : diffopt.step(loss) == (opt.zero_grad, loss.backward, opt.step)
|
||||||
|
|
||||||
|
Be warry that if track_higher_grads is set to True, a new state of the model would be saved each time diffopt.step() is called.
|
||||||
|
Thus increasing memory consumption. The detach_() method should be called to reset the gradient tape and prevent memory saturation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
opt (torch.optim): Optimizer to make differentiable.
|
||||||
|
grad_callback (fct(grads)=grads): Function applied to the list of gradients parameters (ex: clipping). (default: None)
|
||||||
|
track_higher_grads (bool): Wether higher gradient are tracked. If True, the graph/states will be retained to allow backpropagation. (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Higher.DifferentiableOptimizer): Differentiable version of the optimizer.
|
||||||
|
"""
|
||||||
return higher.optim.get_diff_optim(opt,
|
return higher.optim.get_diff_optim(opt,
|
||||||
self._mods['original'].parameters(),
|
self._mods['original'].parameters(),
|
||||||
fmodel=self._mods['functional'],
|
fmodel=self._mods['functional'],
|
||||||
|
@ -22,20 +58,49 @@ class Higher_model(nn.Module):
|
||||||
track_higher_grads=track_higher_grads)
|
track_higher_grads=track_higher_grads)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
""" Main method of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): Batch of data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor : Output of the network. Should be logits.
|
||||||
|
"""
|
||||||
return self._mods['functional'](x)
|
return self._mods['functional'](x)
|
||||||
|
|
||||||
def detach_(self):
|
def detach_(self):
|
||||||
|
"""Detach from the graph.
|
||||||
|
|
||||||
|
Needed to limit the number of state kept in memory.
|
||||||
|
"""
|
||||||
tmp = self._mods['functional'].fast_params
|
tmp = self._mods['functional'].fast_params
|
||||||
self._mods['functional']._fast_params=[]
|
self._mods['functional']._fast_params=[]
|
||||||
self._mods['functional'].update_params(tmp)
|
self._mods['functional'].update_params(tmp)
|
||||||
for p in self._mods['functional'].fast_params:
|
for p in self._mods['functional'].fast_params:
|
||||||
p.detach_().requires_grad_()
|
p.detach_().requires_grad_()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Returns a dictionary containing a whole state of the module.
|
||||||
|
"""
|
||||||
|
return self._mods['functional'].state_dict()
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
"""Access to modules
|
||||||
|
Args:
|
||||||
|
key (string): Name of the module to access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module.
|
||||||
|
"""
|
||||||
return self._mods[key]
|
return self._mods[key]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self._mods['original'].__str__()
|
"""Name of the module
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String containing the name of the module.
|
||||||
|
"""
|
||||||
|
return self._name
|
||||||
|
|
||||||
## Basic CNN ##
|
## Basic CNN ##
|
||||||
class LeNet_F(nn.Module):
|
class LeNet_F(nn.Module):
|
||||||
|
|
|
@ -50,7 +50,7 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
|
|
||||||
#model = LeNet(3,10)
|
#model = LeNet(3,10)
|
||||||
model = ResNet(num_classes=10)
|
#model = ResNet(num_classes=10)
|
||||||
#model = MobileNetV2(num_classes=10)
|
#model = MobileNetV2(num_classes=10)
|
||||||
#model = WideResNet(num_classes=10, wrn_size=32)
|
#model = WideResNet(num_classes=10, wrn_size=32)
|
||||||
|
|
||||||
|
@ -126,6 +126,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
|
|
||||||
|
model = ResNet(num_classes=10)
|
||||||
|
model = Higher_model(model) #run_dist_dataugV3
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=n_tf, mix_dist=dist, fixed_prob=p_setup, fixed_mag=m_setup[0], shared_mag=m_setup[1]), model).to(device)
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
|
@ -136,8 +138,7 @@ if __name__ == "__main__":
|
||||||
dataug_epoch_start=dataug_epoch_start,
|
dataug_epoch_start=dataug_epoch_start,
|
||||||
opt_param=optim_param,
|
opt_param=optim_param,
|
||||||
print_freq=50,
|
print_freq=50,
|
||||||
KLdiv=True,
|
KLdiv=True)
|
||||||
loss_patience=None)
|
|
||||||
|
|
||||||
exec_time=time.process_time() - t0
|
exec_time=time.process_time() - t0
|
||||||
####
|
####
|
||||||
|
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
n_inner_iter = 1
|
n_inner_iter = 1
|
||||||
epochs = 150
|
epochs = 150
|
||||||
dataug_epoch_start=10
|
dataug_epoch_start=0
|
||||||
optim_param={
|
optim_param={
|
||||||
'Meta':{
|
'Meta':{
|
||||||
'optim':'Adam',
|
'optim':'Adam',
|
||||||
|
|
|
@ -876,7 +876,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
else:
|
else:
|
||||||
#Methode KL div
|
#Methode KL div
|
||||||
# Supervised loss (classic)
|
# Supervised loss (classic)
|
||||||
if model._data_augmentation :
|
if model.is_augmenting() :
|
||||||
model.augment(mode=False)
|
model.augment(mode=False)
|
||||||
sup_logits = model(xs)
|
sup_logits = model(xs)
|
||||||
model.augment(mode=True)
|
model.augment(mode=True)
|
||||||
|
@ -886,7 +886,7 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
loss = F.cross_entropy(log_sup, ys)
|
loss = F.cross_entropy(log_sup, ys)
|
||||||
|
|
||||||
# Unsupervised loss (KLdiv)
|
# Unsupervised loss (KLdiv)
|
||||||
if model._data_augmentation:
|
if model.is_augmenting() :
|
||||||
aug_logits = model(xs)
|
aug_logits = model(xs)
|
||||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||||
aug_loss=0
|
aug_loss=0
|
||||||
|
@ -948,7 +948,6 @@ def run_dist_dataugV3(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start
|
||||||
accuracy, test_loss =test(model)
|
accuracy, test_loss =test(model)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
print(model['data_aug']._data_augmentation)
|
|
||||||
#### Log ####
|
#### Log ####
|
||||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -53,64 +53,91 @@ TF_dict={ #Dataugv5 #AutoAugment
|
||||||
#'Equalize': (lambda mag: None),
|
#'Equalize': (lambda mag: None),
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
|
# Dictionnary mapping tranformations identifiers to their function.
|
||||||
|
# Each value of the dict should be a lambda function taking a (batch of data, magnitude of transformations) tuple as input and returns a batch of data.
|
||||||
TF_dict={ #Dataugv5
|
TF_dict={ #Dataugv5
|
||||||
## Geometric TF ##
|
## Geometric TF ##
|
||||||
'Identity' : (lambda x, mag: x),
|
'Identity' : (lambda x, mag: x),
|
||||||
'FlipUD' : (lambda x, mag: flipUD(x)),
|
'FlipUD' : (lambda x, mag: flipUD(x)),
|
||||||
'FlipLR' : (lambda x, mag: flipLR(x)),
|
'FlipLR' : (lambda x, mag: flipLR(x)),
|
||||||
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
'Rotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30))),
|
||||||
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
'TranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=0))),
|
||||||
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
'TranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=20), zero_pos=1))),
|
||||||
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
'ShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=0))),
|
||||||
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
'ShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, maxval=0.3), zero_pos=1))),
|
||||||
|
|
||||||
## Color TF (Expect image in the range of [0, 1]) ##
|
## Color TF (Expect image in the range of [0, 1]) ##
|
||||||
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||||
|
|
||||||
#Color TF (Common mag scale)
|
#Color TF (Common mag scale)
|
||||||
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
'+Color':(lambda x, mag: color(x, color_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
'+Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
'+Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||||
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
'-Contrast': (lambda x, mag: contrast(x, contrast_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
'-Color':(lambda x, mag: color(x, color_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||||
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||||
|
|
||||||
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
|
|
||||||
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
|
|
||||||
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
|
|
||||||
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
|
|
||||||
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
|
|
||||||
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
|
|
||||||
|
|
||||||
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
|
|
||||||
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
|
||||||
|
|
||||||
'Random':(lambda x, mag: torch.rand_like(x)),
|
## Bad Tranformations ##
|
||||||
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
|
# Bad Geometric TF #
|
||||||
|
'BShearX': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=0))),
|
||||||
#Non fonctionnel
|
'BShearY': (lambda x, mag: shear(x, shear=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=0.3*3, maxval=0.3*4), zero_pos=1))),
|
||||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
'BTranslateX': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=0))),
|
||||||
#'Equalize': (lambda mag: None),
|
'BTranslateX-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=0))),
|
||||||
|
'BTranslateY': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=25, maxval=30), zero_pos=1))),
|
||||||
|
'BTranslateY-': (lambda x, mag: translate(x, translation=zero_stack(rand_floats(size=(x.shape[0],), mag=mag, minval=-25, maxval=-30), zero_pos=1))),
|
||||||
|
|
||||||
|
# Bad Color TF #
|
||||||
|
'BadContrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9*2, maxval=2*4))),
|
||||||
|
'BadBrightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.9, maxval=2*3))),
|
||||||
|
|
||||||
|
# Random TF #
|
||||||
|
'Random':(lambda x, mag: torch.rand_like(x)),
|
||||||
|
'RandBlend': (lambda x, mag: blend(x,torch.rand_like(x), alpha=torch.tensor(0.7,device=mag.device).expand(x.shape[0]))),
|
||||||
|
|
||||||
|
#Non fonctionnel
|
||||||
|
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||||
|
#'Equalize': (lambda mag: None),
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'}
|
TF_no_mag={'Identity', 'FlipUD', 'FlipLR', 'Random', 'RandBlend'} #TF that don't have use for magnitude parameter.
|
||||||
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'}
|
TF_no_grad={'Solarize', 'Posterize', '=Solarize', '=Posterize'} #TF which implemetation doesn't allow gradient propagaition.
|
||||||
TF_ignore_mag= TF_no_mag | TF_no_grad
|
TF_ignore_mag= TF_no_mag | TF_no_grad #TF for which magnitude should be ignored (Magnitude fixed).
|
||||||
|
|
||||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
def int_image(float_image):
|
||||||
return (float_image*255.).type(torch.uint8)
|
"""Convert a float Tensor/Image to an int Tensor/Image.
|
||||||
|
|
||||||
|
Be warry that this transformation isn't bijective, each conversion will result in small loss of information.
|
||||||
|
Granularity: 1/256 = 0.0039.
|
||||||
|
|
||||||
|
This will also result in the loss of the gradient associated to input as gradient cannot be tracked on int Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
float_image (torch.float): Image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.uint8) Converted tensor.
|
||||||
|
"""
|
||||||
|
return (float_image*255.).type(torch.uint8)
|
||||||
|
|
||||||
def float_image(int_image):
|
def float_image(int_image):
|
||||||
return int_image.type(torch.float)/255.
|
"""Convert a int Tensor/Image to an float Tensor/Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
int_image (torch.uint8): Image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.float) Converted tensor.
|
||||||
|
"""
|
||||||
|
return int_image.type(torch.float)/255.
|
||||||
|
|
||||||
#def rand_inverse(value):
|
#def rand_inverse(value):
|
||||||
# return value if random.random() < 0.5 else -value
|
# return value if random.random() < 0.5 else -value
|
||||||
|
@ -125,11 +152,22 @@ def float_image(int_image):
|
||||||
# if not minval : minval = -real_max
|
# if not minval : minval = -real_max
|
||||||
# return random.uniform(minval, real_max)
|
# return random.uniform(minval, real_max)
|
||||||
|
|
||||||
def rand_floats(size, mag, maxval, minval=None): #[(-maxval,minval), maxval]
|
def rand_floats(size, mag, maxval, minval=None):
|
||||||
real_mag = float_parameter(mag, maxval=maxval)
|
"""Generate a batch of random values.
|
||||||
if not minval : minval = -real_mag
|
|
||||||
#return random.uniform(minval, real_max)
|
Args:
|
||||||
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
size (int): Number of value to generate.
|
||||||
|
mag (float): Level of the operation that will be between [PARAMETER_MIN, PARAMETER_MAX].
|
||||||
|
maxval (float): Maximum value that can be generated. This will be scaled to mag/PARAMETER_MAX.
|
||||||
|
minval (float): Minimum value that can be generated. (default: -maxval)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated batch of float values between [minval, maxval].
|
||||||
|
"""
|
||||||
|
real_mag = float_parameter(mag, maxval=maxval)
|
||||||
|
if not minval : minval = -real_mag
|
||||||
|
#return random.uniform(minval, real_max)
|
||||||
|
return minval + (real_mag-minval) * torch.rand(size, device=mag.device) #[min_val, real_mag]
|
||||||
|
|
||||||
def invScale_rand_floats(size, mag, maxval, minval):
|
def invScale_rand_floats(size, mag, maxval, minval):
|
||||||
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]
|
#Mag=[0,PARAMETER_MAX] => [PARAMETER_MAX, 0] = [maxval, minval]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue