mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Dataugv5- Modification des TF pour propagation du gradient (mag)
This commit is contained in:
parent
05f81787d6
commit
994d657a28
5 changed files with 94 additions and 21 deletions
|
@ -583,19 +583,33 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
device = x.device
|
||||
batch_size, channels, h, w = x.shape
|
||||
smps_x=[]
|
||||
masks=[]
|
||||
|
||||
for tf_idx in range(self._nb_tf):
|
||||
mask = sampled_TF==tf_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ?
|
||||
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._params["mag"][tf_idx]*10
|
||||
tf=self._TF[tf_idx]
|
||||
#print(magnitude)
|
||||
|
||||
x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
||||
|
||||
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
|
||||
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||
|
||||
idx= mask.nonzero()
|
||||
#print('-'*8)
|
||||
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
#print(idx.shape, smp_x.shape)
|
||||
#print(idx[0], tf_idx)
|
||||
#print(smp_x[0,])
|
||||
#x=x.view(-1,3*32*32)
|
||||
#smp_x=smp_x.view(-1,3*32*32)
|
||||
x=x.scatter(dim=0, index=idx, src=smp_x)
|
||||
#x=x.view(-1,3,32,32)
|
||||
#print(x[0,])
|
||||
|
||||
return x
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue