Commentaires

This commit is contained in:
Antoine Harlé 2020-06-12 02:44:02 -07:00
parent 3de923156c
commit 9d68bc30bd
3 changed files with 19 additions and 16 deletions

View file

@ -156,7 +156,6 @@ class SConv2dAvg(nn.Module):
out[:,:,mask>0] = out_unf
return out
def comp(self,h,w,mask=-torch.ones(1,1)):
out_h = (h-(self.kernel_size))/self.stride
out_w = (w-(self.kernel_size))/self.stride
@ -181,9 +180,13 @@ class SConv2dAvg(nn.Module):
out_channels, in_channels, kh, kw = self.weight.shape
device=mask.device
afterconv_h = h-(kh-1) # Dim after deconv (ou after conv in forward)
#Shape after simple forward conv ?
afterconv_h = h-(kh-1)
afterconv_w = w-(kw-1)
print(afterconv_h/stride)
# print(afterconv_h)
# print(afterconv_h/stride)
#Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ?
if self.ceil_mode:
out_h = math.ceil(afterconv_h/stride)
out_w = math.ceil(afterconv_w/stride)
@ -192,8 +195,8 @@ class SConv2dAvg(nn.Module):
out_w = math.floor(afterconv_w/stride)
# out_h=((afterconv_h+2*self.padding-1)/stride)+1
# out_w=((afterconv_w+2*self.padding-1)/stride)+1
print('Out',out_h, out_w)
assert(tuple(mask.shape)==(out_h,out_w))
# print('Out',out_h, out_w)
# assert(tuple(mask.shape)==(out_h,out_w))
# out_h,out_w=mask.shape
selh = torch.randint(stride,(out_h,out_w), device=device)
@ -201,13 +204,13 @@ class SConv2dAvg(nn.Module):
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
restw = (out_w*stride)-afterconv_w
print('rest', resth, restw)
# print('rest', resth, restw)
if resth!=0:
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
maskh = (out_h)*stride
maskw = (out_w)*stride
print('mask', maskh, maskw)
# print('mask', maskh, maskw)
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
# rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=device).view(-1,1)