mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Modif solarize (Tjrs pas differentiable...)
This commit is contained in:
parent
4a7e73088d
commit
d822f8f92e
5 changed files with 56 additions and 42 deletions
|
@ -298,12 +298,12 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH
|
|||
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
||||
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
|
||||
batch_size, channels, h, w = x.shape
|
||||
imgs=[]
|
||||
for idx, t in enumerate(thresholds): #Operation par image
|
||||
mask = x[idx] > t #Perte du gradient
|
||||
#imgs=[]
|
||||
#for idx, t in enumerate(thresholds): #Operation par image
|
||||
# mask = x[idx] > t #Perte du gradient
|
||||
#In place
|
||||
inv_x = 1-x[idx][mask]
|
||||
x[idx][mask]=inv_x
|
||||
# inv_x = 1-x[idx][mask]
|
||||
# x[idx][mask]=inv_x
|
||||
#
|
||||
|
||||
#Out of place
|
||||
|
@ -316,6 +316,18 @@ def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
|||
#idxs=idxs.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
#x=x.scatter(dim=0, index=idxs, src=torch.stack(imgs))
|
||||
#
|
||||
|
||||
thresholds = thresholds.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
#print(thresholds.grad_fn)
|
||||
x=torch.where(x>thresholds,1-x, x)
|
||||
#print(mask.grad_fn)
|
||||
|
||||
#x=x.min(thresholds)
|
||||
#inv_x = 1-x[mask]
|
||||
#x=x.where(x<thresholds,1-x)
|
||||
#x[mask]=inv_x
|
||||
#x=x.masked_scatter(mask, inv_x)
|
||||
|
||||
return x
|
||||
|
||||
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue