mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 17:50:46 +02:00
added faster BU
This commit is contained in:
parent
286921f8a0
commit
e2db0e6057
2 changed files with 112 additions and 21 deletions
|
@ -101,14 +101,36 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
|
||||||
#out = (self.fc1(out))
|
#out = (self.fc1(out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
|
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
|
||||||
def __init__(self):
|
def __init__(self,k=3):
|
||||||
|
super(MyLeNetMatNormalNoceil, self).__init__()
|
||||||
|
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.fc1 = nn.Linear(1600*k, 10)
|
||||||
|
|
||||||
|
def forward(self, x, stoch=True):
|
||||||
|
out = F.relu(self.conv1(x,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||||
|
out = F.relu(self.conv2(out,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||||
|
out = F.relu(self.conv3(out,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||||
|
out = F.relu(self.conv4(out,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,4,ceil_mode=True)
|
||||||
|
out = out.view(out.size(0), -1 )
|
||||||
|
out = self.fc1(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
|
||||||
|
def __init__(self,k=3):
|
||||||
super(MyLeNetMatStochNoceil, self).__init__()
|
super(MyLeNetMatStochNoceil, self).__init__()
|
||||||
self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False)
|
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False)
|
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False)
|
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False)
|
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
||||||
#self.fc1 = nn.Linear(800, 10)
|
self.fc1 = nn.Linear(1600*k, 10)
|
||||||
|
|
||||||
def forward(self, x, stoch=True):
|
def forward(self, x, stoch=True):
|
||||||
#print('in',x.shape)
|
#print('in',x.shape)
|
||||||
|
@ -118,10 +140,10 @@ class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
|
||||||
#print('c2', out.shape)
|
#print('c2', out.shape)
|
||||||
out = F.relu(self.conv3(out,stoch=stoch))
|
out = F.relu(self.conv3(out,stoch=stoch))
|
||||||
#print('c3',out.shape)
|
#print('c3',out.shape)
|
||||||
out = self.conv4(out,stoch=stoch)
|
out = F.relu(self.conv4(out,stoch=stoch))
|
||||||
#print('c4',out.shape)
|
#print('c4',out.shape)
|
||||||
out = out.view(out.size(0), -1 )
|
out = out.view(out.size(0), -1 )
|
||||||
#out = self.fc1(out)
|
out = self.fc1(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,10 +170,50 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
|
||||||
#out = self.fc1(out)
|
#out = self.fc1(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
|
||||||
|
def __init__(self,k=3):
|
||||||
|
super(MyLeNetMatStochBUNoceil, self).__init__()
|
||||||
|
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
||||||
|
self.fc1 = nn.Linear(1600*k, 10)
|
||||||
|
|
||||||
|
def forward(self, x, stoch=True):
|
||||||
|
#get sizes
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
device = x.device
|
||||||
|
h0,w0 = x.shape[2],x.shape[3]
|
||||||
|
_,_,h1,w1 = self.conv1.get_size(h0,w0)
|
||||||
|
_,_,h2,w2 = self.conv2.get_size(h1,w1)
|
||||||
|
_,_,h3,w3 = self.conv3.get_size(h2,w2)
|
||||||
|
_,_,h4,w4 = self.conv4.get_size(h3,w3)
|
||||||
|
# print(h0,w0)
|
||||||
|
# print(h1,w1)
|
||||||
|
# print(h2,w2)
|
||||||
|
# print(h3,w3)
|
||||||
|
|
||||||
|
#sample BU
|
||||||
|
mask4 = torch.ones(h4,w4).to(x.device)
|
||||||
|
# print(mask3.shape)
|
||||||
|
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
|
||||||
|
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
|
||||||
|
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
|
||||||
|
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
|
||||||
|
|
||||||
|
##forward
|
||||||
|
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
|
||||||
|
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
|
||||||
|
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
|
||||||
|
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
|
||||||
|
out = out.view(out.size(0), -1 )
|
||||||
|
out = self.fc1(out)
|
||||||
|
return out
|
||||||
|
|
||||||
class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyLeNetMatStochBU, self).__init__()
|
super(MyLeNetMatStochBU, self).__init__()
|
||||||
self.conv1 = SConv2dAvg(3, 200, 3, stride=2)
|
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2)
|
||||||
self.conv2 = SConv2dAvg(200, 400, 3, stride=2)
|
self.conv2 = SConv2dAvg(200, 400, 3, stride=2)
|
||||||
self.conv3 = SConv2dAvg(400, 800, 3, stride=2, ceil_mode=True)
|
self.conv3 = SConv2dAvg(400, 800, 3, stride=2, ceil_mode=True)
|
||||||
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)
|
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)
|
||||||
|
|
|
@ -31,7 +31,7 @@ class SConv2dAvg(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True, bias = True):
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True, bias = True):
|
||||||
super(SConv2dAvg, self).__init__()
|
super(SConv2dAvg, self).__init__()
|
||||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size)
|
conv = nn.Conv2d(in_channels, out_channels, kernel_size)
|
||||||
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
|
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=padding, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
|
||||||
nn.init.constant_(self.deconv.weight, 1)
|
nn.init.constant_(self.deconv.weight, 1)
|
||||||
self.pooldeconv = nn.ConvTranspose2d(1, 1, kernel_size=stride,padding=0,stride=stride, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
|
self.pooldeconv = nn.ConvTranspose2d(1, 1, kernel_size=stride,padding=0,stride=stride, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
|
||||||
nn.init.constant_(self.pooldeconv.weight, 1)
|
nn.init.constant_(self.pooldeconv.weight, 1)
|
||||||
|
@ -47,7 +47,7 @@ class SConv2dAvg(nn.Module):
|
||||||
self.ceil_mode = ceil_mode
|
self.ceil_mode = ceil_mode
|
||||||
|
|
||||||
|
|
||||||
def forward_fast(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
|
def forward(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
|
||||||
device=input.device
|
device=input.device
|
||||||
if stride==-1:
|
if stride==-1:
|
||||||
stride = self.stride #if stride not defined use self.stride
|
stride = self.stride #if stride not defined use self.stride
|
||||||
|
@ -60,14 +60,20 @@ class SConv2dAvg(nn.Module):
|
||||||
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||||
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||||
|
|
||||||
if index[0]==-1 and stride!=1: #or stride!=1:
|
if stride!=1:
|
||||||
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
|
if len(index.shape)==1: #or stride!=1:
|
||||||
|
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
|
||||||
|
|
||||||
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||||
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||||
|
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||||
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||||
else:#in case of a valid mask use selection only on the mask locations
|
else:#in case of a valid mask use selection only on the mask locations
|
||||||
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
#inp_unf = inp_unf[index[:,:,mask>0]]
|
||||||
|
#mindex = index[mask>0]
|
||||||
|
mindex = torch.masked_select(index, mask>0)
|
||||||
|
index = mindex.repeat(batch_size,in_channels*kh*kw,1)
|
||||||
|
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
|
||||||
|
|
||||||
#Matrix mul
|
#Matrix mul
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
|
@ -89,7 +95,9 @@ class SConv2dAvg(nn.Module):
|
||||||
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||||
else:#in case of mask
|
else:#in case of mask
|
||||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||||
|
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
|
||||||
out[:,:,mask>0] = out_unf
|
out[:,:,mask>0] = out_unf
|
||||||
|
#out.masked_scatter_(mask>0, out_unf)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,7 +175,7 @@ class SConv2dAvg(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||||
device=input.device
|
device=input.device
|
||||||
if stride==-1:
|
if stride==-1:
|
||||||
stride = self.stride #if stride not defined use self.stride
|
stride = self.stride #if stride not defined use self.stride
|
||||||
|
@ -373,23 +381,28 @@ class SConv2dAvg(nn.Module):
|
||||||
resth = (out_h*stride)-afterconv_h
|
resth = (out_h*stride)-afterconv_h
|
||||||
restw = (out_w*stride)-afterconv_w
|
restw = (out_w*stride)-afterconv_w
|
||||||
if resth!=0:
|
if resth!=0:
|
||||||
|
print("stride",stride,"str-rest",stride-resth,stride-restw)
|
||||||
|
print('before',sel[-1],sel[:,-1])
|
||||||
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
|
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
|
||||||
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
|
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
|
||||||
|
print('after',sel[-1],sel[:,-1])
|
||||||
|
input()
|
||||||
|
|
||||||
rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w)
|
rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w)
|
||||||
index = sel+rng
|
index = sel+rng
|
||||||
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
#index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||||
|
|
||||||
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||||
|
|
||||||
if mask[0,0]!=-1:
|
if mask[0,0]!=-1:
|
||||||
maskh = (out_h)*stride
|
maskh = (out_h)*stride
|
||||||
maskw = (out_w)*stride
|
maskw = (out_w)*stride
|
||||||
nmask = torch.zeros((maskh,maskw),device=device)
|
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
|
||||||
nmask[rng_h,rng_w] = 1
|
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||||
|
nmask[index] = 1
|
||||||
#rmask = mask * nmask
|
#rmask = mask * nmask
|
||||||
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
|
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
|
||||||
rmask = nmask * dmask
|
rmask = nmask.view(1,1,maskh,maskw) * dmask
|
||||||
#rmask = rmask[:,:,:out_h,:out_w]
|
#rmask = rmask[:,:,:out_h,:out_w]
|
||||||
# print('rmask', rmask.shape)
|
# print('rmask', rmask.shape)
|
||||||
fmask = self.deconv(rmask)
|
fmask = self.deconv(rmask)
|
||||||
|
@ -398,6 +411,22 @@ class SConv2dAvg(nn.Module):
|
||||||
|
|
||||||
return index,mask#.long()
|
return index,mask#.long()
|
||||||
|
|
||||||
|
def get_mask(self,in_h,in_w,batch_size,device,mask=-torch.ones(1,1)):
|
||||||
|
maskh = (out_h)*stride
|
||||||
|
maskw = (out_w)*stride
|
||||||
|
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
|
||||||
|
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||||
|
nmask[index[0,0]] = 1
|
||||||
|
#rmask = mask * nmask
|
||||||
|
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
|
||||||
|
rmask = nmask.view(1,1,maskh,maskw) * dmask
|
||||||
|
#rmask = rmask[:,:,:out_h,:out_w]
|
||||||
|
# print('rmask', rmask.shape)
|
||||||
|
fmask = self.deconv(rmask)
|
||||||
|
# print('fmask', fmask.shape)
|
||||||
|
mask = fmask[0,0].long()
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def get_size(self,in_h,in_w,stride=-1):
|
def get_size(self,in_h,in_w,stride=-1):
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue