diff --git a/models/mylenet4.py b/models/mylenet4.py index 0204c47..267e361 100644 --- a/models/mylenet4.py +++ b/models/mylenet4.py @@ -101,6 +101,30 @@ class MyLeNetMatNormal(nn.Module):#epach 21s #out = (self.fc1(out)) return out +class MyLeNetMatStochNoceil(nn.Module):#epoch 17s + def __init__(self): + super(MyLeNetMatStochNoceil, self).__init__() + self.conv1 = SConv2dAvg(3, 200, 3, stride=2,padding=1,ceil_mode=False) + self.conv2 = SConv2dAvg(200, 400, 3, stride=2,padding=1,ceil_mode=False) + self.conv3 = SConv2dAvg(400, 800, 3, stride=2,padding=1,ceil_mode=False) + self.conv4 = SConv2dAvg(800, 10, 3, stride=4,padding=1,ceil_mode=False) + #self.fc1 = nn.Linear(800, 10) + + def forward(self, x, stoch=True): + #print('in',x.shape) + out = F.relu(self.conv1(x,stoch=stoch)) + #print('c1',out.shape) + out = F.relu(self.conv2(out,stoch=stoch)) + #print('c2', out.shape) + out = F.relu(self.conv3(out,stoch=stoch)) + #print('c3',out.shape) + out = self.conv4(out,stoch=stoch) + #print('c4',out.shape) + out = out.view(out.size(0), -1 ) + #out = self.fc1(out) + return out + + class MyLeNetMatStoch(nn.Module):#epoch 17s def __init__(self): super(MyLeNetMatStoch, self).__init__() @@ -111,16 +135,15 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s #self.fc1 = nn.Linear(800, 10) def forward(self, x, stoch=True): - print('in',x.shape) + #print('in',x.shape) out = F.relu(self.conv1(x,stoch=stoch)) - print('c1',out.shape) + #print('c1',out.shape) out = F.relu(self.conv2(out,stoch=stoch)) - print('c2', out.shape) + #print('c2', out.shape) out = F.relu(self.conv3(out,stoch=stoch)) - print('c3',out.shape) - #hkjhlg + #print('c3',out.shape) out = self.conv4(out,stoch=stoch) - print('c4',out.shape) + #print('c4',out.shape) out = out.view(out.size(0), -1 ) #out = self.fc1(out) return out diff --git a/models/myresnet4.py b/models/myresnet4.py index 19e07b2..5970f98 100644 --- a/models/myresnet4.py +++ b/models/myresnet4.py @@ -26,7 +26,7 @@ class SAvg_Pool2d(nn.Module): out = savg_pool2d(x, self.stride, mode = self.mode,ceil_mode = self.ceil_mode) return out -stochmode = 'sim'#'sim'#'stride''stoch''' +stochmode = 'stoch'#'sim'#'stride''stoch''' finalstochpool = True simmode = 'sbc' diff --git a/models/stoch.py b/models/stoch.py index 1c42eaf..936afe4 100644 --- a/models/stoch.py +++ b/models/stoch.py @@ -45,8 +45,55 @@ class SConv2dAvg(nn.Module): self.padding = padding self.kernel_size = kernel_size self.ceil_mode = ceil_mode + + + def forward_fast(self, input, index=-torch.ones(1), mask=-torch.ones(1,1),stoch=True,stride=-1): #ceil_mode = True not right + device=input.device + if stride==-1: + stride = self.stride #if stride not defined use self.stride + if stoch==False: + stride=1 #test with real average pooling + batch_size, in_channels, in_h, in_w = input.shape + out_channels, in_channels, kh, kw = self.weight.shape + afterconv_h,afterconv_w,out_h,out_w = self.get_size(in_h,in_w,stride) + + unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1) + inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w) + + if index[0]==-1 and stride!=1: #or stride!=1: + index,mask = self.sample(in_h,in_w,batch_size,device,mask) + + if mask[0,0]==-1:# in case of not given mask use only sampled selection + #inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1) + inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) + else:#in case of a valid mask use selection only on the mask locations + inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]] + + #Matrix mul + if self.bias is None: + #flt = self.weight.view(self.weight.size(0), -1).t() + #out_unf = inp_unf.transpose(2,1).matmul(flt).transpose(1, 2) + out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch') + #print(((out_unf-out_unf1)**2).mean()) + else: + #out_unf = oe.contract('bji,kj,k->bki',inp_unf,self.weight.view(self.weight.size(0), -1),self.bias,backend='torch')#+self.bias.view(1,-1,1)#wrong + out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)#sligthly slower but correct + #out_unf1 = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2) + #print(((out_unf-out_unf1)**2).mean()) + #self.flt = self.weight.view(self.weight.size(0), -1).t() + #out_unf = (inp_unf.transpose(1, 2).matmul(self.flt) + self.bias).transpose(1, 2) + + if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1 + out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold + if stoch==False: #this is done outside for more clarity + out = F.avg_pool2d(out,self.stride,ceil_mode=True) + else:#in case of mask + out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device) + out[:,:,mask>0] = out_unf + return out + - def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): + def forward_test(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):#ugly but faster device=input.device if stride==-1: stride = self.stride #if stride not defined use self.stride @@ -55,41 +102,48 @@ class SConv2dAvg(nn.Module): batch_size, in_channels, in_h, in_w = input.shape out_channels, in_channels, kh, kw = self.weight.shape - afterconv_h = in_h+2*self.padding-(kh-1) #size after conv - afterconv_w = in_w+2*self.padding-(kw-1) - if self.ceil_mode: #ceil_mode = talse default mode for strided conv - out_h = math.ceil(afterconv_h/stride) - out_w = math.ceil(afterconv_w/stride) - else: #ceil_mode = false default mode for pooling - out_h = math.floor(afterconv_h/stride) - out_w = math.floor(afterconv_w/stride) + #afterconv_h,afterconv_w,out_h,out_w = self.get_size(in_h,in_w) + #if selh[0,0]==-1: + # index,mask = self.sample(in_h,in_w,batch_size,device,mask) + + if 1: + afterconv_h = in_h+2*self.padding-(kh-1) #size after conv + afterconv_w = in_w+2*self.padding-(kw-1) + if self.ceil_mode: #ceil_mode = talse default mode for strided conv + out_h = math.ceil(afterconv_h/stride) + out_w = math.ceil(afterconv_w/stride) + else: #ceil_mode = false default mode for pooling + out_h = math.floor(afterconv_h/stride) + out_w = math.floor(afterconv_w/stride) unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1) inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w) - if stride!=1: # if stride==1 there is no pooling - inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w) - if selh[0,0]==-1: # if not given sampled selection - #selction of where to sample for each pooling location - sel = torch.randint(stride*stride,(out_h,out_w), device=device) + if 1: + if stride!=1: # if stride==1 there is no pooling + inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w) + if selh[0,0]==-1: # if not given sampled selection + #selction of where to sample for each pooling location + sel = torch.randint(stride*stride,(out_h,out_w), device=device) - if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions - resth = (out_h*stride)-afterconv_h - restw = (out_w*stride)-afterconv_w - if resth!=0: - sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride) - sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride - #print(stride-resth,sel[-1]) - #print(stride-restw,sel[:,-1]) + if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions + resth = (out_h*stride)-afterconv_h + restw = (out_w*stride)-afterconv_w + if resth!=0: + sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride) + sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride + #print(stride-resth,sel[-1]) + #print(stride-restw,sel[:,-1]) - rng = torch.arange(0,afterconv_h*afterconv_w,stride*stride,device=device).view(out_h,out_w) - index = sel+rng - index = index.repeat(batch_size,in_channels*kh*kw,1,1) + #rng = torch.arange(0,afterconv_h*afterconv_w,stride*stride,device=device).view(out_h,out_w) + rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w) + index = sel+rng + index = index.repeat(batch_size,in_channels*kh*kw,1,1) - if mask[0,0]==-1:# in case of not given mask use only sampled selection - #inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1) - inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) - else:#in case of a valid mask use selection only on the mask locations - inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]] + if mask[0,0]==-1:# in case of not given mask use only sampled selection + #inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1) + inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) + else:#in case of a valid mask use selection only on the mask locations + inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]] #Matrix mul if self.bias is None: @@ -105,15 +159,15 @@ class SConv2dAvg(nn.Module): if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1 out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold - #if stoch==False: #this is done outside for more clarity - # out = F.avg_pool2d(out,self.stride,ceil_mode=True) + if stoch==False: #this is done outside for more clarity + out = F.avg_pool2d(out,self.stride,ceil_mode=True) else:#in case of mask out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device) out[:,:,mask>0] = out_unf return out - def forward_(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): + def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): device=input.device if stride==-1: stride = self.stride #if stride not defined use self.stride @@ -158,19 +212,19 @@ class SConv2dAvg(nn.Module): #out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2) out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch') else: - dgdg - out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2) + out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1) + #out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2) if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1 out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold - #if stoch==False: #this is done outside for more clarity - # out = F.avg_pool2d(out,self.stride,ceil_mode=True) + if stoch==False: #this is done outside for more clarity + out = F.avg_pool2d(out,self.stride,ceil_mode=True) else:#in case of mask out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device) out[:,:,mask>0] = out_unf return out - def forward_(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): + def forward_slowwithbatch(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): device=input.device if stride==-1: stride = self.stride @@ -226,6 +280,7 @@ class SConv2dAvg(nn.Module): out[:,:,mask>0] = out_unf return out + def comp(self,h,w,mask=-torch.ones(1,1)): out_h = (h-(self.kernel_size))/self.stride out_w = (w-(self.kernel_size))/self.stride @@ -241,7 +296,7 @@ class SConv2dAvg(nn.Module): comp = self.weight.numel()*(mask>0).sum() return comp - def sample(self,h,w,mask): + def sample_slow(self,h,w,mask): ''' h, w : forward input shape mask : mask of output used in computation @@ -251,8 +306,8 @@ class SConv2dAvg(nn.Module): device=mask.device #Shape after simple forward conv ? - afterconv_h = h-(kh-1) - afterconv_w = w-(kw-1) + afterconv_h = h+2*padding-(kh-1) + afterconv_w = w+2*padding-(kw-1) # print(afterconv_h) # print(afterconv_h/stride) @@ -263,14 +318,9 @@ class SConv2dAvg(nn.Module): else: out_h = math.floor(afterconv_h/stride) out_w = math.floor(afterconv_w/stride) - # out_h=((afterconv_h+2*self.padding-1)/stride)+1 - # out_w=((afterconv_w+2*self.padding-1)/stride)+1 - # print('Out',out_h, out_w) - # assert(tuple(mask.shape)==(out_h,out_w)) - # out_h,out_w=mask.shape - selh = torch.randint(stride,(out_h,out_w), device=device) - selw = torch.randint(stride,(out_h,out_w), device=device) + #selh = torch.randint(stride,(out_h,out_w), device=device) + #selw = torch.randint(stride,(out_h,out_w), device=device) resth = (out_h*stride)-afterconv_h #reste de ceil/floor, 0 ou 1 restw = (out_w*stride)-afterconv_w @@ -296,8 +346,73 @@ class SConv2dAvg(nn.Module): fmask = fmask[0,0] return selh,selw,fmask.long() - def get_size(self,h,w): - newh=math.floor(((h + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1) - neww=math.floor(((w + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1) - return newh, neww + def sample(self,in_h,in_w,batch_size,device,mask=-torch.ones(1,1)): + ''' + h, w : forward input shape + mask : mask of output used in computation + ''' + stride = self.stride + out_channels, in_channels, kh, kw = self.weight.shape + #device=mask.device + + #Shape after simple forward conv ? + afterconv_h = in_h+2*self.padding-(kh-1) #size after conv + afterconv_w = in_w+2*self.padding-(kw-1) + + #Shape after forward ? (== mask.shape ?) #Padding, Dilatation pas pris en compte ? + if self.ceil_mode: + out_h = math.ceil(afterconv_h/stride) + out_w = math.ceil(afterconv_w/stride) + else: + out_h = math.floor(afterconv_h/stride) + out_w = math.floor(afterconv_w/stride) + + sel = torch.randint(stride*stride,(out_h,out_w), device=device) + + if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions + resth = (out_h*stride)-afterconv_h + restw = (out_w*stride)-afterconv_w + if resth!=0: + sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride) + sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride + + rng = torch.arange(0,out_h*stride*out_w*stride,stride*stride,device=device).view(out_h,out_w) + index = sel+rng + index = index.repeat(batch_size,in_channels*kh*kw,1,1) + + #inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) + + if mask[0,0]!=-1: + maskh = (out_h)*stride + maskw = (out_w)*stride + nmask = torch.zeros((maskh,maskw),device=device) + nmask[rng_h,rng_w] = 1 + #rmask = mask * nmask + dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1])) + rmask = nmask * dmask + #rmask = rmask[:,:,:out_h,:out_w] + # print('rmask', rmask.shape) + fmask = self.deconv(rmask) + # print('fmask', fmask.shape) + mask = fmask[0,0].long() + + return index,mask#.long() + + + def get_size(self,in_h,in_w,stride=-1): + + if stride==-1: + stride = self.stride + out_channels, in_channels, kh, kw = self.weight.shape + afterconv_h = in_h+2*self.padding-(kh-1) #size after conv + afterconv_w = in_w+2*self.padding-(kw-1) + if self.ceil_mode: #ceil_mode = talse default mode for strided conv + out_h = math.ceil(afterconv_h/stride) + out_w = math.ceil(afterconv_w/stride) + else: #ceil_mode = false default mode for pooling + out_h = math.floor(afterconv_h/stride) + out_w = math.floor(afterconv_w/stride) + #newh=math.floor(((h + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1) + #neww=math.floor(((w + 2*self.padding - self.dilation*(self.kernel_size-1) - 1)/self.stride) + 1) + return afterconv_h,afterconv_w,out_h,out_w