From e2db0e6057eb68694013c901ef59bd56eaf471ca Mon Sep 17 00:00:00 2001 From: Marco Pedersoli Date: Thu, 18 Jun 2020 21:59:19 -0400 Subject: [PATCH] added faster BU --- models/mylenet4.py | 82 ++++++++++++++++++++++++++++++++++++++++------ models/stoch.py | 51 +++++++++++++++++++++------- 2 files changed, 112 insertions(+), 21 deletions(-) diff --git a/models/mylenet4.py b/models/mylenet4.py index 267e361..c0b8bb4 100644 --- a/models/mylenet4.py +++ b/models/mylenet4.py @@ -101,14 +101,36 @@ class MyLeNetMatNormal(nn.Module):#epach 21s #out = (self.fc1(out)) return out -class MyLeNetMatStochNoceil(nn.Module):#epoch 17s - def __init__(self): +class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB + def __init__(self,k=3): + super(MyLeNetMatNormalNoceil, self).__init__() + self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False) + self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False) + self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False) + self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False) + self.fc1 = nn.Linear(1600*k, 10) + + def forward(self, x, stoch=True): + out = F.relu(self.conv1(x,stoch=stoch)) + out = F.avg_pool2d(out,2,ceil_mode=True) + out = F.relu(self.conv2(out,stoch=stoch)) + out = F.avg_pool2d(out,2,ceil_mode=True) + out = F.relu(self.conv3(out,stoch=stoch)) + out = F.avg_pool2d(out,2,ceil_mode=True) + out = F.relu(self.conv4(out,stoch=stoch)) + out = F.avg_pool2d(out,4,ceil_mode=True) + out = out.view(out.size(0), -1 ) + out = self.fc1(out) + return out + +class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG + def __init__(self,k=3): super(MyLeNetMatStochNoceil, self).__init__() - self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False) - self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False) - self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False) - self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False) - #self.fc1 = nn.Linear(800, 10) + self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False) + self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False) + self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False) + self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False) + self.fc1 = nn.Linear(1600*k, 10) def forward(self, x, stoch=True): #print('in',x.shape) @@ -118,10 +140,10 @@ class MyLeNetMatStochNoceil(nn.Module):#epoch 17s #print('c2', out.shape) out = F.relu(self.conv3(out,stoch=stoch)) #print('c3',out.shape) - out = self.conv4(out,stoch=stoch) + out = F.relu(self.conv4(out,stoch=stoch)) #print('c4',out.shape) out = out.view(out.size(0), -1 ) - #out = self.fc1(out) + out = self.fc1(out) return out @@ -148,10 +170,50 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s #out = self.fc1(out) return out +class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB + def __init__(self,k=3): + super(MyLeNetMatStochBUNoceil, self).__init__() + self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False) + self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False) + self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False) + self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False) + self.fc1 = nn.Linear(1600*k, 10) + + def forward(self, x, stoch=True): + #get sizes + batch_size = x.shape[0] + device = x.device + h0,w0 = x.shape[2],x.shape[3] + _,_,h1,w1 = self.conv1.get_size(h0,w0) + _,_,h2,w2 = self.conv2.get_size(h1,w1) + _,_,h3,w3 = self.conv3.get_size(h2,w2) + _,_,h4,w4 = self.conv4.get_size(h3,w3) + # print(h0,w0) + # print(h1,w1) + # print(h2,w2) + # print(h3,w3) + + #sample BU + mask4 = torch.ones(h4,w4).to(x.device) + # print(mask3.shape) + index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4) + index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3) + index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2) + index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1) + + ##forward + out = F.relu(self.conv1(x,index1,mask1,stoch=stoch)) + out = F.relu(self.conv2(out,index2,mask2,stoch=stoch)) + out = F.relu(self.conv3(out,index3,mask3,stoch=stoch)) + out = F.relu(self.conv4(out,index4,mask4,stoch=stoch)) + out = out.view(out.size(0), -1 ) + out = self.fc1(out) + return out + class MyLeNetMatStochBU(nn.Module):#epoch 11s def __init__(self): super(MyLeNetMatStochBU, self).__init__() - self.conv1 = SConv2dAvg(3, 200, 3, stride=2) + self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2) self.conv2 = SConv2dAvg(200, 400, 3, stride=2) self.conv3 = SConv2dAvg(400, 800, 3, stride=2, ceil_mode=True) self.conv4 = SConv2dAvg(800, 10, 3, stride=1) diff --git a/models/stoch.py b/models/stoch.py index 936afe4..931714b 100644 --- a/models/stoch.py +++ b/models/stoch.py @@ -31,7 +31,7 @@ class SConv2dAvg(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True, bias = True): super(SConv2dAvg, self).__init__() conv = nn.Conv2d(in_channels, out_channels, kernel_size) - self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros') + self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, stride=1, padding=padding, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros') nn.init.constant_(self.deconv.weight, 1) self.pooldeconv = nn.ConvTranspose2d(1, 1, kernel_size=stride,padding=0,stride=stride, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros') nn.init.constant_(self.pooldeconv.weight, 1) @@ -47,7 +47,7 @@ class SConv2dAvg(nn.Module): self.ceil_mode = ceil_mode - def forward_fast(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right + def forward(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right device=input.device if stride==-1: stride = self.stride #if stride not defined use self.stride @@ -60,14 +60,20 @@ class SConv2dAvg(nn.Module): unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1) inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w) - if index[0]==-1 and stride!=1: #or stride!=1: - index,mask = self.sample(in_h,in_w,batch_size,device,mask) - + if stride!=1: + if len(index.shape)==1: #or stride!=1: + index,mask = self.sample(in_h,in_w,batch_size,device,mask) + if mask[0,0]==-1:# in case of not given mask use only sampled selection #inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1) + index = index.repeat(batch_size,in_channels*kh*kw,1,1) inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) else:#in case of a valid mask use selection only on the mask locations - inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]] + #inp_unf = inp_unf[index[:,:,mask>0]] + #mindex = index[mask>0] + mindex = torch.masked_select(index, mask>0) + index = mindex.repeat(batch_size,in_channels*kh*kw,1) + inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2]) #Matrix mul if self.bias is None: @@ -89,7 +95,9 @@ class SConv2dAvg(nn.Module): out = F.avg_pool2d(out,self.stride,ceil_mode=True) else:#in case of mask out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device) + #out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2]) out[:,:,mask>0] = out_unf + #out.masked_scatter_(mask>0, out_unf) return out @@ -167,7 +175,7 @@ class SConv2dAvg(nn.Module): return out - def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): + def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): device=input.device if stride==-1: stride = self.stride #if stride not defined use self.stride @@ -373,23 +381,28 @@ class SConv2dAvg(nn.Module): resth = (out_h*stride)-afterconv_h restw = (out_w*stride)-afterconv_w if resth!=0: + print("stride",stride,"str-rest",stride-resth,stride-restw) + print('before',sel[-1],sel[:,-1]) sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride) sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride + print('after',sel[-1],sel[:,-1]) + input() rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w) index = sel+rng - index = index.repeat(batch_size,in_channels*kh*kw,1,1) + #index = index.repeat(batch_size,in_channels*kh*kw,1,1) #inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) if mask[0,0]!=-1: maskh = (out_h)*stride maskw = (out_w)*stride - nmask = torch.zeros((maskh,maskw),device=device) - nmask[rng_h,rng_w] = 1 + nmask = torch.zeros((maskh,maskw),device=device).view(-1) + #inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) + nmask[index] = 1 #rmask = mask * nmask dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1])) - rmask = nmask * dmask + rmask = nmask.view(1,1,maskh,maskw) * dmask #rmask = rmask[:,:,:out_h,:out_w] # print('rmask', rmask.shape) fmask = self.deconv(rmask) @@ -398,6 +411,22 @@ class SConv2dAvg(nn.Module): return index,mask#.long() + def get_mask(self,in_h,in_w,batch_size,device,mask=-torch.ones(1,1)): + maskh = (out_h)*stride + maskw = (out_w)*stride + nmask = torch.zeros((maskh,maskw),device=device).view(-1) + #inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) + nmask[index[0,0]] = 1 + #rmask = mask * nmask + dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1])) + rmask = nmask.view(1,1,maskh,maskw) * dmask + #rmask = rmask[:,:,:out_h,:out_w] + # print('rmask', rmask.shape) + fmask = self.deconv(rmask) + # print('fmask', fmask.shape) + mask = fmask[0,0].long() + return mask + def get_size(self,in_h,in_w,stride=-1):