Changement mesure memoire + Tests solarize differentiable

This commit is contained in:
Harle, Antoine (Contracteur) 2020-02-10 14:36:12 -05:00
parent 6277e268c1
commit 7d5aa7c6fb
4 changed files with 102 additions and 66 deletions

View file

@ -346,10 +346,14 @@ def posterize(x, bits):
return float_image(x & mask)
import torch.nn.functional as F
def solarize(x, thresholds):
"""Invert all pixel values above a threshold.
Be warry that the use of the inequality (x>tresholds) block the gradient propagation.
TODO : Make differentiable.
Args:
x (Tensor): Batch of images.
thresholds (Tensor): All pixels above this level are inverted
@ -386,6 +390,25 @@ def solarize(x, thresholds):
#x[mask]=inv_x
#x=x.masked_scatter(mask, inv_x)
#Differentiable (/Thresholds) ?
#inv_x_bT= F.relu(x) - F.relu(x - thresholds)
#inv_x_aT= 1-x #Besoin thresholds
#print('-'*10)
#print(thresholds[0])
#print(x[0])
#print(inv_x_bT[0])
#print(inv_x_aT[0])
#x=torch.where(x>thresholds,inv_x_aT, inv_x_bT)
#print(torch.allclose(x, x+0.001, atol=1e-3))
#print(torch.allclose(x, sol_x, atol=1e-2))
#print(torch.eq(x,sol_x)[0])
#print(x[0])
#print(sol_x[0])
#'''
return x
def blend(x,y,alpha):