mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
added forward fast in stoch.py, but it works only with ceil_mode=False
This commit is contained in:
parent
f7436d0002
commit
286921f8a0
3 changed files with 197 additions and 59 deletions
|
@ -101,6 +101,30 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
|
||||||
#out = (self.fc1(out))
|
#out = (self.fc1(out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class MyLeNetMatStochNoceil(nn.Module):#epoch 17s
|
||||||
|
def __init__(self):
|
||||||
|
super(MyLeNetMatStochNoceil, self).__init__()
|
||||||
|
self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False)
|
||||||
|
#self.fc1 = nn.Linear(800, 10)
|
||||||
|
|
||||||
|
def forward(self, x, stoch=True):
|
||||||
|
#print('in',x.shape)
|
||||||
|
out = F.relu(self.conv1(x,stoch=stoch))
|
||||||
|
#print('c1',out.shape)
|
||||||
|
out = F.relu(self.conv2(out,stoch=stoch))
|
||||||
|
#print('c2', out.shape)
|
||||||
|
out = F.relu(self.conv3(out,stoch=stoch))
|
||||||
|
#print('c3',out.shape)
|
||||||
|
out = self.conv4(out,stoch=stoch)
|
||||||
|
#print('c4',out.shape)
|
||||||
|
out = out.view(out.size(0), -1 )
|
||||||
|
#out = self.fc1(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MyLeNetMatStoch(nn.Module):#epoch 17s
|
class MyLeNetMatStoch(nn.Module):#epoch 17s
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyLeNetMatStoch, self).__init__()
|
super(MyLeNetMatStoch, self).__init__()
|
||||||
|
@ -111,16 +135,15 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
|
||||||
#self.fc1 = nn.Linear(800, 10)
|
#self.fc1 = nn.Linear(800, 10)
|
||||||
|
|
||||||
def forward(self, x, stoch=True):
|
def forward(self, x, stoch=True):
|
||||||
print('in',x.shape)
|
#print('in',x.shape)
|
||||||
out = F.relu(self.conv1(x,stoch=stoch))
|
out = F.relu(self.conv1(x,stoch=stoch))
|
||||||
print('c1',out.shape)
|
#print('c1',out.shape)
|
||||||
out = F.relu(self.conv2(out,stoch=stoch))
|
out = F.relu(self.conv2(out,stoch=stoch))
|
||||||
print('c2', out.shape)
|
#print('c2', out.shape)
|
||||||
out = F.relu(self.conv3(out,stoch=stoch))
|
out = F.relu(self.conv3(out,stoch=stoch))
|
||||||
print('c3',out.shape)
|
#print('c3',out.shape)
|
||||||
#hkjhlg
|
|
||||||
out = self.conv4(out,stoch=stoch)
|
out = self.conv4(out,stoch=stoch)
|
||||||
print('c4',out.shape)
|
#print('c4',out.shape)
|
||||||
out = out.view(out.size(0), -1 )
|
out = out.view(out.size(0), -1 )
|
||||||
#out = self.fc1(out)
|
#out = self.fc1(out)
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -26,7 +26,7 @@ class SAvg_Pool2d(nn.Module):
|
||||||
out = savg_pool2d(x, self.stride, mode = self.mode,ceil_mode = self.ceil_mode)
|
out = savg_pool2d(x, self.stride, mode = self.mode,ceil_mode = self.ceil_mode)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
stochmode = 'sim'#'sim'#'stride''stoch'''
|
stochmode = 'stoch'#'sim'#'stride''stoch'''
|
||||||
finalstochpool = True
|
finalstochpool = True
|
||||||
simmode = 'sbc'
|
simmode = 'sbc'
|
||||||
|
|
||||||
|
|
219
models/stoch.py
219
models/stoch.py
|
@ -45,8 +45,55 @@ class SConv2dAvg(nn.Module):
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.ceil_mode = ceil_mode
|
self.ceil_mode = ceil_mode
|
||||||
|
|
||||||
|
|
||||||
|
def forward_fast(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right
|
||||||
|
device=input.device
|
||||||
|
if stride==-1:
|
||||||
|
stride = self.stride #if stride not defined use self.stride
|
||||||
|
if stoch==False:
|
||||||
|
stride=1 #test with real average pooling
|
||||||
|
batch_size, in_channels, in_h, in_w = input.shape
|
||||||
|
out_channels, in_channels, kh, kw = self.weight.shape
|
||||||
|
afterconv_h,afterconv_w,out_h,out_w = self.get_size(in_h,in_w,stride)
|
||||||
|
|
||||||
|
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||||
|
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||||
|
|
||||||
|
if index[0]==-1 and stride!=1: #or stride!=1:
|
||||||
|
index,mask = self.sample(in_h,in_w,batch_size,device,mask)
|
||||||
|
|
||||||
|
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||||
|
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||||
|
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||||
|
else:#in case of a valid mask use selection only on the mask locations
|
||||||
|
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
||||||
|
|
||||||
|
#Matrix mul
|
||||||
|
if self.bias is None:
|
||||||
|
#flt = self.weight.view(self.weight.size(0), -1).t()
|
||||||
|
#out_unf = inp_unf.transpose(2,1).matmul(flt).transpose(1, 2)
|
||||||
|
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
||||||
|
#print(((out_unf-out_unf1)**2).mean())
|
||||||
|
else:
|
||||||
|
#out_unf = oe.contract('bji,kj,k->bki',inp_unf,self.weight.view(self.weight.size(0), -1),self.bias,backend='torch')#+self.bias.view(1,-1,1)#wrong
|
||||||
|
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)#sligthly slower but correct
|
||||||
|
#out_unf1 = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
||||||
|
#print(((out_unf-out_unf1)**2).mean())
|
||||||
|
#self.flt = self.weight.view(self.weight.size(0), -1).t()
|
||||||
|
#out_unf = (inp_unf.transpose(1, 2).matmul(self.flt) + self.bias).transpose(1, 2)
|
||||||
|
|
||||||
|
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||||
|
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||||
|
if stoch==False: #this is done outside for more clarity
|
||||||
|
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||||
|
else:#in case of mask
|
||||||
|
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||||
|
out[:,:,mask>0] = out_unf
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
def forward_test(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):#ugly but faster
|
||||||
device=input.device
|
device=input.device
|
||||||
if stride==-1:
|
if stride==-1:
|
||||||
stride = self.stride #if stride not defined use self.stride
|
stride = self.stride #if stride not defined use self.stride
|
||||||
|
@ -55,41 +102,48 @@ class SConv2dAvg(nn.Module):
|
||||||
batch_size, in_channels, in_h, in_w = input.shape
|
batch_size, in_channels, in_h, in_w = input.shape
|
||||||
out_channels, in_channels, kh, kw = self.weight.shape
|
out_channels, in_channels, kh, kw = self.weight.shape
|
||||||
|
|
||||||
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
#afterconv_h,afterconv_w,out_h,out_w = self.get_size(in_h,in_w)
|
||||||
afterconv_w = in_w+2*self.padding-(kw-1)
|
#if selh[0,0]==-1:
|
||||||
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
# index,mask = self.sample(in_h,in_w,batch_size,device,mask)
|
||||||
out_h = math.ceil(afterconv_h/stride)
|
|
||||||
out_w = math.ceil(afterconv_w/stride)
|
if 1:
|
||||||
else: #ceil_mode = false default mode for pooling
|
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||||
out_h = math.floor(afterconv_h/stride)
|
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||||
out_w = math.floor(afterconv_w/stride)
|
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
||||||
|
out_h = math.ceil(afterconv_h/stride)
|
||||||
|
out_w = math.ceil(afterconv_w/stride)
|
||||||
|
else: #ceil_mode = false default mode for pooling
|
||||||
|
out_h = math.floor(afterconv_h/stride)
|
||||||
|
out_w = math.floor(afterconv_w/stride)
|
||||||
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||||
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||||
if stride!=1: # if stride==1 there is no pooling
|
if 1:
|
||||||
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
|
if stride!=1: # if stride==1 there is no pooling
|
||||||
if selh[0,0]==-1: # if not given sampled selection
|
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||||
#selction of where to sample for each pooling location
|
if selh[0,0]==-1: # if not given sampled selection
|
||||||
sel = torch.randint(stride*stride,(out_h,out_w), device=device)
|
#selction of where to sample for each pooling location
|
||||||
|
sel = torch.randint(stride*stride,(out_h,out_w), device=device)
|
||||||
|
|
||||||
if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
||||||
resth = (out_h*stride)-afterconv_h
|
resth = (out_h*stride)-afterconv_h
|
||||||
restw = (out_w*stride)-afterconv_w
|
restw = (out_w*stride)-afterconv_w
|
||||||
if resth!=0:
|
if resth!=0:
|
||||||
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
|
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
|
||||||
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
|
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
|
||||||
#print(stride-resth,sel[-1])
|
#print(stride-resth,sel[-1])
|
||||||
#print(stride-restw,sel[:,-1])
|
#print(stride-restw,sel[:,-1])
|
||||||
|
|
||||||
rng = torch.arange(0,afterconv_h*afterconv_w,stride*stride,device=device).view(out_h,out_w)
|
#rng = torch.arange(0,afterconv_h*afterconv_w,stride*stride,device=device).view(out_h,out_w)
|
||||||
index = sel+rng
|
rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w)
|
||||||
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
index = sel+rng
|
||||||
|
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||||
|
|
||||||
|
|
||||||
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||||
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
#inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||||
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||||
else:#in case of a valid mask use selection only on the mask locations
|
else:#in case of a valid mask use selection only on the mask locations
|
||||||
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
||||||
|
|
||||||
#Matrix mul
|
#Matrix mul
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
|
@ -105,15 +159,15 @@ class SConv2dAvg(nn.Module):
|
||||||
|
|
||||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||||
#if stoch==False: #this is done outside for more clarity
|
if stoch==False: #this is done outside for more clarity
|
||||||
# out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||||
else:#in case of mask
|
else:#in case of mask
|
||||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||||
out[:,:,mask>0] = out_unf
|
out[:,:,mask>0] = out_unf
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def forward_(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||||
device=input.device
|
device=input.device
|
||||||
if stride==-1:
|
if stride==-1:
|
||||||
stride = self.stride #if stride not defined use self.stride
|
stride = self.stride #if stride not defined use self.stride
|
||||||
|
@ -158,19 +212,19 @@ class SConv2dAvg(nn.Module):
|
||||||
#out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
#out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
||||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
||||||
else:
|
else:
|
||||||
dgdg
|
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)
|
||||||
out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
#out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
||||||
|
|
||||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||||
#if stoch==False: #this is done outside for more clarity
|
if stoch==False: #this is done outside for more clarity
|
||||||
# out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||||
else:#in case of mask
|
else:#in case of mask
|
||||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||||
out[:,:,mask>0] = out_unf
|
out[:,:,mask>0] = out_unf
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward_(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
def forward_slowwithbatch(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||||
device=input.device
|
device=input.device
|
||||||
if stride==-1:
|
if stride==-1:
|
||||||
stride = self.stride
|
stride = self.stride
|
||||||
|
@ -226,6 +280,7 @@ class SConv2dAvg(nn.Module):
|
||||||
out[:,:,mask>0] = out_unf
|
out[:,:,mask>0] = out_unf
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def comp(self,h,w,mask=-torch.ones(1,1)):
|
def comp(self,h,w,mask=-torch.ones(1,1)):
|
||||||
out_h = (h-(self.kernel_size))/self.stride
|
out_h = (h-(self.kernel_size))/self.stride
|
||||||
out_w = (w-(self.kernel_size))/self.stride
|
out_w = (w-(self.kernel_size))/self.stride
|
||||||
|
@ -241,7 +296,7 @@ class SConv2dAvg(nn.Module):
|
||||||
comp = self.weight.numel()*(mask>0).sum()
|
comp = self.weight.numel()*(mask>0).sum()
|
||||||
return comp
|
return comp
|
||||||
|
|
||||||
def sample(self,h,w,mask):
|
def sample_slow(self,h,w,mask):
|
||||||
'''
|
'''
|
||||||
h, w : forward input shape
|
h, w : forward input shape
|
||||||
mask : mask of output used in computation
|
mask : mask of output used in computation
|
||||||
|
@ -251,8 +306,8 @@ class SConv2dAvg(nn.Module):
|
||||||
device=mask.device
|
device=mask.device
|
||||||
|
|
||||||
#Shape after simple forward conv ?
|
#Shape after simple forward conv ?
|
||||||
afterconv_h = h-(kh-1)
|
afterconv_h = h+2*padding-(kh-1)
|
||||||
afterconv_w = w-(kw-1)
|
afterconv_w = w+2*padding-(kw-1)
|
||||||
# print(afterconv_h)
|
# print(afterconv_h)
|
||||||
# print(afterconv_h/stride)
|
# print(afterconv_h/stride)
|
||||||
|
|
||||||
|
@ -263,14 +318,9 @@ class SConv2dAvg(nn.Module):
|
||||||
else:
|
else:
|
||||||
out_h = math.floor(afterconv_h/stride)
|
out_h = math.floor(afterconv_h/stride)
|
||||||
out_w = math.floor(afterconv_w/stride)
|
out_w = math.floor(afterconv_w/stride)
|
||||||
# out_h=((afterconv_h+2*self.padding-1)/stride)+1
|
|
||||||
# out_w=((afterconv_w+2*self.padding-1)/stride)+1
|
|
||||||
# print('Out',out_h, out_w)
|
|
||||||
# assert(tuple(mask.shape)==(out_h,out_w))
|
|
||||||
# out_h,out_w=mask.shape
|
|
||||||
|
|
||||||
selh = torch.randint(stride,(out_h,out_w), device=device)
|
#selh = torch.randint(stride,(out_h,out_w), device=device)
|
||||||
selw = torch.randint(stride,(out_h,out_w), device=device)
|
#selw = torch.randint(stride,(out_h,out_w), device=device)
|
||||||
|
|
||||||
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
|
resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1
|
||||||
restw = (out_w*stride)-afterconv_w
|
restw = (out_w*stride)-afterconv_w
|
||||||
|
@ -296,8 +346,73 @@ class SConv2dAvg(nn.Module):
|
||||||
fmask = fmask[0,0]
|
fmask = fmask[0,0]
|
||||||
return selh,selw,fmask.long()
|
return selh,selw,fmask.long()
|
||||||
|
|
||||||
def get_size(self,h,w):
|
def sample(self,in_h,in_w,batch_size,device,mask=-torch.ones(1,1)):
|
||||||
newh=math.floor(((h + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1)
|
'''
|
||||||
neww=math.floor(((w + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1)
|
h, w : forward input shape
|
||||||
return newh, neww
|
mask : mask of output used in computation
|
||||||
|
'''
|
||||||
|
stride = self.stride
|
||||||
|
out_channels, in_channels, kh, kw = self.weight.shape
|
||||||
|
#device=mask.device
|
||||||
|
|
||||||
|
#Shape after simple forward conv ?
|
||||||
|
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||||
|
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||||
|
|
||||||
|
#Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ?
|
||||||
|
if self.ceil_mode:
|
||||||
|
out_h = math.ceil(afterconv_h/stride)
|
||||||
|
out_w = math.ceil(afterconv_w/stride)
|
||||||
|
else:
|
||||||
|
out_h = math.floor(afterconv_h/stride)
|
||||||
|
out_w = math.floor(afterconv_w/stride)
|
||||||
|
|
||||||
|
sel = torch.randint(stride*stride,(out_h,out_w), device=device)
|
||||||
|
|
||||||
|
if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
||||||
|
resth = (out_h*stride)-afterconv_h
|
||||||
|
restw = (out_w*stride)-afterconv_w
|
||||||
|
if resth!=0:
|
||||||
|
sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride)
|
||||||
|
sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride
|
||||||
|
|
||||||
|
rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w)
|
||||||
|
index = sel+rng
|
||||||
|
index = index.repeat(batch_size,in_channels*kh*kw,1,1)
|
||||||
|
|
||||||
|
#inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w)
|
||||||
|
|
||||||
|
if mask[0,0]!=-1:
|
||||||
|
maskh = (out_h)*stride
|
||||||
|
maskw = (out_w)*stride
|
||||||
|
nmask = torch.zeros((maskh,maskw),device=device)
|
||||||
|
nmask[rng_h,rng_w] = 1
|
||||||
|
#rmask = mask * nmask
|
||||||
|
dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
|
||||||
|
rmask = nmask * dmask
|
||||||
|
#rmask = rmask[:,:,:out_h,:out_w]
|
||||||
|
# print('rmask', rmask.shape)
|
||||||
|
fmask = self.deconv(rmask)
|
||||||
|
# print('fmask', fmask.shape)
|
||||||
|
mask = fmask[0,0].long()
|
||||||
|
|
||||||
|
return index,mask#.long()
|
||||||
|
|
||||||
|
|
||||||
|
def get_size(self,in_h,in_w,stride=-1):
|
||||||
|
|
||||||
|
if stride==-1:
|
||||||
|
stride = self.stride
|
||||||
|
out_channels, in_channels, kh, kw = self.weight.shape
|
||||||
|
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||||
|
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||||
|
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
||||||
|
out_h = math.ceil(afterconv_h/stride)
|
||||||
|
out_w = math.ceil(afterconv_w/stride)
|
||||||
|
else: #ceil_mode = false default mode for pooling
|
||||||
|
out_h = math.floor(afterconv_h/stride)
|
||||||
|
out_w = math.floor(afterconv_w/stride)
|
||||||
|
#newh=math.floor(((h + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1)
|
||||||
|
#neww=math.floor(((w + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1)
|
||||||
|
return afterconv_h,afterconv_w,out_h,out_w
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue