Log et utils pour Dataugv5

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-18 16:48:51 -05:00
parent 860d9f1bbb
commit f4bdd9bca5
6 changed files with 57 additions and 45 deletions

View file

@ -531,7 +531,7 @@ class Data_augV4(nn.Module): #Transformations avec mask
return "Data_augV4(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, shared_mag=True):
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_mag=True, shared_mag=True):
super(Data_augV5, self).__init__()
assert len(TF_dict)>0
@ -543,14 +543,14 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
self._N_seqTF = N_TF
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
#self._fixed_mag=5 #[0, PARAMETER_MAX]
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
"mag" : nn.Parameter(torch.tensor(0.5) if shared_mag
"mag" : nn.Parameter(torch.tensor(0.5) if self._shared_mag
else torch.tensor(0.5).expand(self._nb_tf)), #[0, PARAMETER_MAX]/10
})
self._samples = []
self._mix_dist = False
@ -594,23 +594,19 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
tf=self._TF[tf_idx]
#print(magnitude)
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude) # Refusionner eviter x[mask] : in place
#In place
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
#Out of place
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
idx= mask.nonzero()
#print('-'*8)
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#print(idx.shape, smp_x.shape)
#print(idx[0], tf_idx)
#print(smp_x[0,])
#x=x.view(-1,3*32*32)
#smp_x=smp_x.view(-1,3*32*32)
x=x.scatter(dim=0, index=idx, src=smp_x)
#x=x.view(-1,3,32,32)
#print(x[0,])
return x
@ -663,10 +659,13 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
return self._params[key]
def __str__(self):
mag_param='Mag'
if self._fixed_mag: mag_param+= 'Fx'
if self._shared_mag: mag_param+= 'Sh'
if not self._mix_dist:
return "Data_augV5(Uniform-%d TF x %d)" % (self._nb_tf, self._N_seqTF)
return "Data_augV5(Uniform-%dTFx%d-%s)" % (self._nb_tf, self._N_seqTF, mag_param)
else:
return "Data_augV5(Mix %.1f-%d TF x %d)" % (self._mix_factor, self._nb_tf, self._N_seqTF)
return "Data_augV5(Mix%.1f-%dTFx%d-%s)" % (self._mix_factor, self._nb_tf, self._N_seqTF, mag_param)
class Augmented_model(nn.Module):