mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-03 17:20:45 +02:00
Commentaires
This commit is contained in:
parent
3de923156c
commit
9d68bc30bd
3 changed files with 19 additions and 16 deletions
|
@ -140,18 +140,18 @@ class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
|||
h1,w1 = self.conv1.get_size(h0,w0)
|
||||
h2,w2 = self.conv2.get_size(h1,w1)
|
||||
h3,w3 = self.conv3.get_size(h2,w2)
|
||||
print(h0,w0)
|
||||
print(h1,w1)
|
||||
print(h2,w2)
|
||||
print(h3,w3)
|
||||
# print(h0,w0)
|
||||
# print(h1,w1)
|
||||
# print(h2,w2)
|
||||
# print(h3,w3)
|
||||
|
||||
#sample BU
|
||||
mask3 = torch.ones(h3,w3).to(x.device)
|
||||
print(mask3.shape)
|
||||
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3)
|
||||
print(mask2.shape)
|
||||
# print(mask3.shape)
|
||||
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3) #Mask2.shape != (h2,w2) ???
|
||||
# print(mask2.shape)
|
||||
selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2)
|
||||
print(mask1.shape)
|
||||
# print(mask1.shape)
|
||||
selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1)
|
||||
#forward
|
||||
out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch))
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .sconv2davg import SConv2dAvg
|
||||
from .stoch import SConv2dAvg
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
|
|
@ -156,7 +156,6 @@ class SConv2dAvg(nn.Module):
|
|||
out[:,:,mask>0] = out_unf
|
||||
return out
|
||||
|
||||
|
||||
def comp(self,h,w,mask=-torch.ones(1,1)):
|
||||
out_h = (h-(self.kernel_size))/self.stride
|
||||
out_w = (w-(self.kernel_size))/self.stride
|
||||
|
@ -181,9 +180,13 @@ class SConv2dAvg(nn.Module):
|
|||
out_channels, in_channels, kh, kw = self.weight.shape
|
||||
device=mask.device
|
||||
|
||||
afterconv_h = h-(kh-1) # Dim after deconv (ou after conv in forward)
|
||||
#Shape after simple forward conv ?
|
||||
afterconv_h = h-(kh-1)
|
||||
afterconv_w = w-(kw-1)
|
||||
print(afterconv_h/stride)
|
||||
# print(afterconv_h)
|
||||
# print(afterconv_h/stride)
|
||||
|
||||
#Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ?
|
||||
if self.ceil_mode:
|
||||
out_h = math.ceil(afterconv_h/stride)
|
||||
out_w = math.ceil(afterconv_w/stride)
|
||||
|
@ -192,8 +195,8 @@ class SConv2dAvg(nn.Module):
|
|||
out_w = math.floor(afterconv_w/stride)
|
||||
# out_h=((afterconv_h+2*self.padding-1)/stride)+1
|
||||
# out_w=((afterconv_w+2*self.padding-1)/stride)+1
|
||||
print('Out',out_h, out_w)
|
||||
assert(tuple(mask.shape)==(out_h,out_w))
|
||||
# print('Out',out_h, out_w)
|
||||
# assert(tuple(mask.shape)==(out_h,out_w))
|
||||
# out_h,out_w=mask.shape
|
||||
|
||||
selh = torch.randint(stride,(out_h,out_w), device=device)
|
||||
|
@ -201,13 +204,13 @@ class SConv2dAvg(nn.Module):
|
|||
|
||||
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
|
||||
restw = (out_w*stride)-afterconv_w
|
||||
print('rest', resth, restw)
|
||||
# print('rest', resth, restw)
|
||||
if resth!=0:
|
||||
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
||||
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
||||
maskh = (out_h)*stride
|
||||
maskw = (out_w)*stride
|
||||
print('mask', maskh, maskw)
|
||||
# print('mask', maskh, maskw)
|
||||
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
||||
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
||||
# rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=device).view(-1,1)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue