Commentaires

This commit is contained in:
Antoine Harlé 2020-06-12 02:44:02 -07:00
parent 3de923156c
commit 9d68bc30bd
3 changed files with 19 additions and 16 deletions

View file

@ -140,18 +140,18 @@ class MyLeNetMatStochBU(nn.Module):#epoch 11s
h1,w1 = self.conv1.get_size(h0,w0)
h2,w2 = self.conv2.get_size(h1,w1)
h3,w3 = self.conv3.get_size(h2,w2)
print(h0,w0)
print(h1,w1)
print(h2,w2)
print(h3,w3)
# print(h0,w0)
# print(h1,w1)
# print(h2,w2)
# print(h3,w3)
#sample BU
mask3 = torch.ones(h3,w3).to(x.device)
print(mask3.shape)
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3)
print(mask2.shape)
# print(mask3.shape)
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3) #Mask2.shape != (h2,w2) ???
# print(mask2.shape)
selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2)
print(mask1.shape)
# print(mask1.shape)
selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1)
#forward
out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch))

View file

@ -10,7 +10,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .sconv2davg import SConv2dAvg
from .stoch import SConv2dAvg
class BasicBlock(nn.Module):
expansion = 1

View file

@ -156,7 +156,6 @@ class SConv2dAvg(nn.Module):
out[:,:,mask>0] = out_unf
return out
def comp(self,h,w,mask=-torch.ones(1,1)):
out_h = (h-(self.kernel_size))/self.stride
out_w = (w-(self.kernel_size))/self.stride
@ -181,9 +180,13 @@ class SConv2dAvg(nn.Module):
out_channels, in_channels, kh, kw = self.weight.shape
device=mask.device
afterconv_h = h-(kh-1) # Dim after deconv (ou after conv in forward)
#Shape after simple forward conv ?
afterconv_h = h-(kh-1)
afterconv_w = w-(kw-1)
print(afterconv_h/stride)
# print(afterconv_h)
# print(afterconv_h/stride)
#Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ?
if self.ceil_mode:
out_h = math.ceil(afterconv_h/stride)
out_w = math.ceil(afterconv_w/stride)
@ -192,8 +195,8 @@ class SConv2dAvg(nn.Module):
out_w = math.floor(afterconv_w/stride)
# out_h=((afterconv_h+2*self.padding-1)/stride)+1
# out_w=((afterconv_w+2*self.padding-1)/stride)+1
print('Out',out_h, out_w)
assert(tuple(mask.shape)==(out_h,out_w))
# print('Out',out_h, out_w)
# assert(tuple(mask.shape)==(out_h,out_w))
# out_h,out_w=mask.shape
selh = torch.randint(stride,(out_h,out_w), device=device)
@ -201,13 +204,13 @@ class SConv2dAvg(nn.Module):
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
restw = (out_w*stride)-afterconv_w
print('rest', resth, restw)
# print('rest', resth, restw)
if resth!=0:
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
maskh = (out_h)*stride
maskw = (out_w)*stride
print('mask', maskh, maskw)
# print('mask', maskh, maskw)
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
# rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=device).view(-1,1)