added faster BU

This commit is contained in:
Marco Pedersoli 2020-06-18 21:59:19 -04:00
parent 286921f8a0
commit e2db0e6057
2 changed files with 112 additions and 21 deletions

View file

@ -31,7 +31,7 @@ class SConv2dAvg(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True, bias = True):
super(SConv2dAvg, self).__init__()
conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=padding, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
nn.init.constant_(self.deconv.weight, 1)
self.pooldeconv = nn.ConvTranspose2d(1, 1, kernel_size=stride,padding=0,stride=stride, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
nn.init.constant_(self.pooldeconv.weight, 1)
@ -47,7 +47,7 @@ class SConv2dAvg(nn.Module):
self.ceil_mode = ceil_mode
def forward_fast(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
def forward(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
device=input.device
if stride==-1:
stride = self.stride #if stride not defined use self.stride
@ -60,14 +60,20 @@ class SConv2dAvg(nn.Module):
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
if index[0]==-1 and stride!=1: #or stride!=1:
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
if stride!=1:
if len(index.shape)==1: #or stride!=1:
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
if mask[0,0]==-1:# in case of not given mask use only sampled selection
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
else:#in case of a valid mask use selection only on the mask locations
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
#inp_unf = inp_unf[index[:,:,mask>0]]
#mindex = index[mask>0]
mindex = torch.masked_select(index, mask>0)
index = mindex.repeat(batch_size,in_channels*kh*kw,1)
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
#Matrix mul
if self.bias is None:
@ -89,7 +95,9 @@ class SConv2dAvg(nn.Module):
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
else:#in case of mask
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
out[:,:,mask>0] = out_unf
#out.masked_scatter_(mask>0, out_unf)
return out
@ -167,7 +175,7 @@ class SConv2dAvg(nn.Module):
return out
def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
device=input.device
if stride==-1:
stride = self.stride #if stride not defined use self.stride
@ -373,23 +381,28 @@ class SConv2dAvg(nn.Module):
resth = (out_h*stride)-afterconv_h
restw = (out_w*stride)-afterconv_w
if resth!=0:
print("stride",stride,"str-rest",stride-resth,stride-restw)
print('before',sel[-1],sel[:,-1])
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
print('after',sel[-1],sel[:,-1])
input()
rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w)
index = sel+rng
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
#index = index.repeat(batch_size,in_channels*kh*kw,1,1)
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
if mask[0,0]!=-1:
maskh = (out_h)*stride
maskw = (out_w)*stride
nmask = torch.zeros((maskh,maskw),device=device)
nmask[rng_h,rng_w] = 1
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
nmask[index] = 1
#rmask = mask * nmask
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
rmask = nmask * dmask
rmask = nmask.view(1,1,maskh,maskw) * dmask
#rmask = rmask[:,:,:out_h,:out_w]
# print('rmask', rmask.shape)
fmask = self.deconv(rmask)
@ -398,6 +411,22 @@ class SConv2dAvg(nn.Module):
return index,mask#.long()
def get_mask(self,in_h,in_w,batch_size,device,mask=-torch.ones(1,1)):
maskh = (out_h)*stride
maskw = (out_w)*stride
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
nmask[index[0,0]] = 1
#rmask = mask * nmask
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
rmask = nmask.view(1,1,maskh,maskw) * dmask
#rmask = rmask[:,:,:out_h,:out_w]
# print('rmask', rmask.shape)
fmask = self.deconv(rmask)
# print('fmask', fmask.shape)
mask = fmask[0,0].long()
return mask
def get_size(self,in_h,in_w,stride=-1):