mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-03 17:20:45 +02:00
fixed problem with inference with MyLeNetNoceil
This commit is contained in:
parent
019310e60c
commit
5504ca67ab
4 changed files with 157 additions and 138 deletions
2
main.py
2
main.py
|
@ -51,7 +51,7 @@ checkpoint=False
|
|||
|
||||
# Data
|
||||
print('==> Preparing data..')
|
||||
dataroot="~/scratch/data" #"./data"
|
||||
dataroot="./data"#"~/scratch/data" #"./data"
|
||||
download_data=False
|
||||
transform_train = [
|
||||
# transforms.RandomCrop(32, padding=4),
|
||||
|
|
|
@ -101,52 +101,6 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
|
|||
#out = (self.fc1(out))
|
||||
return out
|
||||
|
||||
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
|
||||
def __init__(self,k=3):
|
||||
super(MyLeNetMatNormalNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
out = F.relu(self.conv1(x,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||
out = F.relu(self.conv2(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
||||
out = F.relu(self.conv4(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,4,ceil_mode=True)
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
|
||||
def __init__(self,k=3):
|
||||
super(MyLeNetMatStochNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
||||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
#print('in',x.shape)
|
||||
out = F.relu(self.conv1(x,stoch=stoch))
|
||||
#print('c1',out.shape)
|
||||
out = F.relu(self.conv2(out,stoch=stoch))
|
||||
#print('c2', out.shape)
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
#print('c3',out.shape)
|
||||
out = F.relu(self.conv4(out,stoch=stoch))
|
||||
#print('c4',out.shape)
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
return out
|
||||
|
||||
|
||||
class MyLeNetMatStoch(nn.Module):#epoch 17s
|
||||
def __init__(self):
|
||||
super(MyLeNetMatStoch, self).__init__()
|
||||
|
@ -169,6 +123,59 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
|
|||
out = out.view(out.size(0), -1 )
|
||||
#out = self.fc1(out)
|
||||
return out
|
||||
|
||||
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
|
||||
def __init__(self,k=3):
|
||||
super(MyLeNetMatNormalNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
out = F.relu(self.conv1(x,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv2(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv4(out,stoch=stoch))
|
||||
out = F.avg_pool2d(out,4,ceil_mode=False)
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
|
||||
def __init__(self,k=3):
|
||||
super(MyLeNetMatStochNoceil, self).__init__()
|
||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
||||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
if stoch:
|
||||
out = F.relu(self.conv1(x,stoch=stoch))
|
||||
out = F.relu(self.conv2(out,stoch=stoch))
|
||||
out = F.relu(self.conv3(out,stoch=stoch))
|
||||
out = F.relu(self.conv4(out,stoch=stoch))
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
else:
|
||||
out = F.relu(self.conv1(x,stride=1))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv2(out,stride=1))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv3(out,stride=1))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv4(out,stride=1))
|
||||
out = F.avg_pool2d(out,4,ceil_mode=False)
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
|
||||
def __init__(self,k=3):
|
||||
|
@ -180,34 +187,46 @@ class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
|
|||
self.fc1 = nn.Linear(1600*k, 10)
|
||||
|
||||
def forward(self, x, stoch=True):
|
||||
#get sizes
|
||||
batch_size = x.shape[0]
|
||||
device = x.device
|
||||
h0,w0 = x.shape[2],x.shape[3]
|
||||
_,_,h1,w1 = self.conv1.get_size(h0,w0)
|
||||
_,_,h2,w2 = self.conv2.get_size(h1,w1)
|
||||
_,_,h3,w3 = self.conv3.get_size(h2,w2)
|
||||
_,_,h4,w4 = self.conv4.get_size(h3,w3)
|
||||
# print(h0,w0)
|
||||
# print(h1,w1)
|
||||
# print(h2,w2)
|
||||
# print(h3,w3)
|
||||
if stoch:
|
||||
#get sizes
|
||||
batch_size = x.shape[0]
|
||||
device = x.device
|
||||
h0,w0 = x.shape[2],x.shape[3]
|
||||
_,_,h1,w1 = self.conv1.get_size(h0,w0)
|
||||
_,_,h2,w2 = self.conv2.get_size(h1,w1)
|
||||
_,_,h3,w3 = self.conv3.get_size(h2,w2)
|
||||
_,_,h4,w4 = self.conv4.get_size(h3,w3)
|
||||
# print(h0,w0)
|
||||
# print(h1,w1)
|
||||
# print(h2,w2)
|
||||
# print(h3,w3)
|
||||
|
||||
#sample BU
|
||||
mask4 = torch.ones(h4,w4).to(x.device)
|
||||
# print(mask3.shape)
|
||||
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
|
||||
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
|
||||
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
|
||||
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
|
||||
|
||||
##forward
|
||||
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
|
||||
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
|
||||
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
|
||||
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
#sample BU
|
||||
mask4 = torch.ones(h4,w4).to(x.device)
|
||||
# print(mask3.shape)
|
||||
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
|
||||
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
|
||||
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
|
||||
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
|
||||
|
||||
##forward
|
||||
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
|
||||
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
|
||||
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
|
||||
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
else:
|
||||
out = F.relu(self.conv1(x,stride=1))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv2(out,stride=1))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv3(out,stride=1))
|
||||
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||
out = F.relu(self.conv4(out,stride=1))
|
||||
out = F.avg_pool2d(out,4,ceil_mode=False)
|
||||
out = out.view(out.size(0), -1 )
|
||||
out = self.fc1(out)
|
||||
return out
|
||||
|
||||
class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
||||
|
|
120
models/stoch.py
120
models/stoch.py
|
@ -91,8 +91,9 @@ class SConv2dAvg(nn.Module):
|
|||
|
||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||
if stoch==False: #this is done outside for more clarity
|
||||
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||
#if stoch==False: #this is done outside for more clarity
|
||||
# out = F.avg_pool2d(out,self.stride,ceil_mode=False)
|
||||
#print(self.stride)
|
||||
else:#in case of mask
|
||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
|
||||
|
@ -100,6 +101,63 @@ class SConv2dAvg(nn.Module):
|
|||
#out.masked_scatter_(mask>0, out_unf)
|
||||
return out
|
||||
|
||||
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||
device=input.device
|
||||
if stride==-1:
|
||||
stride = self.stride #if stride not defined use self.stride
|
||||
if stoch==False:
|
||||
stride=1 #test with real average pooling
|
||||
batch_size, in_channels, in_h, in_w = input.shape
|
||||
out_channels, in_channels, kh, kw = self.weight.shape
|
||||
|
||||
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
||||
out_h = math.ceil(afterconv_h/stride)
|
||||
out_w = math.ceil(afterconv_w/stride)
|
||||
else: #ceil_mode = false default mode for pooling
|
||||
out_h = math.floor(afterconv_h/stride)
|
||||
out_w = math.floor(afterconv_w/stride)
|
||||
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||
if stride!=1: # if stride==1 there is no pooling
|
||||
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||
if selh[0,0]==-1: # if not given sampled selection
|
||||
#selction of where to sample for each pooling location
|
||||
selh = torch.randint(stride,(out_h,out_w), device=device)
|
||||
selw = torch.randint(stride,(out_h,out_w), device=device)
|
||||
|
||||
resth = (out_h*stride)-afterconv_h
|
||||
restw = (out_w*stride)-afterconv_w
|
||||
if resth!=0 and self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
||||
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
||||
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
||||
#the postion should be global by adding range...
|
||||
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
||||
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
||||
|
||||
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||
inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||
else:#in case of a valid mask use selection only on the mask locations
|
||||
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
||||
|
||||
#Matrix mul
|
||||
if self.bias is None:
|
||||
out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
||||
#out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
||||
else:
|
||||
#out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)
|
||||
out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
||||
|
||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||
#if stoch==False: #this is done outside for more clarity
|
||||
# out = F.avg_pool2d(out,self.stride,ceil_mode=self.ceil_mode)
|
||||
else:#in case of mask
|
||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||
out[:,:,mask>0] = out_unf
|
||||
return out
|
||||
|
||||
|
||||
def forward_test(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):#ugly but faster
|
||||
device=input.device
|
||||
|
@ -174,64 +232,6 @@ class SConv2dAvg(nn.Module):
|
|||
out[:,:,mask>0] = out_unf
|
||||
return out
|
||||
|
||||
|
||||
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||
device=input.device
|
||||
if stride==-1:
|
||||
stride = self.stride #if stride not defined use self.stride
|
||||
if stoch==False:
|
||||
stride=1 #test with real average pooling
|
||||
batch_size, in_channels, in_h, in_w = input.shape
|
||||
out_channels, in_channels, kh, kw = self.weight.shape
|
||||
|
||||
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
||||
out_h = math.ceil(afterconv_h/stride)
|
||||
out_w = math.ceil(afterconv_w/stride)
|
||||
else: #ceil_mode = false default mode for pooling
|
||||
out_h = math.floor(afterconv_h/stride)
|
||||
out_w = math.floor(afterconv_w/stride)
|
||||
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||
if stride!=1: # if stride==1 there is no pooling
|
||||
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||
if selh[0,0]==-1: # if not given sampled selection
|
||||
#selction of where to sample for each pooling location
|
||||
selh = torch.randint(stride,(out_h,out_w), device=device)
|
||||
selw = torch.randint(stride,(out_h,out_w), device=device)
|
||||
|
||||
resth = (out_h*stride)-afterconv_h
|
||||
restw = (out_w*stride)-afterconv_w
|
||||
if resth!=0 and self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
||||
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
||||
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
||||
#the postion should be global by adding range...
|
||||
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
||||
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
||||
|
||||
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||
inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||
else:#in case of a valid mask use selection only on the mask locations
|
||||
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
||||
|
||||
#Matrix mul
|
||||
if self.bias is None:
|
||||
#out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
||||
else:
|
||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)
|
||||
#out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
||||
|
||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||
if stoch==False: #this is done outside for more clarity
|
||||
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
||||
else:#in case of mask
|
||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||
out[:,:,mask>0] = out_unf
|
||||
return out
|
||||
|
||||
def forward_slowwithbatch(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||
device=input.device
|
||||
if stride==-1:
|
||||
|
|
|
@ -23,8 +23,8 @@ if __name__ == "__main__":
|
|||
accs.append(max([x["test_acc"] for x in data]))
|
||||
taccs.append(max([x["train_acc"] for x in data]))
|
||||
# aug_accs.append(data['Aug_Accuracy'][1])
|
||||
# times.append(data['Time'][0])
|
||||
# mem.append(data['Memory'][1])
|
||||
#times.append(data['Time'][0])
|
||||
#mem.append(data['Memory'][1])
|
||||
|
||||
# acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
|
||||
# f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
|
||||
|
@ -36,5 +36,5 @@ if __name__ == "__main__":
|
|||
print("Acc train : %.2f ~ %.2f"%(np.mean(taccs), np.std(taccs)))
|
||||
# print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
|
||||
# print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
|
||||
# print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
|
||||
# print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
|
||||
#print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
|
||||
#print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue