From f7436d0002bf3d245b9c4b19694476fa68e4c88f Mon Sep 17 00:00:00 2001 From: Marco Pedersoli Date: Sat, 13 Jun 2020 20:47:19 -0400 Subject: [PATCH] added faster filtering and convolution, but not working yet for BU --- main.py | 17 ++-- models/__init__.py | 2 +- models/myresnet3.py | 2 +- models/myresnet4.py | 213 ++++++++++++++++++++++++++++++++++++++++++++ models/stoch.py | 72 ++++++++++++++- 5 files changed, 295 insertions(+), 11 deletions(-) create mode 100644 models/myresnet4.py diff --git a/main.py b/main.py index dd25b28..70eda0d 100644 --- a/main.py +++ b/main.py @@ -49,7 +49,7 @@ checkpoint=False # Data print('==> Preparing data..') -dataroot="~/scratch/data" +dataroot="./data" download_data=False transform_train = [ # transforms.RandomCrop(32, padding=4), @@ -116,7 +116,7 @@ print('==> Building model..') # net = MyLeNetMatStochBU() # 10.5s - 45.3% #1.3GB net=globals()[args.net]() -print(net) +#print(net) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) @@ -244,7 +244,7 @@ def stest(epoch,times=10): best_acc = acc import matplotlib.pyplot as plt -def plot_res(log, fig_name='res'): +def plot_res(log, best_acc,fig_name='res'): """Save a visual graph of the logs. Args: @@ -260,7 +260,7 @@ def plot_res(log, fig_name='res'): ax[0].plot(epochs,[x["test_loss"] for x in log], label='Test') ax[0].legend() - ax[1].set_title('Acc') + ax[1].set_title('Acc %s'%best_acc) ax[1].plot(epochs,[x["train_acc"] for x in log], label='Train') ax[1].plot(epochs,[x["test_acc"] for x in log], label='Test') ax[1].legend() @@ -270,7 +270,7 @@ def plot_res(log, fig_name='res'): plt.savefig(fig_name, bbox_inches='tight') plt.close() -from warmup_scheduler import GradualWarmupScheduler +#from warmup_scheduler import GradualWarmupScheduler def get_scheduler(schedule, epochs, warmup_mul, warmup_ep): scheduler=None if schedule=='cosine': @@ -328,11 +328,12 @@ for epoch in range(start_epoch, start_epoch+args.epochs): print('\nEpoch: %d' % epoch) print("Acc : %.2f / %.2f"%(train_acc, test_acc)) print("Loss : %.2f / %.2f"%(train_loss, test_loss)) + print('Time:',time.perf_counter() - t0) exec_time=time.perf_counter() - t0 print('-'*9) print('Best Acc : %.2f'%best_acc) -print('Training time (s):',exec_time) +print('Training time (min):',exec_time/60) import json @@ -344,8 +345,8 @@ except: print("Failed to save logs :",filename) print(sys.exc_info()[1]) try: - plot_res(log, fig_name=res_folder+filename) + plot_res(log,best_acc, fig_name=res_folder+filename) print('Plot :\"',res_folder+filename, '\" saved !') except: print("Failed to plot res") - print(sys.exc_info()[1]) \ No newline at end of file + print(sys.exc_info()[1]) diff --git a/models/__init__.py b/models/__init__.py index 5c2679e..ba67097 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1,2 @@ from .mylenet4 import * -from .myresnet3 import * +from .myresnet4 import * diff --git a/models/myresnet3.py b/models/myresnet3.py index f85d886..f40746b 100644 --- a/models/myresnet3.py +++ b/models/myresnet3.py @@ -119,7 +119,7 @@ class ResNet(nn.Module): def forward(self, x , stoch = False): #if self.training==False: # stoch=False - print(stoch) + #print(stoch) # self.layer1.stoch=stoch # self.layer2.stoch=stoch # self.layer3.stoch=stoch diff --git a/models/myresnet4.py b/models/myresnet4.py new file mode 100644 index 0000000..19e07b2 --- /dev/null +++ b/models/myresnet4.py @@ -0,0 +1,213 @@ +'''ResNet in PyTorch. + +For Pre-activation ResNet, see 'preact_resnet.py'. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Deep Residual Learning for Image Recognition. arXiv:1512.03385 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .stochsim import savg_pool2d +from .stoch import * + + + +class SAvg_Pool2d(nn.Module): + def __init__(self, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True,bias=False,mode='s'): + super(SAvg_Pool2d, self).__init__() + self.stride = stride + self.mode = mode + self.ceil_mode = ceil_mode + + def forward(self, x,stoch = True): + out = savg_pool2d(x, self.stride, mode = self.mode,ceil_mode = self.ceil_mode) + return out + +stochmode = 'sim'#'sim'#'stride''stoch''' +finalstochpool = True +simmode = 'sbc' + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1,pool=1): + super(BasicBlock, self).__init__() + + if stochmode=='' or stride==1: + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + elif stochmode=='stride': + if finalstochpool: + stride = stride*pool + self.conv1 = SConv2dStride(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + elif stochmode=='sim': + if finalstochpool: + stride = stride*pool + self.conv1 = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=1, bias=False), + SAvg_Pool2d(stride, mode = simmode,ceil_mode = True) + ) + elif stochmode=='stoch': + if finalstochpool: + stride = stride*pool + self.conv1 = SConv2dAvg(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + if stochmode=='stoch': + if pool!=1 and finalstochpool: + self.conv2 = SConv2dAvg(planes, planes, kernel_size=3, + stride=pool, padding=1, bias=False) + else: + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + else: + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + if stochmode=='stride': + self.shortcut = nn.Sequential( + SConv2dStride(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + elif stochmode=='stoch': + self.shortcut = nn.Sequential( + SConv2dAvg(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + elif stochmode=='sim': + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=1, bias=False), + SAvg_Pool2d(stride, mode = simmode,ceil_mode = True), + nn.BatchNorm2d(self.expansion*planes) + ) + elif stochmode=='': + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + +#only basic block has been updated!!! +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + #self.conv1 = SConv2dStride(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = SConv2dStride(planes, planes, kernel_size=3,stride=stride, padding=1, bias=False) + #self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = SConv2dStride(planes, self.expansion*planes, kernel_size=1, bias=False) + #self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion*planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + #nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + SConv2dStride(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10,stoch=False): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2,pool=4) + self.linear = nn.Linear(512*block.expansion, num_classes) + self.stoch = stoch + + def _make_layer(self, block, planes, num_blocks, stride, pool=1): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride,pool)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x ,stoch = True): + #if self.training==False: + # stoch=False + #stoch=True + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + #if self.stoch: + if stochmode=='': + if not(finalstochpool): + #if stochmode == '': + out = F.avg_pool2d(out, 4) + else: + out = savg_pool2d(out, 4, mode = simmode) + else: + if not(finalstochpool): + out = F.avg_pool2d(out, 4) +# else: +# if stoch: +# out = savg_pool2d(out, 4, mode = 's') +# else: +# out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + + +def MyResNet18(stoch=False): + return ResNet(BasicBlock, [2, 2, 2, 2],stoch=stoch) + + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def MyResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def test(): + net = ResNet18() + y = net(torch.randn(1, 3, 32, 32)) + print(y.size()) + +# test() diff --git a/models/stoch.py b/models/stoch.py index a18a7e2..1c42eaf 100644 --- a/models/stoch.py +++ b/models/stoch.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F import math +import opt_einsum as oe class SConv2dStride(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,ceil_mode=True,bias=False): @@ -54,6 +55,73 @@ class SConv2dAvg(nn.Module): batch_size, in_channels, in_h, in_w = input.shape out_channels, in_channels, kh, kw = self.weight.shape + afterconv_h = in_h+2*self.padding-(kh-1) #size after conv + afterconv_w = in_w+2*self.padding-(kw-1) + if self.ceil_mode: #ceil_mode = talse default mode for strided conv + out_h = math.ceil(afterconv_h/stride) + out_w = math.ceil(afterconv_w/stride) + else: #ceil_mode = false default mode for pooling + out_h = math.floor(afterconv_h/stride) + out_w = math.floor(afterconv_w/stride) + unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1) + inp_unf = unfold(input) #transform into a matrix (batch_size, in_channels*kh*kw,afterconv_h,afterconv_w) + if stride!=1: # if stride==1 there is no pooling + inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w) + if selh[0,0]==-1: # if not given sampled selection + #selction of where to sample for each pooling location + sel = torch.randint(stride*stride,(out_h,out_w), device=device) + + if self.ceil_mode: #in case of ceil_mode need to select only the good locations for the last regions + resth = (out_h*stride)-afterconv_h + restw = (out_w*stride)-afterconv_w + if resth!=0: + sel[-1] = (sel[-1]//stride)%(stride-resth)*stride+(sel[-1]%stride) + sel[:,-1] = (sel[:,-1]%stride)%(stride-restw)+sel[:,-1]//stride*stride + #print(stride-resth,sel[-1]) + #print(stride-restw,sel[:,-1]) + + rng = torch.arange(0,afterconv_h*afterconv_w,stride*stride,device=device).view(out_h,out_w) + index = sel+rng + index = index.repeat(batch_size,in_channels*kh*kw,1,1) + + + if mask[0,0]==-1:# in case of not given mask use only sampled selection + #inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1) + inp_unf = torch.gather(inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h*afterconv_w),2,index.view(batch_size,in_channels*kh*kw,out_h*out_w)).view(batch_size,in_channels*kh*kw,out_h*out_w) + else:#in case of a valid mask use selection only on the mask locations + inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]] + + #Matrix mul + if self.bias is None: + #flt = self.weight.view(self.weight.size(0), -1).t() + #out_unf = inp_unf.transpose(2,1).matmul(flt).transpose(1, 2) + out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch') + #print(((out_unf-out_unf1)**2).mean()) + else: + #out_unf = oe.contract('bji,kj,b->bki',inp_unf,self.weight.view(self.weight.size(0), -1),self.bias,backend='torch')#+self.bias.view(1,-1,1)#still slow + out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch')+self.bias.view(1,-1,1)#still slow + #self.flt = self.weight.view(self.weight.size(0), -1).t() + #out_unf = (inp_unf.transpose(1, 2).matmul(self.flt) + self.bias).transpose(1, 2) + + if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1 + out = out_unf.view(batch_size,out_channels,out_h,out_w) #Fold + #if stoch==False: #this is done outside for more clarity + # out = F.avg_pool2d(out,self.stride,ceil_mode=True) + else:#in case of mask + out = torch.zeros(batch_size, out_channels,out_h,out_w,device=device) + out[:,:,mask>0] = out_unf + return out + + + def forward_(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True,stride=-1): + device=input.device + if stride==-1: + stride = self.stride #if stride not defined use self.stride + if stoch==False: + stride=1 #test with real average pooling + batch_size, in_channels, in_h, in_w = input.shape + out_channels, in_channels, kh, kw = self.weight.shape + afterconv_h = in_h+2*self.padding-(kh-1) #size after conv afterconv_w = in_w+2*self.padding-(kw-1) if self.ceil_mode: #ceil_mode = talse default mode for strided conv @@ -87,8 +155,10 @@ class SConv2dAvg(nn.Module): #Matrix mul if self.bias is None: - out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2) + #out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2) + out_unf = oe.contract('bji,kj->bki',inp_unf,self.weight.view(self.weight.size(0), -1),backend='torch') else: + dgdg out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2) if stride==1 or mask[0,0]==-1:# in case of no mask and stride==1