mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
added faster filtering and convolution, but not working yet for BU
This commit is contained in:
parent
9d68bc30bd
commit
f7436d0002
5 changed files with 295 additions and 11 deletions
|
@ -4,6 +4,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
import opt_einsum as oe
|
||||
|
||||
class SConv2dStride(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True,bias=False):
|
||||
|
@ -54,6 +55,73 @@ class SConv2dAvg(nn.Module):
|
|||
batch_size, in_channels, in_h, in_w = input.shape
|
||||
out_channels, in_channels, kh, kw = self.weight.shape
|
||||
|
||||
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
||||
out_h = math.ceil(afterconv_h/stride)
|
||||
out_w = math.ceil(afterconv_w/stride)
|
||||
else: #ceil_mode = false default mode for pooling
|
||||
out_h = math.floor(afterconv_h/stride)
|
||||
out_w = math.floor(afterconv_w/stride)
|
||||
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||
if stride!=1: # if stride==1 there is no pooling
|
||||
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||
if selh[0,0]==-1: # if not given sampled selection
|
||||
#selction of where to sample for each pooling location
|
||||
sel = torch.randint(stride*stride,(out_h,out_w), device=device)
|
||||
|
||||
if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
||||
resth = (out_h*stride)-afterconv_h
|
||||
restw = (out_w*stride)-afterconv_w
|
||||
if resth!=0:
|
||||
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
|
||||
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
|
||||
#print(stride-resth,sel[-1])
|
||||
#print(stride-restw,sel[:,-1])
|
||||
|
||||
rng = torch.arange(0,afterconv_h*afterconv_w,stride*stride,device=device).view(out_h,out_w)
|
||||
index = sel+rng
|
||||
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||
|
||||
|
||||
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||
else:#in case of a valid mask use selection only on the mask locations
|
||||
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
||||
|
||||
#Matrix mul
|
||||
if self.bias is None:
|
||||
#flt = self.weight.view(self.weight.size(0), -1).t()
|
||||
#out_unf = inp_unf.transpose(2,1).matmul(flt).transpose(1, 2)
|
||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
||||
#print(((out_unf-out_unf1)**2).mean())
|
||||
else:
|
||||
#out_unf = oe.contract('bji,kj,b->bki',inp_unf,self.weight.view(self.weight.size(0), -1),self.bias,backend='torch')#+self.bias.view(1,-1,1)#still slow
|
||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)#still slow
|
||||
#self.flt = self.weight.view(self.weight.size(0), -1).t()
|
||||
#out_unf = (inp_unf.transpose(1, 2).matmul(self.flt) + self.bias).transpose(1, 2)
|
||||
|
||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||
#if stoch==False: #this is done outside for more clarity
|
||||
# out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||
else:#in case of mask
|
||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||
out[:,:,mask>0] = out_unf
|
||||
return out
|
||||
|
||||
|
||||
def forward_(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||
device=input.device
|
||||
if stride==-1:
|
||||
stride = self.stride #if stride not defined use self.stride
|
||||
if stoch==False:
|
||||
stride=1 #test with real average pooling
|
||||
batch_size, in_channels, in_h, in_w = input.shape
|
||||
out_channels, in_channels, kh, kw = self.weight.shape
|
||||
|
||||
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
||||
|
@ -87,8 +155,10 @@ class SConv2dAvg(nn.Module):
|
|||
|
||||
#Matrix mul
|
||||
if self.bias is None:
|
||||
out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
||||
#out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
||||
else:
|
||||
dgdg
|
||||
out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
||||
|
||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue