mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-03 17:20:45 +02:00
fixed problem with inference with MyLeNetNoceil
This commit is contained in:
parent
019310e60c
commit
5504ca67ab
4 changed files with 157 additions and 138 deletions
2
main.py
2
main.py
|
@ -51,7 +51,7 @@ checkpoint=False
|
||||||
|
|
||||||
# Data
|
# Data
|
||||||
print('==> Preparing data..')
|
print('==> Preparing data..')
|
||||||
dataroot="~/scratch/data" #"./data"
|
dataroot="./data"#"~/scratch/data" #"./data"
|
||||||
download_data=False
|
download_data=False
|
||||||
transform_train = [
|
transform_train = [
|
||||||
# transforms.RandomCrop(32, padding=4),
|
# transforms.RandomCrop(32, padding=4),
|
||||||
|
|
|
@ -101,52 +101,6 @@ class MyLeNetMatNormal(nn.Module):#epach 21s
|
||||||
#out = (self.fc1(out))
|
#out = (self.fc1(out))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
|
|
||||||
def __init__(self,k=3):
|
|
||||||
super(MyLeNetMatNormalNoceil, self).__init__()
|
|
||||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
|
|
||||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
|
|
||||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
|
|
||||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
|
|
||||||
self.fc1 = nn.Linear(1600*k, 10)
|
|
||||||
|
|
||||||
def forward(self, x, stoch=True):
|
|
||||||
out = F.relu(self.conv1(x,stoch=stoch))
|
|
||||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
|
||||||
out = F.relu(self.conv2(out,stoch=stoch))
|
|
||||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
|
||||||
out = F.relu(self.conv3(out,stoch=stoch))
|
|
||||||
out = F.avg_pool2d(out,2,ceil_mode=True)
|
|
||||||
out = F.relu(self.conv4(out,stoch=stoch))
|
|
||||||
out = F.avg_pool2d(out,4,ceil_mode=True)
|
|
||||||
out = out.view(out.size(0), -1 )
|
|
||||||
out = self.fc1(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
|
|
||||||
def __init__(self,k=3):
|
|
||||||
super(MyLeNetMatStochNoceil, self).__init__()
|
|
||||||
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
|
||||||
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
|
||||||
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
|
||||||
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
|
||||||
self.fc1 = nn.Linear(1600*k, 10)
|
|
||||||
|
|
||||||
def forward(self, x, stoch=True):
|
|
||||||
#print('in',x.shape)
|
|
||||||
out = F.relu(self.conv1(x,stoch=stoch))
|
|
||||||
#print('c1',out.shape)
|
|
||||||
out = F.relu(self.conv2(out,stoch=stoch))
|
|
||||||
#print('c2', out.shape)
|
|
||||||
out = F.relu(self.conv3(out,stoch=stoch))
|
|
||||||
#print('c3',out.shape)
|
|
||||||
out = F.relu(self.conv4(out,stoch=stoch))
|
|
||||||
#print('c4',out.shape)
|
|
||||||
out = out.view(out.size(0), -1 )
|
|
||||||
out = self.fc1(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class MyLeNetMatStoch(nn.Module):#epoch 17s
|
class MyLeNetMatStoch(nn.Module):#epoch 17s
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MyLeNetMatStoch, self).__init__()
|
super(MyLeNetMatStoch, self).__init__()
|
||||||
|
@ -169,6 +123,59 @@ class MyLeNetMatStoch(nn.Module):#epoch 17s
|
||||||
out = out.view(out.size(0), -1 )
|
out = out.view(out.size(0), -1 )
|
||||||
#out = self.fc1(out)
|
#out = self.fc1(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class MyLeNetMatNormalNoceil(nn.Module):#epoch 136s 16GB
|
||||||
|
def __init__(self,k=3):
|
||||||
|
super(MyLeNetMatNormalNoceil, self).__init__()
|
||||||
|
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=1,padding=1,ceil_mode=False)
|
||||||
|
self.fc1 = nn.Linear(1600*k, 10)
|
||||||
|
|
||||||
|
def forward(self, x, stoch=True):
|
||||||
|
out = F.relu(self.conv1(x,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv2(out,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv3(out,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv4(out,stoch=stoch))
|
||||||
|
out = F.avg_pool2d(out,4,ceil_mode=False)
|
||||||
|
out = out.view(out.size(0), -1 )
|
||||||
|
out = self.fc1(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class MyLeNetMatStochNoceil(nn.Module):#epoch 41s 16BG
|
||||||
|
def __init__(self,k=3):
|
||||||
|
super(MyLeNetMatStochNoceil, self).__init__()
|
||||||
|
self.conv1 = SConv2dAvg(3, 200*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv2 = SConv2dAvg(200*k, 400*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv3 = SConv2dAvg(400*k, 800*k, 3, stride=2,padding=1,ceil_mode=False)
|
||||||
|
self.conv4 = SConv2dAvg(800*k, 1600*k, 3, stride=4,padding=1,ceil_mode=False)
|
||||||
|
self.fc1 = nn.Linear(1600*k, 10)
|
||||||
|
|
||||||
|
def forward(self, x, stoch=True):
|
||||||
|
if stoch:
|
||||||
|
out = F.relu(self.conv1(x,stoch=stoch))
|
||||||
|
out = F.relu(self.conv2(out,stoch=stoch))
|
||||||
|
out = F.relu(self.conv3(out,stoch=stoch))
|
||||||
|
out = F.relu(self.conv4(out,stoch=stoch))
|
||||||
|
out = out.view(out.size(0), -1 )
|
||||||
|
out = self.fc1(out)
|
||||||
|
else:
|
||||||
|
out = F.relu(self.conv1(x,stride=1))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv2(out,stride=1))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv3(out,stride=1))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv4(out,stride=1))
|
||||||
|
out = F.avg_pool2d(out,4,ceil_mode=False)
|
||||||
|
out = out.view(out.size(0), -1 )
|
||||||
|
out = self.fc1(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
|
class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
|
||||||
def __init__(self,k=3):
|
def __init__(self,k=3):
|
||||||
|
@ -180,34 +187,46 @@ class MyLeNetMatStochBUNoceil(nn.Module):#30.5s 14GB
|
||||||
self.fc1 = nn.Linear(1600*k, 10)
|
self.fc1 = nn.Linear(1600*k, 10)
|
||||||
|
|
||||||
def forward(self, x, stoch=True):
|
def forward(self, x, stoch=True):
|
||||||
#get sizes
|
if stoch:
|
||||||
batch_size = x.shape[0]
|
#get sizes
|
||||||
device = x.device
|
batch_size = x.shape[0]
|
||||||
h0,w0 = x.shape[2],x.shape[3]
|
device = x.device
|
||||||
_,_,h1,w1 = self.conv1.get_size(h0,w0)
|
h0,w0 = x.shape[2],x.shape[3]
|
||||||
_,_,h2,w2 = self.conv2.get_size(h1,w1)
|
_,_,h1,w1 = self.conv1.get_size(h0,w0)
|
||||||
_,_,h3,w3 = self.conv3.get_size(h2,w2)
|
_,_,h2,w2 = self.conv2.get_size(h1,w1)
|
||||||
_,_,h4,w4 = self.conv4.get_size(h3,w3)
|
_,_,h3,w3 = self.conv3.get_size(h2,w2)
|
||||||
# print(h0,w0)
|
_,_,h4,w4 = self.conv4.get_size(h3,w3)
|
||||||
# print(h1,w1)
|
# print(h0,w0)
|
||||||
# print(h2,w2)
|
# print(h1,w1)
|
||||||
# print(h3,w3)
|
# print(h2,w2)
|
||||||
|
# print(h3,w3)
|
||||||
|
|
||||||
#sample BU
|
#sample BU
|
||||||
mask4 = torch.ones(h4,w4).to(x.device)
|
mask4 = torch.ones(h4,w4).to(x.device)
|
||||||
# print(mask3.shape)
|
# print(mask3.shape)
|
||||||
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
|
index4,mask3 = self.conv4.sample(h3,w3,batch_size,device,mask4)
|
||||||
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
|
index3,mask2 = self.conv3.sample(h2,w2,batch_size,device,mask3)
|
||||||
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
|
index2,mask1 = self.conv2.sample(h1,w1,batch_size,device,mask2)
|
||||||
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
|
index1,mask0 = self.conv1.sample(h0,w0,batch_size,device,mask1)
|
||||||
|
|
||||||
##forward
|
##forward
|
||||||
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
|
out = F.relu(self.conv1(x,index1,mask1,stoch=stoch))
|
||||||
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
|
out = F.relu(self.conv2(out,index2,mask2,stoch=stoch))
|
||||||
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
|
out = F.relu(self.conv3(out,index3,mask3,stoch=stoch))
|
||||||
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
|
out = F.relu(self.conv4(out,index4,mask4,stoch=stoch))
|
||||||
out = out.view(out.size(0), -1 )
|
out = out.view(out.size(0), -1 )
|
||||||
out = self.fc1(out)
|
out = self.fc1(out)
|
||||||
|
else:
|
||||||
|
out = F.relu(self.conv1(x,stride=1))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv2(out,stride=1))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv3(out,stride=1))
|
||||||
|
out = F.avg_pool2d(out,2,ceil_mode=False)
|
||||||
|
out = F.relu(self.conv4(out,stride=1))
|
||||||
|
out = F.avg_pool2d(out,4,ceil_mode=False)
|
||||||
|
out = out.view(out.size(0), -1 )
|
||||||
|
out = self.fc1(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
class MyLeNetMatStochBU(nn.Module):#epoch 11s
|
||||||
|
|
120
models/stoch.py
120
models/stoch.py
|
@ -91,8 +91,9 @@ class SConv2dAvg(nn.Module):
|
||||||
|
|
||||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||||
if stoch==False: #this is done outside for more clarity
|
#if stoch==False: #this is done outside for more clarity
|
||||||
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
# out = F.avg_pool2d(out,self.stride,ceil_mode=False)
|
||||||
|
#print(self.stride)
|
||||||
else:#in case of mask
|
else:#in case of mask
|
||||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||||
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
|
#out = torch.gather(out.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,index.shape[2])).view(batch_size,in_channels*kh*kw,index.shape[2])
|
||||||
|
@ -100,6 +101,63 @@ class SConv2dAvg(nn.Module):
|
||||||
#out.masked_scatter_(mask>0, out_unf)
|
#out.masked_scatter_(mask>0, out_unf)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||||
|
device=input.device
|
||||||
|
if stride==-1:
|
||||||
|
stride = self.stride #if stride not defined use self.stride
|
||||||
|
if stoch==False:
|
||||||
|
stride=1 #test with real average pooling
|
||||||
|
batch_size, in_channels, in_h, in_w = input.shape
|
||||||
|
out_channels, in_channels, kh, kw = self.weight.shape
|
||||||
|
|
||||||
|
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
||||||
|
afterconv_w = in_w+2*self.padding-(kw-1)
|
||||||
|
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
||||||
|
out_h = math.ceil(afterconv_h/stride)
|
||||||
|
out_w = math.ceil(afterconv_w/stride)
|
||||||
|
else: #ceil_mode = false default mode for pooling
|
||||||
|
out_h = math.floor(afterconv_h/stride)
|
||||||
|
out_w = math.floor(afterconv_w/stride)
|
||||||
|
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
||||||
|
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||||
|
if stride!=1: # if stride==1 there is no pooling
|
||||||
|
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
|
||||||
|
if selh[0,0]==-1: # if not given sampled selection
|
||||||
|
#selction of where to sample for each pooling location
|
||||||
|
selh = torch.randint(stride,(out_h,out_w), device=device)
|
||||||
|
selw = torch.randint(stride,(out_h,out_w), device=device)
|
||||||
|
|
||||||
|
resth = (out_h*stride)-afterconv_h
|
||||||
|
restw = (out_w*stride)-afterconv_w
|
||||||
|
if resth!=0 and self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
||||||
|
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
||||||
|
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
||||||
|
#the postion should be global by adding range...
|
||||||
|
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
||||||
|
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
||||||
|
|
||||||
|
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
||||||
|
inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
||||||
|
else:#in case of a valid mask use selection only on the mask locations
|
||||||
|
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
||||||
|
|
||||||
|
#Matrix mul
|
||||||
|
if self.bias is None:
|
||||||
|
out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
||||||
|
#out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
||||||
|
else:
|
||||||
|
#out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)
|
||||||
|
out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
||||||
|
|
||||||
|
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
||||||
|
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
||||||
|
#if stoch==False: #this is done outside for more clarity
|
||||||
|
# out = F.avg_pool2d(out,self.stride,ceil_mode=self.ceil_mode)
|
||||||
|
else:#in case of mask
|
||||||
|
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
||||||
|
out[:,:,mask>0] = out_unf
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def forward_test(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):#ugly but faster
|
def forward_test(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):#ugly but faster
|
||||||
device=input.device
|
device=input.device
|
||||||
|
@ -174,64 +232,6 @@ class SConv2dAvg(nn.Module):
|
||||||
out[:,:,mask>0] = out_unf
|
out[:,:,mask>0] = out_unf
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def forward_slow(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
|
||||||
device=input.device
|
|
||||||
if stride==-1:
|
|
||||||
stride = self.stride #if stride not defined use self.stride
|
|
||||||
if stoch==False:
|
|
||||||
stride=1 #test with real average pooling
|
|
||||||
batch_size, in_channels, in_h, in_w = input.shape
|
|
||||||
out_channels, in_channels, kh, kw = self.weight.shape
|
|
||||||
|
|
||||||
afterconv_h = in_h+2*self.padding-(kh-1) #size after conv
|
|
||||||
afterconv_w = in_w+2*self.padding-(kw-1)
|
|
||||||
if self.ceil_mode: #ceil_mode = talse default mode for strided conv
|
|
||||||
out_h = math.ceil(afterconv_h/stride)
|
|
||||||
out_w = math.ceil(afterconv_w/stride)
|
|
||||||
else: #ceil_mode = false default mode for pooling
|
|
||||||
out_h = math.floor(afterconv_h/stride)
|
|
||||||
out_w = math.floor(afterconv_w/stride)
|
|
||||||
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
|
|
||||||
inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w)
|
|
||||||
if stride!=1: # if stride==1 there is no pooling
|
|
||||||
inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
|
|
||||||
if selh[0,0]==-1: # if not given sampled selection
|
|
||||||
#selction of where to sample for each pooling location
|
|
||||||
selh = torch.randint(stride,(out_h,out_w), device=device)
|
|
||||||
selw = torch.randint(stride,(out_h,out_w), device=device)
|
|
||||||
|
|
||||||
resth = (out_h*stride)-afterconv_h
|
|
||||||
restw = (out_w*stride)-afterconv_w
|
|
||||||
if resth!=0 and self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions
|
|
||||||
selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
|
|
||||||
selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
|
|
||||||
#the postion should be global by adding range...
|
|
||||||
rng_h = selh + torch.arange(0,out_h*stride,stride,device=device).view(-1,1)
|
|
||||||
rng_w = selw + torch.arange(0,out_w*stride,stride,device=device)
|
|
||||||
|
|
||||||
if mask[0,0]==-1:# in case of not given mask use only sampled selection
|
|
||||||
inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
|
|
||||||
else:#in case of a valid mask use selection only on the mask locations
|
|
||||||
inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
|
|
||||||
|
|
||||||
#Matrix mul
|
|
||||||
if self.bias is None:
|
|
||||||
#out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
|
|
||||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')
|
|
||||||
else:
|
|
||||||
out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)
|
|
||||||
#out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
|
|
||||||
|
|
||||||
if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1
|
|
||||||
out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold
|
|
||||||
if stoch==False: #this is done outside for more clarity
|
|
||||||
out = F.avg_pool2d(out,self.stride,ceil_mode=True)
|
|
||||||
else:#in case of mask
|
|
||||||
out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device)
|
|
||||||
out[:,:,mask>0] = out_unf
|
|
||||||
return out
|
|
||||||
|
|
||||||
def forward_slowwithbatch(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
def forward_slowwithbatch(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1):
|
||||||
device=input.device
|
device=input.device
|
||||||
if stride==-1:
|
if stride==-1:
|
||||||
|
|
|
@ -23,8 +23,8 @@ if __name__ == "__main__":
|
||||||
accs.append(max([x["test_acc"] for x in data]))
|
accs.append(max([x["test_acc"] for x in data]))
|
||||||
taccs.append(max([x["train_acc"] for x in data]))
|
taccs.append(max([x["train_acc"] for x in data]))
|
||||||
# aug_accs.append(data['Aug_Accuracy'][1])
|
# aug_accs.append(data['Aug_Accuracy'][1])
|
||||||
# times.append(data['Time'][0])
|
#times.append(data['Time'][0])
|
||||||
# mem.append(data['Memory'][1])
|
#mem.append(data['Memory'][1])
|
||||||
|
|
||||||
# acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
|
# acc_idx = [x['acc'] for x in data['Log']].index(data['Accuracy'])
|
||||||
# f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
|
# f1_max.append(max(data['Log'][acc_idx]['f1'])*100)
|
||||||
|
@ -36,5 +36,5 @@ if __name__ == "__main__":
|
||||||
print("Acc train : %.2f ~ %.2f"%(np.mean(taccs), np.std(taccs)))
|
print("Acc train : %.2f ~ %.2f"%(np.mean(taccs), np.std(taccs)))
|
||||||
# print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
|
# print("Acc : %.2f ~ %.2f / Aug_Acc %d: %.2f ~ %.2f"%(np.mean(accs), np.std(accs), data['Aug_Accuracy'][0], np.mean(aug_accs), np.std(aug_accs)))
|
||||||
# print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
|
# print("F1 max : %.2f ~ %.2f / F1 min : %.2f ~ %.2f"%(np.mean(f1_max), np.std(f1_max), np.mean(f1_min), np.std(f1_min)))
|
||||||
# print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
|
#print("Time (h): %.2f ~ %.2f"%(np.mean(times)/3600, np.std(times)/3600))
|
||||||
# print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
|
#print("Mem (MB): %d ~ %d"%(np.mean(mem), np.std(mem)))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue