mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
added forward fast in stoch.py, but it works only with ceil_mode=False
This commit is contained in:
parent
f7436d0002
commit
286921f8a0
3 changed files with 197 additions and 59 deletions
|
@ -101,6 +101,30 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
|
|||
#out = (self.fc1(out))
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
|
||||
def __init__(self):
|
||||
super(MyLeNetMatStochNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False)
|
||||
#self.fc1 = nn.Linear(800, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
#print('in',x.shape)
|
||||
out = F.relu(self.conv1(x,stoch=stoch))
|
||||
#print('c1',out.shape)
|
||||
out = F.relu(self.conv2(out,stoch=stoch))
|
||||
#print('c2', out.shape)
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
#print('c3',out.shape)
|
||||
out = self.conv4(out,stoch=stoch)
|
||||
#print('c4',out.shape)
|
||||
out = out.view(out.size(0), -1 )
|
||||
#out = self.fc1(out)
|
||||
return out
|
||||
|
||||
|
||||
class MyLeNetMatStoch(nn.Module):#epoch 17s
|
||||
def __init__(self):
|
||||
super(MyLeNetMatStoch, self).__init__()
|
||||
|
@ -111,16 +135,15 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
|
|||
#self.fc1 = nn.Linear(800, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
print('in',x.shape)
|
||||
#print('in',x.shape)
|
||||
out = F.relu(self.conv1(x,stoch=stoch))
|
||||
print('c1',out.shape)
|
||||
#print('c1',out.shape)
|
||||
out = F.relu(self.conv2(out,stoch=stoch))
|
||||
print('c2', out.shape)
|
||||
#print('c2', out.shape)
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
print('c3',out.shape)
|
||||
#hkjhlg
|
||||
#print('c3',out.shape)
|
||||
out = self.conv4(out,stoch=stoch)
|
||||
print('c4',out.shape)
|
||||
#print('c4',out.shape)
|
||||
out = out.view(out.size(0), -1 )
|
||||
#out = self.fc1(out)
|
||||
return out
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue