diff --git a/models/mylenet4.py b/models/mylenet4.py index b2364c5..0204c47 100644 --- a/models/mylenet4.py +++ b/models/mylenet4.py @@ -140,18 +140,18 @@ class MyLeNetMatStochBU(nn.Module):#epoch 11s h1,w1 = self.conv1.get_size(h0,w0) h2,w2 = self.conv2.get_size(h1,w1) h3,w3 = self.conv3.get_size(h2,w2) - print(h0,w0) - print(h1,w1) - print(h2,w2) - print(h3,w3) + # print(h0,w0) + # print(h1,w1) + # print(h2,w2) + # print(h3,w3) #sample BU mask3 = torch.ones(h3,w3).to(x.device) - print(mask3.shape) - selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3) - print(mask2.shape) + # print(mask3.shape) + selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3) #Mask2.shape != (h2,w2) ??? + # print(mask2.shape) selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2) - print(mask1.shape) + # print(mask1.shape) selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1) #forward out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch)) diff --git a/models/myresnet3.py b/models/myresnet3.py index 1c56b8f..f85d886 100644 --- a/models/myresnet3.py +++ b/models/myresnet3.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .sconv2davg import SConv2dAvg +from .stoch import SConv2dAvg class BasicBlock(nn.Module): expansion = 1 diff --git a/models/stoch.py b/models/stoch.py index 34cc92a..a18a7e2 100644 --- a/models/stoch.py +++ b/models/stoch.py @@ -156,7 +156,6 @@ class SConv2dAvg(nn.Module): out[:,:,mask>0] = out_unf return out - def comp(self,h,w,mask=-torch.ones(1,1)): out_h = (h-(self.kernel_size))/self.stride out_w = (w-(self.kernel_size))/self.stride @@ -181,9 +180,13 @@ class SConv2dAvg(nn.Module): out_channels, in_channels, kh, kw = self.weight.shape device=mask.device - afterconv_h = h-(kh-1) # Dim after deconv (ou after conv in forward) + #Shape after simple forward conv ? + afterconv_h = h-(kh-1) afterconv_w = w-(kw-1) - print(afterconv_h/stride) + # print(afterconv_h) + # print(afterconv_h/stride) + + #Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ? if self.ceil_mode: out_h = math.ceil(afterconv_h/stride) out_w = math.ceil(afterconv_w/stride) @@ -192,8 +195,8 @@ class SConv2dAvg(nn.Module): out_w = math.floor(afterconv_w/stride) # out_h=((afterconv_h+2*self.padding-1)/stride)+1 # out_w=((afterconv_w+2*self.padding-1)/stride)+1 - print('Out',out_h, out_w) - assert(tuple(mask.shape)==(out_h,out_w)) + # print('Out',out_h, out_w) + # assert(tuple(mask.shape)==(out_h,out_w)) # out_h,out_w=mask.shape selh = torch.randint(stride,(out_h,out_w), device=device) @@ -201,13 +204,13 @@ class SConv2dAvg(nn.Module): resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1 restw = (out_w*stride)-afterconv_w - print('rest', resth, restw) + # print('rest', resth, restw) if resth!=0: selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw) selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw) maskh = (out_h)*stride maskw = (out_w)*stride - print('mask', maskh, maskw) + # print('mask', maskh, maskw) rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1) rng_w = selw + torch.arange(0,out_w*stride,stride,device=device) # rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=device).view(-1,1)