fixed problem with inference with MyLeNetNoceil

This commit is contained in:
Marco Pedersoli 2020-06-30 11:56:51 -04:00
parent 019310e60c
commit 5504ca67ab
4 changed files with 157 additions and 138 deletions

View file

@ -51,7 +51,7 @@ checkpoint=False
# Data
print('==> Preparing data..')
dataroot="~/scratch/data" #"./data"
dataroot="./data"#"~/scratch/data" #"./data"
download_data=False
transform_train = [
# transforms.RandomCrop(32, padding=4),

View file

@ -101,52 +101,6 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
#out = (self.fc1(out))
return out
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
def __init__(self,k=3):
super(MyLeNetMatNormalNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv2(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv3(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv4(out,stoch=stoch))
out = F.avg_pool2d(out,4,ceil_mode=True)
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
def __init__(self,k=3):
super(MyLeNetMatStochNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
#print('in',x.shape)
out = F.relu(self.conv1(x,stoch=stoch))
#print('c1',out.shape)
out = F.relu(self.conv2(out,stoch=stoch))
#print('c2', out.shape)
out = F.relu(self.conv3(out,stoch=stoch))
#print('c3',out.shape)
out = F.relu(self.conv4(out,stoch=stoch))
#print('c4',out.shape)
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStoch(nn.Module):#epoch 17s
def __init__(self):
super(MyLeNetMatStoch, self).__init__()
@ -170,6 +124,59 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
#out = self.fc1(out)
return out
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
def __init__(self,k=3):
super(MyLeNetMatNormalNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv2(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv3(out,stoch=stoch))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv4(out,stoch=stoch))
out = F.avg_pool2d(out,4,ceil_mode=False)
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
def __init__(self,k=3):
super(MyLeNetMatStochNoceil, self).__init__()
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
if stoch:
out = F.relu(self.conv1(x,stoch=stoch))
out = F.relu(self.conv2(out,stoch=stoch))
out = F.relu(self.conv3(out,stoch=stoch))
out = F.relu(self.conv4(out,stoch=stoch))
out = out.view(out.size(0), -1 )
out = self.fc1(out)
else:
out = F.relu(self.conv1(x,stride=1))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv2(out,stride=1))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv3(out,stride=1))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv4(out,stride=1))
out = F.avg_pool2d(out,4,ceil_mode=False)
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
def __init__(self,k=3):
super(MyLeNetMatStochBUNoceil, self).__init__()
@ -180,34 +187,46 @@ class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
self.fc1 = nn.Linear(1600*k, 10)
def forward(self, x, stoch=True):
#get sizes
batch_size = x.shape[0]
device = x.device
h0,w0 = x.shape[2],x.shape[3]
_,_,h1,w1 = self.conv1.get_size(h0,w0)
_,_,h2,w2 = self.conv2.get_size(h1,w1)
_,_,h3,w3 = self.conv3.get_size(h2,w2)
_,_,h4,w4 = self.conv4.get_size(h3,w3)
# print(h0,w0)
# print(h1,w1)
# print(h2,w2)
# print(h3,w3)
if stoch:
#get sizes
batch_size = x.shape[0]
device = x.device
h0,w0 = x.shape[2],x.shape[3]
_,_,h1,w1 = self.conv1.get_size(h0,w0)
_,_,h2,w2 = self.conv2.get_size(h1,w1)
_,_,h3,w3 = self.conv3.get_size(h2,w2)
_,_,h4,w4 = self.conv4.get_size(h3,w3)
# print(h0,w0)
# print(h1,w1)
# print(h2,w2)
# print(h3,w3)
#sample BU
mask4 = torch.ones(h4,w4).to(x.device)
# print(mask3.shape)
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
#sample BU
mask4 = torch.ones(h4,w4).to(x.device)
# print(mask3.shape)
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
##forward
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
out = out.view(out.size(0), -1 )
out = self.fc1(out)
##forward
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
out = out.view(out.size(0), -1 )
out = self.fc1(out)
else:
out = F.relu(self.conv1(x,stride=1))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv2(out,stride=1))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv3(out,stride=1))
out = F.avg_pool2d(out,2,ceil_mode=False)
out = F.relu(self.conv4(out,stride=1))
out = F.avg_pool2d(out,4,ceil_mode=False)
out = out.view(out.size(0), -1 )
out = self.fc1(out)
return out
class MyLeNetMatStochBU(nn.Module):#epoch 11s

