mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
Commentaires
This commit is contained in:
parent
3de923156c
commit
9d68bc30bd
3 changed files with 19 additions and 16 deletions
|
@ -156,7 +156,6 @@ class SConv2dAvg(nn.Module):
|
|||
out[:,:,mask>0] = out_unf
|
||||
return out
|
||||
|
||||
|
||||
def comp(self,h,w,mask=-torch.ones(1,1)):
|
||||
out_h = (h-(self.kernel_size))/self.stride
|
||||
out_w = (w-(self.kernel_size))/self.stride
|
||||
|
@ -181,9 +180,13 @@ class SConv2dAvg(nn.Module):
|
|||
out_channels, in_channels, kh, kw = self.weight.shape
|
||||
device=mask.device
|
||||
|
||||
afterconv_h = h-(kh-1) # Dim after deconv (ou after conv in forward)
|
||||
#Shape after simple forward conv ?
|
||||
afterconv_h = h-(kh-1)
|
||||
afterconv_w = w-(kw-1)
|
||||
print(afterconv_h/stride)
|
||||
# print(afterconv_h)
|
||||
# print(afterconv_h/stride)
|
||||
|
||||
#Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ?
|
||||
if self.ceil_mode:
|
||||
out_h = math.ceil(afterconv_h/stride)
|
||||
out_w = math.ceil(afterconv_w/stride)
|
||||
|
@ -192,8 +195,8 @@ class SConv2dAvg(nn.Module):
|
|||
out_w = math.floor(afterconv_w/stride)
|
||||
# out_h=((afterconv_h+2*self.padding-1)/stride)+1
|
||||
# out_w=((afterconv_w+2*self.padding-1)/stride)+1
|
||||
print('Out',out_h, out_w)
|
||||
assert(tuple(mask.shape)==(out_h,out_w))
|
||||
# print('Out',out_h, out_w)
|
||||
# assert(tuple(mask.shape)==(out_h,out_w))
|
||||
# out_h,out_w=mask.shape
|
||||
|
||||
selh = torch.randint(stride,(out_h,out_w), device=device)
|
||||
|
@ -201,13 +204,13 @@ class SConv2dAvg(nn.Module):
|
|||
|
||||
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
|
||||
restw = (out_w*stride)-afterconv_w
|
||||
print('rest', resth, restw)
|
||||
# print('rest', resth, restw)
|
||||
if resth!=0:
|
||||
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
||||
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
||||
maskh = (out_h)*stride
|
||||
maskw = (out_w)*stride
|
||||
print('mask', maskh, maskw)
|
||||
# print('mask', maskh, maskw)
|
||||
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
||||
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
||||
# rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=device).view(-1,1)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue