Modif solarize (Tjrs pas differentiable...)

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-27 17:19:51 -05:00
parent 4a7e73088d
commit d822f8f92e
5 changed files with 56 additions and 42 deletions

View file

@ -298,12 +298,12 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
batch_size, channels, h, w = x.shape
imgs=[]
for idx, t in enumerate(thresholds): #Operation par image
mask = x[idx] > t #Perte du gradient
#imgs=[]
#for idx, t in enumerate(thresholds): #Operation par image
# mask = x[idx] > t #Perte du gradient
#In place
inv_x = 1-x[idx][mask]
x[idx][mask]=inv_x
# inv_x = 1-x[idx][mask]
# x[idx][mask]=inv_x
#
#Out of place
@ -316,6 +316,18 @@ def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
#
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
#print(thresholds.grad_fn)
x=torch.where(x>thresholds,1-x, x)
#print(mask.grad_fn)
#x=x.min(thresholds)
#inv_x = 1-x[mask]
#x=x.where(x<thresholds,1-x)
#x[mask]=inv_x
#x=x.masked_scatter(mask, inv_x)
return x
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818