View file

@ -91,8 +91,9 @@ class SConv2dAvg(nn.Module):
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
if stoch==False: #this is done outside for more clarity
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
#if stoch==False: #this is done outside for more clarity
# out = F.avg_pool2d(out,self.stride,ceil_mode=False)
#print(self.stride)
else:#in case of mask
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
@ -100,6 +101,63 @@ class SConv2dAvg(nn.Module):
#out.masked_scatter_(mask>0, out_unf)
return out
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
device=input.device
if stride==-1:
stride = self.stride #if stride not defined use self.stride
if stoch==False:
stride=1 #test with real average pooling
batch_size, in_channels, in_h, in_w = input.shape
out_channels, in_channels, kh, kw = self.weight.shape
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
afterconv_w = in_w+2*self.padding-(kw-1)
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
out_h = math.ceil(afterconv_h/stride)
out_w = math.ceil(afterconv_w/stride)
else: #ceil_mode = false default mode for pooling
out_h = math.floor(afterconv_h/stride)
out_w = math.floor(afterconv_w/stride)
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
if stride!=1: # if stride==1 there is no pooling
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
if selh[0,0]==-1: # if not given sampled selection
#selction of where to sample for each pooling location
selh = torch.randint(stride,(out_h,out_w), device=device)
selw = torch.randint(stride,(out_h,out_w), device=device)
resth = (out_h*stride)-afterconv_h
restw = (out_w*stride)-afterconv_w
if resth!=0 and self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
#the postion should be global by adding range...
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
if mask[0,0]==-1:# in case of not given mask use only sampled selection
inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
else:#in case of a valid mask use selection only on the mask locations
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
#Matrix mul
if self.bias is None:
out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
#out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
else:
#out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)
out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
#if stoch==False: #this is done outside for more clarity
# out = F.avg_pool2d(out,self.stride,ceil_mode=self.ceil_mode)
else:#in case of mask
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
out[:,:,mask>0] = out_unf
return out
def forward_test(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):#ugly but faster
device=input.device
@ -174,64 +232,6 @@ class SConv2dAvg(nn.Module):
out[:,:,mask>0] = out_unf
return out
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
device=input.device
if stride==-1:
stride = self.stride #if stride not defined use self.stride
if stoch==False:
stride=1 #test with real average pooling
batch_size, in_channels, in_h, in_w = input.shape
out_channels, in_channels, kh, kw = self.weight.shape
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
afterconv_w = in_w+2*self.padding-(kw-1)
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
out_h = math.ceil(afterconv_h/stride)
out_w = math.ceil(afterconv_w/stride)
else: #ceil_mode = false default mode for pooling
out_h = math.floor(afterconv_h/stride)
out_w = math.floor(afterconv_w/stride)
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
if stride!=1: # if stride==1 there is no pooling
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
if selh[0,0]==-1: # if not given sampled selection
#selction of where to sample for each pooling location
selh = torch.randint(stride,(out_h,out_w), device=device)
selw = torch.randint(stride,(out_h,out_w), device=device)
resth = (out_h*stride)-afterconv_h
restw = (out_w*stride)-afterconv_w
if resth!=0 and self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
#the postion should be global by adding range...
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
if mask[0,0]==-1:# in case of not given mask use only sampled selection
inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
else:#in case of a valid mask use selection only on the mask locations
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
#Matrix mul
if self.bias is None:
#out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
else:
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)
#out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
if stoch==False: #this is done outside for more clarity
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
else:#in case of mask
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
out[:,:,mask>0] = out_unf
return out
def forward_slowwithbatch(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
device=input.device
if stride==-1:

View file

@ -23,8 +23,8 @@ if __name__ == "__main__":
accs.append(max([x["test_acc"] for x in data]))
taccs.append(max([x["train_acc"] for x in data]))
# aug_accs.append(data['Aug_Accuracy'][1])
# times.append(data['Time'][0])
# mem.append(data['Memory'][1])
#times.append(data['Time'][0])
#mem.append(data['Memory'][1])
# acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
# f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
@ -36,5 +36,5 @@ if __name__ == "__main__":
print("Acc train : %.2f ~ %.2f"%(np.mean(taccs), np.std(taccs)))
# print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
# print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
# print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
# print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
#print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
#print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))