mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout RandAugment
This commit is contained in:
parent
3c2022de32
commit
4a7e73088d
4 changed files with 249 additions and 37 deletions
170
higher/dataug.py
170
higher/dataug.py
|
@ -659,7 +659,7 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
if self._fixed_mag:
|
||||
if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF
|
||||
return torch.tensor(0)
|
||||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
|
@ -692,6 +692,174 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
else:
|
||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||
super(RandAug, self).__init__()
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #pas utilise
|
||||
"mag" : nn.Parameter(torch.tensor(float(mag))),
|
||||
})
|
||||
self._shared_mag = True
|
||||
self._fixed_mag = True
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
for _ in range(self._N_seqTF):
|
||||
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*uniforme_dist)
|
||||
sample = cat_distrib.sample()
|
||||
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
smps_x=[]
|
||||
|
||||
for tf_idx in range(self._nb_tf):
|
||||
mask = sampled_TF==tf_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._params["mag"].detach()
|
||||
|
||||
tf=self._TF[tf_idx]
|
||||
#print(magnitude)
|
||||
|
||||
#In place
|
||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||
|
||||
return x
|
||||
|
||||
def adjust_param(self, soft=False):
|
||||
pass #Pas de parametre a opti
|
||||
|
||||
def loss_weight(self):
|
||||
return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
return torch.tensor(0) #Pas de regularisation
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(RandAug, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||
|
||||
class RandAugUDA(nn.Module): #RandAugment = UniformFx-MagFxSh
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||
super(RandAug, self).__init__()
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
self._N_seqTF = N_TF
|
||||
|
||||
self.mag=nn.Parameter(torch.tensor(float(mag)))
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.tensor(0.5)),
|
||||
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX))),
|
||||
})
|
||||
self._shared_mag = True
|
||||
self._fixed_mag = True
|
||||
|
||||
self._op_list =[]
|
||||
for tf in self._TF:
|
||||
for mag in range(0.1, self._params['mag'], 0.1):
|
||||
op_list+=[(tf, self._params['prob'], mag)]
|
||||
self._nb_op = len(self._op_list)
|
||||
|
||||
print(self._op_list)
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
for _ in range(self._N_seqTF):
|
||||
## Echantillonage ## == sampled_ops = np.random.choice(transforms, N)
|
||||
uniforme_dist = torch.ones(1, self._nb_op, device=device).softmax(dim=1)
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_op), device=device)*uniforme_dist)
|
||||
sample = cat_distrib.sample()
|
||||
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
smps_x=[]
|
||||
|
||||
for op_idx in range(self._nb_op):
|
||||
mask = sampled_TF==tf_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ? (Necessite d'expand le mask au meme dim)
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
if TF.random.random() < self.op_list[op_idx][1]:
|
||||
magnitude=self.op_list[op_idx][2]
|
||||
tf=self.op_list[op_idx][0]
|
||||
|
||||
#In place
|
||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||
|
||||
return x
|
||||
|
||||
def adjust_param(self, soft=False):
|
||||
pass #Pas de parametre a opti
|
||||
|
||||
def loss_weight(self):
|
||||
return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
return torch.tensor(0) #Pas de regularisation
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(RandAug, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "RandAug(%dTFx%d-Mag%d)" % (self._nb_tf, self._N_seqTF, self.mag)
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
def __init__(self, data_augmenter, model):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue