mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
Commentaires
This commit is contained in:
parent
3de923156c
commit
9d68bc30bd
3 changed files with 19 additions and 16 deletions
|
@ -140,18 +140,18 @@ class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
|||
h1,w1 = self.conv1.get_size(h0,w0)
|
||||
h2,w2 = self.conv2.get_size(h1,w1)
|
||||
h3,w3 = self.conv3.get_size(h2,w2)
|
||||
print(h0,w0)
|
||||
print(h1,w1)
|
||||
print(h2,w2)
|
||||
print(h3,w3)
|
||||
# print(h0,w0)
|
||||
# print(h1,w1)
|
||||
# print(h2,w2)
|
||||
# print(h3,w3)
|
||||
|
||||
#sample BU
|
||||
mask3 = torch.ones(h3,w3).to(x.device)
|
||||
print(mask3.shape)
|
||||
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3)
|
||||
print(mask2.shape)
|
||||
# print(mask3.shape)
|
||||
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3) #Mask2.shape != (h2,w2) ???
|
||||
# print(mask2.shape)
|
||||
selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2)
|
||||
print(mask1.shape)
|
||||
# print(mask1.shape)
|
||||
selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1)
|
||||
#forward
|
||||
out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue