Commentaires

This commit is contained in:
Antoine Harlé 2020-06-12 02:44:02 -07:00
parent 3de923156c
commit 9d68bc30bd
3 changed files with 19 additions and 16 deletions

View file

@ -140,18 +140,18 @@ class MyLeNetMatStochBU(nn.Module):#epoch 11s
h1,w1 = self.conv1.get_size(h0,w0)
h2,w2 = self.conv2.get_size(h1,w1)
h3,w3 = self.conv3.get_size(h2,w2)
print(h0,w0)
print(h1,w1)
print(h2,w2)
print(h3,w3)
# print(h0,w0)
# print(h1,w1)
# print(h2,w2)
# print(h3,w3)
#sample BU
mask3 = torch.ones(h3,w3).to(x.device)
print(mask3.shape)
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3)
print(mask2.shape)
# print(mask3.shape)
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3) #Mask2.shape != (h2,w2) ???
# print(mask2.shape)
selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2)
print(mask1.shape)
# print(mask1.shape)
selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1)
#forward
out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch))