diff --git a/higher/dataug.py b/higher/dataug.py index f2b1d6a..ee21d00 100644 --- a/higher/dataug.py +++ b/higher/dataug.py @@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF) class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) - def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, glob_mag=True): + def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, shared_mag=True): super(Data_augV5, self).__init__() assert len(TF_dict)>0 @@ -542,11 +542,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) self._nb_tf= len(self._TF) self._N_seqTF = N_TF + self._shared_mag = shared_mag #self._fixed_mag=5 #[0, PARAMETER_MAX] self._params = nn.ParameterDict({ "prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme - "mag" : nn.Parameter(torch.tensor(0.5).expand(self._nb_tf) if glob_mag else torch.tensor(0.5).repeat(self._nb_tf)) #[0, PARAMETER_MAX]/10 + "mag" : nn.Parameter(torch.tensor(0.5) if shared_mag + else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10 }) self._samples = [] @@ -591,7 +593,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba) smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim) if smp_x.shape[0]!=0: #if there's data to TF - magnitude=self._params["mag"][tf_idx]*10 + magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx] tf=self._TF[tf_idx] #print(magnitude) diff --git a/higher/test_dataug.py b/higher/test_dataug.py index 1fb9849..2b9ab7a 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -68,7 +68,7 @@ if __name__ == "__main__": t0 = time.process_time() tf_dict = {k: TF.TF_dict[k] for k in tf_names} #tf_dict = TF.TF_dict - aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, glob_mag=False), LeNet(3,10)).to(device) + aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.5, shared_mag=True), LeNet(3,10)).to(device) #aug_model = Augmented_model(Data_augV4(TF_dict=tf_dict, N_TF=2, mix_dist=0.0), WideResNet(num_classes=10, wrn_size=160)).to(device) print(str(aug_model), 'on', device_name) #run_simple_dataug(inner_it=n_inner_iter, epochs=epochs) diff --git a/higher/transformations.py b/higher/transformations.py index a172313..fa4b4b8 100644 --- a/higher/transformations.py +++ b/higher/transformations.py @@ -86,7 +86,7 @@ def zero_stack(tensor, zero_pos): raise Exception("Invalid zero_pos : ", zero_pos) #https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137 -PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted +PARAMETER_MAX = 1 # What is the max 'level' a transform could be predicted def float_parameter(level, maxval): """Helper function to scale `val` between 0 and maxval . Args: @@ -98,7 +98,7 @@ def float_parameter(level, maxval): """ #return float(level) * maxval / PARAMETER_MAX - return (level * maxval / PARAMETER_MAX)#.to(torch.float32) + return (level * maxval / PARAMETER_MAX)#.to(torch.float) def int_parameter(level, maxval): #Perte de gradient """Helper function to scale `val` between 0 and maxval . @@ -135,11 +135,11 @@ def flipUD(x): return kornia.warp_perspective(x, M, dsize=(h, w)) def rotate(x, angle): - return kornia.rotate(x, angle=angle.type(torch.float32)) #Kornia ne supporte pas les int + return kornia.rotate(x, angle=angle.type(torch.float)) #Kornia ne supporte pas les int def translate(x, translation): #print(translation) - return kornia.translate(x, translation=translation.type(torch.float32)) #Kornia ne supporte pas les int + return kornia.translate(x, translation=translation.type(torch.float)) #Kornia ne supporte pas les int def shear(x, shear): return kornia.shear(x, shear=shear)