mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Modif pour shared_mag
This commit is contained in:
parent
9ad3f0453b
commit
860d9f1bbb
3 changed files with 10 additions and 8 deletions
|
@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
|
||||||
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
|
||||||
|
|
||||||
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, glob_mag=True):
|
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, shared_mag=True):
|
||||||
super(Data_augV5, self).__init__()
|
super(Data_augV5, self).__init__()
|
||||||
assert len(TF_dict)>0
|
assert len(TF_dict)>0
|
||||||
|
|
||||||
|
@ -542,11 +542,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
self._nb_tf= len(self._TF)
|
self._nb_tf= len(self._TF)
|
||||||
|
|
||||||
self._N_seqTF = N_TF
|
self._N_seqTF = N_TF
|
||||||
|
self._shared_mag = shared_mag
|
||||||
|
|
||||||
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
#self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||||
self._params = nn.ParameterDict({
|
self._params = nn.ParameterDict({
|
||||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||||
"mag" : nn.Parameter(torch.tensor(0.5).expand(self._nb_tf) if glob_mag else torch.tensor(0.5).repeat(self._nb_tf)) #[0, PARAMETER_MAX]/10
|
"mag" : nn.Parameter(torch.tensor(0.5) if shared_mag
|
||||||
|
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
|
||||||
})
|
})
|
||||||
|
|
||||||
self._samples = []
|
self._samples = []
|
||||||
|
@ -591,7 +593,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
||||||
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
||||||
|
|
||||||
if smp_x.shape[0]!=0: #if there's data to TF
|
if smp_x.shape[0]!=0: #if there's data to TF
|
||||||
magnitude=self._params["mag"][tf_idx]*10
|
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
||||||
tf=self._TF[tf_idx]
|
tf=self._TF[tf_idx]
|
||||||
#print(magnitude)
|
#print(magnitude)
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,7 @@ if __name__ == "__main__":
|
||||||
t0 = time.process_time()
|
t0 = time.process_time()
|
||||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||||
#tf_dict = TF.TF_dict
|
#tf_dict = TF.TF_dict
|
||||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, glob_mag=False), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, shared_mag=True), LeNet(3,10)).to(device)
|
||||||
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
#aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
|
|
|
@ -86,7 +86,7 @@ def zero_stack(tensor, zero_pos):
|
||||||
raise Exception("Invalid zero_pos : ", zero_pos)
|
raise Exception("Invalid zero_pos : ", zero_pos)
|
||||||
|
|
||||||
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
||||||
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
|
PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted
|
||||||
def float_parameter(level, maxval):
|
def float_parameter(level, maxval):
|
||||||
"""Helper function to scale `val` between 0 and maxval .
|
"""Helper function to scale `val` between 0 and maxval .
|
||||||
Args:
|
Args:
|
||||||
|
@ -98,7 +98,7 @@ def float_parameter(level, maxval):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#return float(level) * maxval / PARAMETER_MAX
|
#return float(level) * maxval / PARAMETER_MAX
|
||||||
return (level * maxval / PARAMETER_MAX)#.to(torch.float32)
|
return (level * maxval / PARAMETER_MAX)#.to(torch.float)
|
||||||
|
|
||||||
def int_parameter(level, maxval): #Perte de gradient
|
def int_parameter(level, maxval): #Perte de gradient
|
||||||
"""Helper function to scale `val` between 0 and maxval .
|
"""Helper function to scale `val` between 0 and maxval .
|
||||||
|
@ -135,11 +135,11 @@ def flipUD(x):
|
||||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||||
|
|
||||||
def rotate(x, angle):
|
def rotate(x, angle):
|
||||||
return kornia.rotate(x, angle=angle.type(torch.float32)) #Kornia ne supporte pas les int
|
return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int
|
||||||
|
|
||||||
def translate(x, translation):
|
def translate(x, translation):
|
||||||
#print(translation)
|
#print(translation)
|
||||||
return kornia.translate(x, translation=translation.type(torch.float32)) #Kornia ne supporte pas les int
|
return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int
|
||||||
|
|
||||||
def shear(x, shear):
|
def shear(x, shear):
|
||||||
return kornia.shear(x, shear=shear)
|
return kornia.shear(x, shear=shear)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue