added faster BU

This commit is contained in:
Marco Pedersoli 2020-06-18 21:59:19 -04:00
parent 286921f8a0
commit e2db0e6057
2 changed files with 112 additions and 21 deletions

View file

@ -101,14 +101,36 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
#out = (self.fc1(out))
return out
class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
def __init__(self):
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
def __init__(self,k=3):
super(MyLeNetMatNormalNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv2(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv3(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv4(out,stoch=stoch))
out = F.avg_pool2d(out,4,ceil_mode=True)
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
def __init__(self,k=3):
super(MyLeNetMatStochNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False)
#self.fc1 = nn.Linear(800, 10)
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
#print('in',x.shape)
@ -118,10 +140,10 @@ class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
#print('c2', out.shape)
out = F.relu(self.conv3(out,stoch=stoch))
#print('c3',out.shape)
out = self.conv4(out,stoch=stoch)
out = F.relu(self.conv4(out,stoch=stoch))
#print('c4',out.shape)
out = out.view(out.size(0), -1 )
#out = self.fc1(out)
out = self.fc1(out)
return out
@ -148,10 +170,50 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
#out = self.fc1(out)
return out
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
def __init__(self,k=3):
super(MyLeNetMatStochBUNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
#get sizes
batch_size = x.shape[0]
device = x.device
h0,w0 = x.shape[2],x.shape[3]
_,_,h1,w1 = self.conv1.get_size(h0,w0)
_,_,h2,w2 = self.conv2.get_size(h1,w1)
_,_,h3,w3 = self.conv3.get_size(h2,w2)
_,_,h4,w4 = self.conv4.get_size(h3,w3)
# print(h0,w0)
# print(h1,w1)
# print(h2,w2)
# print(h3,w3)
#sample BU
mask4 = torch.ones(h4,w4).to(x.device)
# print(mask3.shape)
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
##forward
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochBU(nn.Module):#epoch 11s
def __init__(self):
super(MyLeNetMatStochBU, self).__init__()
self.conv1 = SConv2dAvg(3, 200, 3, stride=2)
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2)
self.conv2 = SConv2dAvg(200, 400, 3, stride=2)
self.conv3 = SConv2dAvg(400, 800, 3, stride=2, ceil_mode=True)
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)

View file

@ -31,7 +31,7 @@ class SConv2dAvg(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True, bias = True):
super(SConv2dAvg, self).__init__()
conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=padding, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
nn.init.constant_(self.deconv.weight, 1)
self.pooldeconv = nn.ConvTranspose2d(1, 1, kernel_size=stride,padding=0,stride=stride, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
nn.init.constant_(self.pooldeconv.weight, 1)
@ -47,7 +47,7 @@ class SConv2dAvg(nn.Module):
self.ceil_mode = ceil_mode
def forward_fast(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
def forward(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
device=input.device
if stride==-1:
stride = self.stride #if stride not defined use self.stride
@ -60,14 +60,20 @@ class SConv2dAvg(nn.Module):
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
if index[0]==-1 and stride!=1: #or stride!=1:
if stride!=1:
if len(index.shape)==1: #or stride!=1:
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
if mask[0,0]==-1:# in case of not given mask use only sampled selection
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
else:#in case of a valid mask use selection only on the mask locations
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
#inp_unf = inp_unf[index[:,:,mask>0]]
#mindex = index[mask>0]
mindex = torch.masked_select(index, mask>0)
index = mindex.repeat(batch_size,in_channels*kh*kw,1)
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
#Matrix mul
if self.bias is None:
@ -89,7 +95,9 @@ class SConv2dAvg(nn.Module):
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
else:#in case of mask
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
out[:,:,mask>0] = out_unf
#out.masked_scatter_(mask>0, out_unf)
return out
@ -167,7 +175,7 @@ class SConv2dAvg(nn.Module):
return out
def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
device=input.device
if stride==-1:
stride = self.stride #if stride not defined use self.stride
@ -373,23 +381,28 @@ class SConv2dAvg(nn.Module):
resth = (out_h*stride)-afterconv_h
restw = (out_w*stride)-afterconv_w
if resth!=0:
print("stride",stride,"str-rest",stride-resth,stride-restw)
print('before',sel[-1],sel[:,-1])
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
print('after',sel[-1],sel[:,-1])
input()
rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w)
index = sel+rng
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
#index = index.repeat(batch_size,in_channels*kh*kw,1,1)
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
if mask[0,0]!=-1:
maskh = (out_h)*stride
maskw = (out_w)*stride
nmask = torch.zeros((maskh,maskw),device=device)
nmask[rng_h,rng_w] = 1
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
nmask[index] = 1
#rmask = mask * nmask
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
rmask = nmask * dmask
rmask = nmask.view(1,1,maskh,maskw) * dmask
#rmask = rmask[:,:,:out_h,:out_w]
# print('rmask', rmask.shape)
fmask = self.deconv(rmask)
@ -398,6 +411,22 @@ class SConv2dAvg(nn.Module):
return index,mask#.long()
def get_mask(self,in_h,in_w,batch_size,device,mask=-torch.ones(1,1)):
maskh = (out_h)*stride
maskw = (out_w)*stride
nmask = torch.zeros((maskh,maskw),device=device).view(-1)
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
nmask[index[0,0]] = 1
#rmask = mask * nmask
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
rmask = nmask.view(1,1,maskh,maskw) * dmask
#rmask = rmask[:,:,:out_h,:out_w]
# print('rmask', rmask.shape)
fmask = self.deconv(rmask)
# print('fmask', fmask.shape)
mask = fmask[0,0].long()
return mask
def get_size(self,in_h,in_w,stride=-1):