added faster BU

This commit is contained in:
Marco Pedersoli 2020-06-18 21:59:19 -04:00
parent 286921f8a0
commit e2db0e6057
2 changed files with 112 additions and 21 deletions

View file

@ -101,14 +101,36 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
#out = (self.fc1(out))
return out
class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
def __init__(self):
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
def __init__(self,k=3):
super(MyLeNetMatNormalNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv2(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv3(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv4(out,stoch=stoch))
out = F.avg_pool2d(out,4,ceil_mode=True)
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
def __init__(self,k=3):
super(MyLeNetMatStochNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False)
#self.fc1 = nn.Linear(800, 10)
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
#print('in',x.shape)
@ -118,10 +140,10 @@ class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
#print('c2', out.shape)
out = F.relu(self.conv3(out,stoch=stoch))
#print('c3',out.shape)
out = self.conv4(out,stoch=stoch)
out = F.relu(self.conv4(out,stoch=stoch))
#print('c4',out.shape)
out = out.view(out.size(0), -1 )
#out = self.fc1(out)
out = self.fc1(out)
return out
@ -148,10 +170,50 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
#out = self.fc1(out)
return out
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
def __init__(self,k=3):
super(MyLeNetMatStochBUNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
#get sizes
batch_size = x.shape[0]
device = x.device
h0,w0 = x.shape[2],x.shape[3]
_,_,h1,w1 = self.conv1.get_size(h0,w0)
_,_,h2,w2 = self.conv2.get_size(h1,w1)
_,_,h3,w3 = self.conv3.get_size(h2,w2)
_,_,h4,w4 = self.conv4.get_size(h3,w3)
# print(h0,w0)
# print(h1,w1)
# print(h2,w2)
# print(h3,w3)
#sample BU
mask4 = torch.ones(h4,w4).to(x.device)
# print(mask3.shape)
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
##forward
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochBU(nn.Module):#epoch 11s
def __init__(self):
super(MyLeNetMatStochBU, self).__init__()
self.conv1 = SConv2dAvg(3, 200, 3, stride=2)
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2)
self.conv2 = SConv2dAvg(200, 400, 3, stride=2)
self.conv3 = SConv2dAvg(400, 800, 3, stride=2, ceil_mode=True)
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)