Modif Dataugv6

This commit is contained in:
Harle, Antoine (Contracteur) 2019-12-02 06:37:19 -05:00
parent ebee1b789f
commit 3ec99bf729
6 changed files with 334 additions and 36 deletions

View file

@ -692,6 +692,208 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
else:
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
import numpy as np
class Data_augV6(nn.Module): #Optimisation sequentielle
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
super(Data_augV6, self).__init__()
assert len(TF_dict)>0
self._data_augmentation = True
self._TF_dict = TF_dict
self._TF= list(self._TF_dict.keys())
self._nb_tf= len(self._TF)
self._N_seqTF = N_TF
self._shared_mag = shared_mag
self._fixed_mag = fixed_mag
self._TF_set_size=3
#if self._TF_set_size>self._nb_tf:
# print("Warning : TF sets size higher than number of TF. Reducing set size to %d"%self._nb_tf)
# self._TF_set_size=self._nb_tf
assert self._nb_tf>=self._TF_set_size
self._TF_sets=[]
for i in range(1,self._nb_tf):
for j in range(i,self._nb_tf):
if i!=j:
self._TF_sets+=[torch.tensor([0, i, j])]
#print(self._TF_sets)
#self._TF_sets=[torch.tensor([0, i, j]) for i in range(1,self._nb_tf)] #All VS Identity
self._TF_schedule = [list(range(len(self._TF_sets))) for _ in range(self._N_seqTF)]
for n_tf in range(self._N_seqTF) :
TF.random.shuffle(self._TF_schedule[n_tf])
#print(self._TF_schedule)
self._current_TF_idx=0 #random.randint
self._start_prob = 1/self._TF_set_size
self._params = nn.ParameterDict({
"prob": nn.Parameter(torch.tensor(self._start_prob).expand(self._nb_tf)), #Proba independantes
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)) if self._shared_mag
else torch.tensor(float(TF.PARAMETER_MAX)).expand(self._nb_tf)), #[0, PARAMETER_MAX]
})
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
#Distribution
self._fixed_prob=fixed_prob
self._samples = []
self._mix_dist = False
if mix_dist != 0.0:
self._mix_dist = True
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
#Mag regularisation
if not self._fixed_mag:
if self._shared_mag :
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
else:
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
def forward(self, x):
self._samples = []
if self._data_augmentation:# and TF.random.random() < 0.5:
device = x.device
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
for n_tf in range(self._N_seqTF):
tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(device)
#print(n_tf, tf_set)
## Echantillonage ##
uniforme_dist = torch.ones(1,len(tf_set),device=device).softmax(dim=1)
if not self._mix_dist:
self._distrib = uniforme_dist
else:
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
curr_prob = torch.index_select(prob, 0, tf_set)
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib)
sample = cat_distrib.sample()
self._samples.append(sample)
## Transformations ##
x = self.apply_TF(x, sample)
return x
def apply_TF(self, x, sampled_TF):
device = x.device
batch_size, channels, h, w = x.shape
smps_x=[]
for sel_idx, tf_idx in enumerate(self._TF_sets[self._current_TF_idx]):
mask = sampled_TF==sel_idx #Create selection mask
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
if smp_x.shape[0]!=0: #if there's data to TF
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
tf=self._TF[tf_idx]
#print(magnitude)
#In place
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
#Out of place
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
idx= mask.nonzero()
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
x=x.scatter(dim=0, index=idx, src=smp_x)
return x
def adjust_param(self, soft=False): #Detach from gradient ?
if not self._fixed_prob:
if soft :
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
else:
self._params['prob'].data = F.relu(self._params['prob'].data)
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
#self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
self._params['prob'].data[0]=self._start_prob #Fixe p identite
if not self._fixed_mag:
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
def loss_weight(self): #A verifier
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
w_loss = torch.zeros((self._samples[0].shape[0],self._TF_set_size), device=self._samples[0].device)
for n_tf in range(self._N_seqTF):
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
tmp_w.scatter_(dim=1, index=self._samples[n_tf].view(-1,1), value=1/self._N_seqTF)
tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(prob.device)
curr_prob = torch.index_select(prob, 0, tf_set)
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
#ATTENTION DISTRIB DIFFERENTE AVEC MIX
assert not self._mix_dist
w_loss += tmp_w * curr_prob /self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
w_loss = torch.sum(w_loss,dim=1)
return w_loss
def reg_loss(self, reg_factor=0.005):
if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF
return torch.tensor(0)
else:
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
def next_TF_set(self, idx=None):
if idx:
self._current_TF_idx=idx
else:
self._current_TF_idx+=1
if self._current_TF_idx== len(self._TF_schedule[0]):
self._current_TF_idx=0
#for n_tf in range(self._N_seqTF) :
# TF.random.shuffle(self._TF_schedule[n_tf])
#print(self._TF_schedule)
#print("Current TF :",self._TF_sets[self._current_TF_idx])
def train(self, mode=None):
if mode is None :
mode=self._data_augmentation
self.augment(mode=mode) #Inutile si mode=None
super(Data_augV6, self).train(mode)
def eval(self):
self.train(mode=False)
def augment(self, mode=True):
self._data_augmentation=mode
def __getitem__(self, key):
return self._params[key]
def __str__(self):
dist_param=''
if self._fixed_prob: dist_param+='Fx'
mag_param='Mag'
if self._fixed_mag: mag_param+= 'Fx'
if self._shared_mag: mag_param+= 'Sh'
if not self._mix_dist:
return "Data_augV6(Uniform%s-%dTF(%d)x%d-%s)" % (dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param)
else:
return "Data_augV6(Mix%.1f%s-%dTF(%d)x%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param)
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
super(RandAug, self).__init__()