mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
added faster BU
This commit is contained in:
parent
286921f8a0
commit
e2db0e6057
2 changed files with 112 additions and 21 deletions
|
@ -31,7 +31,7 @@ class SConv2dAvg(nn.Module):
|
|||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True, bias = True):
|
||||
super(SConv2dAvg, self).__init__()
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size)
|
||||
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
|
||||
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=padding, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
|
||||
nn.init.constant_(self.deconv.weight, 1)
|
||||
self.pooldeconv = nn.ConvTranspose2d(1, 1, kernel_size=stride,padding=0,stride=stride, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
|
||||
nn.init.constant_(self.pooldeconv.weight, 1)
|
||||
|
@ -47,7 +47,7 @@ class SConv2dAvg(nn.Module):
|
|||
self.ceil_mode = ceil_mode
|
||||
|
||||
|
||||
def forward_fast(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
|
||||
def forward(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
|
||||
device=input.device
|
||||
if stride==-1:
|
||||
stride = self.stride #if stride not defined use self.stride
|
||||
|
@ -60,14 +60,20 @@ class SConv2dAvg(nn.Module):
|
|||
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||
|
||||
if index[0]==-1 and stride!=1: #or stride!=1:
|
||||
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
|
||||
|
||||
if stride!=1:
|
||||
if len(index.shape)==1: #or stride!=1:
|
||||
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
|
||||
|
||||
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||
else:#in case of a valid mask use selection only on the mask locations
|
||||
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
||||
#inp_unf = inp_unf[index[:,:,mask>0]]
|
||||
#mindex = index[mask>0]
|
||||
mindex = torch.masked_select(index, mask>0)
|
||||
index = mindex.repeat(batch_size,in_channels*kh*kw,1)
|
||||
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
|
||||
|
||||
#Matrix mul
|
||||
if self.bias is None:
|
||||
|
@ -89,7 +95,9 @@ class SConv2dAvg(nn.Module):
|
|||
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||
else:#in case of mask
|
||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
|
||||
out[:,:,mask>0] = out_unf
|
||||
#out.masked_scatter_(mask>0, out_unf)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -167,7 +175,7 @@ class SConv2dAvg(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||
device=input.device
|
||||
if stride==-1:
|
||||
stride = self.stride #if stride not defined use self.stride
|
||||
|
@ -373,23 +381,28 @@ class SConv2dAvg(nn.Module):
|
|||
resth = (out_h*stride)-afterconv_h
|
||||
restw = (out_w*stride)-afterconv_w
|
||||
if resth!=0:
|
||||
print("stride",stride,"str-rest",stride-resth,stride-restw)
|
||||
print('before',sel[-1],sel[:,-1])
|
||||
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
|
||||
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
|
||||
print('after',sel[-1],sel[:,-1])
|
||||
input()
|
||||
|
||||
rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w)
|
||||
index = sel+rng
|
||||
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||
#index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||
|
||||
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||
|
||||
if mask[0,0]!=-1:
|
||||
maskh = (out_h)*stride
|
||||
maskw = (out_w)*stride
|
||||
nmask = torch.zeros((maskh,maskw),device=device)
|
||||
nmask[rng_h,rng_w] = 1
|
||||
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
|
||||
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||
nmask[index] = 1
|
||||
#rmask = mask * nmask
|
||||
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
|
||||
rmask = nmask * dmask
|
||||
rmask = nmask.view(1,1,maskh,maskw) * dmask
|
||||
#rmask = rmask[:,:,:out_h,:out_w]
|
||||
# print('rmask', rmask.shape)
|
||||
fmask = self.deconv(rmask)
|
||||
|
@ -398,6 +411,22 @@ class SConv2dAvg(nn.Module):
|
|||
|
||||
return index,mask#.long()
|
||||
|
||||
def get_mask(self,in_h,in_w,batch_size,device,mask=-torch.ones(1,1)):
|
||||
maskh = (out_h)*stride
|
||||
maskw = (out_w)*stride
|
||||
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
|
||||
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||
nmask[index[0,0]] = 1
|
||||
#rmask = mask * nmask
|
||||
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
|
||||
rmask = nmask.view(1,1,maskh,maskw) * dmask
|
||||
#rmask = rmask[:,:,:out_h,:out_w]
|
||||
# print('rmask', rmask.shape)
|
||||
fmask = self.deconv(rmask)
|
||||
# print('fmask', fmask.shape)
|
||||
mask = fmask[0,0].long()
|
||||
return mask
|
||||
|
||||
|
||||
def get_size(self,in_h,in_w,stride=-1):
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue