mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
added faster BU
This commit is contained in:
parent
286921f8a0
commit
e2db0e6057
2 changed files with 112 additions and 21 deletions
|
@ -101,14 +101,36 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
|
|||
#out = (self.fc1(out))
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
|
||||
def __init__(self):
|
||||
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
|
||||
def __init__(self,k=3):
|
||||
super(MyLeNetMatNormalNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
out = F.relu(self.conv1(x,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||
out = F.relu(self.conv2(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||
out = F.relu(self.conv4(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,4,ceil_mode=True)
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
|
||||
def __init__(self,k=3):
|
||||
super(MyLeNetMatStochNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False)
|
||||
#self.fc1 = nn.Linear(800, 10)
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
||||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
#print('in',x.shape)
|
||||
|
@ -118,10 +140,10 @@ class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
|
|||
#print('c2', out.shape)
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
#print('c3',out.shape)
|
||||
out = self.conv4(out,stoch=stoch)
|
||||
out = F.relu(self.conv4(out,stoch=stoch))
|
||||
#print('c4',out.shape)
|
||||
out = out.view(out.size(0), -1 )
|
||||
#out = self.fc1(out)
|
||||
out = self.fc1(out)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -148,10 +170,50 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
|
|||
#out = self.fc1(out)
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
|
||||
def __init__(self,k=3):
|
||||
super(MyLeNetMatStochBUNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
||||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
#get sizes
|
||||
batch_size = x.shape[0]
|
||||
device = x.device
|
||||
h0,w0 = x.shape[2],x.shape[3]
|
||||
_,_,h1,w1 = self.conv1.get_size(h0,w0)
|
||||
_,_,h2,w2 = self.conv2.get_size(h1,w1)
|
||||
_,_,h3,w3 = self.conv3.get_size(h2,w2)
|
||||
_,_,h4,w4 = self.conv4.get_size(h3,w3)
|
||||
# print(h0,w0)
|
||||
# print(h1,w1)
|
||||
# print(h2,w2)
|
||||
# print(h3,w3)
|
||||
|
||||
#sample BU
|
||||
mask4 = torch.ones(h4,w4).to(x.device)
|
||||
# print(mask3.shape)
|
||||
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
|
||||
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
|
||||
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
|
||||
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
|
||||
|
||||
##forward
|
||||
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
|
||||
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
|
||||
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
|
||||
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
||||
def __init__(self):
|
||||
super(MyLeNetMatStochBU, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200, 3, stride=2)
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2)
|
||||
self.conv2 = SConv2dAvg(200, 400, 3, stride=2)
|
||||
self.conv3 = SConv2dAvg(400, 800, 3, stride=2, ceil_mode=True)
|
||||
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue