mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 17:50:46 +02:00
Commentaires
This commit is contained in:
parent
3de923156c
commit
9d68bc30bd
3 changed files with 19 additions and 16 deletions
|
@ -140,18 +140,18 @@ class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
||||||
h1,w1 = self.conv1.get_size(h0,w0)
|
h1,w1 = self.conv1.get_size(h0,w0)
|
||||||
h2,w2 = self.conv2.get_size(h1,w1)
|
h2,w2 = self.conv2.get_size(h1,w1)
|
||||||
h3,w3 = self.conv3.get_size(h2,w2)
|
h3,w3 = self.conv3.get_size(h2,w2)
|
||||||
print(h0,w0)
|
# print(h0,w0)
|
||||||
print(h1,w1)
|
# print(h1,w1)
|
||||||
print(h2,w2)
|
# print(h2,w2)
|
||||||
print(h3,w3)
|
# print(h3,w3)
|
||||||
|
|
||||||
#sample BU
|
#sample BU
|
||||||
mask3 = torch.ones(h3,w3).to(x.device)
|
mask3 = torch.ones(h3,w3).to(x.device)
|
||||||
print(mask3.shape)
|
# print(mask3.shape)
|
||||||
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3)
|
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3) #Mask2.shape != (h2,w2) ???
|
||||||
print(mask2.shape)
|
# print(mask2.shape)
|
||||||
selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2)
|
selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2)
|
||||||
print(mask1.shape)
|
# print(mask1.shape)
|
||||||
selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1)
|
selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1)
|
||||||
#forward
|
#forward
|
||||||
out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch))
|
out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch))
|
||||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .sconv2davg import SConv2dAvg
|
from .stoch import SConv2dAvg
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
|
@ -156,7 +156,6 @@ class SConv2dAvg(nn.Module):
|
||||||
out[:,:,mask>0] = out_unf
|
out[:,:,mask>0] = out_unf
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def comp(self,h,w,mask=-torch.ones(1,1)):
|
def comp(self,h,w,mask=-torch.ones(1,1)):
|
||||||
out_h = (h-(self.kernel_size))/self.stride
|
out_h = (h-(self.kernel_size))/self.stride
|
||||||
out_w = (w-(self.kernel_size))/self.stride
|
out_w = (w-(self.kernel_size))/self.stride
|
||||||
|
@ -181,9 +180,13 @@ class SConv2dAvg(nn.Module):
|
||||||
out_channels, in_channels, kh, kw = self.weight.shape
|
out_channels, in_channels, kh, kw = self.weight.shape
|
||||||
device=mask.device
|
device=mask.device
|
||||||
|
|
||||||
afterconv_h = h-(kh-1) # Dim after deconv (ou after conv in forward)
|
#Shape after simple forward conv ?
|
||||||
|
afterconv_h = h-(kh-1)
|
||||||
afterconv_w = w-(kw-1)
|
afterconv_w = w-(kw-1)
|
||||||
print(afterconv_h/stride)
|
# print(afterconv_h)
|
||||||
|
# print(afterconv_h/stride)
|
||||||
|
|
||||||
|
#Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ?
|
||||||
if self.ceil_mode:
|
if self.ceil_mode:
|
||||||
out_h = math.ceil(afterconv_h/stride)
|
out_h = math.ceil(afterconv_h/stride)
|
||||||
out_w = math.ceil(afterconv_w/stride)
|
out_w = math.ceil(afterconv_w/stride)
|
||||||
|
@ -192,8 +195,8 @@ class SConv2dAvg(nn.Module):
|
||||||
out_w = math.floor(afterconv_w/stride)
|
out_w = math.floor(afterconv_w/stride)
|
||||||
# out_h=((afterconv_h+2*self.padding-1)/stride)+1
|
# out_h=((afterconv_h+2*self.padding-1)/stride)+1
|
||||||
# out_w=((afterconv_w+2*self.padding-1)/stride)+1
|
# out_w=((afterconv_w+2*self.padding-1)/stride)+1
|
||||||
print('Out',out_h, out_w)
|
# print('Out',out_h, out_w)
|
||||||
assert(tuple(mask.shape)==(out_h,out_w))
|
# assert(tuple(mask.shape)==(out_h,out_w))
|
||||||
# out_h,out_w=mask.shape
|
# out_h,out_w=mask.shape
|
||||||
|
|
||||||
selh = torch.randint(stride,(out_h,out_w), device=device)
|
selh = torch.randint(stride,(out_h,out_w), device=device)
|
||||||
|
@ -201,13 +204,13 @@ class SConv2dAvg(nn.Module):
|
||||||
|
|
||||||
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
|
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
|
||||||
restw = (out_w*stride)-afterconv_w
|
restw = (out_w*stride)-afterconv_w
|
||||||
print('rest', resth, restw)
|
# print('rest', resth, restw)
|
||||||
if resth!=0:
|
if resth!=0:
|
||||||
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
||||||
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
||||||
maskh = (out_h)*stride
|
maskh = (out_h)*stride
|
||||||
maskw = (out_w)*stride
|
maskw = (out_w)*stride
|
||||||
print('mask', maskh, maskw)
|
# print('mask', maskh, maskw)
|
||||||
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
||||||
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
||||||
# rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=device).view(-1,1)
|
# rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=device).view(-1,1)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue