BU_Stoch_pool/models/mylenet4.py
2020-06-12 01:42:08 -07:00

314 lines
13 KiB
Python

'''LeNet in PyTorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from .stoch import SConv2dAvg
from .stochsim import savg_pool2d
class MyLeNetNormal(nn.Module):#epoch 12s
def __init__(self):
super(MyLeNetNormal, self).__init__()
self.conv1 = nn.Conv2d(3, 200, 3, stride=1)
self.conv2 = nn.Conv2d(200, 400, 3, stride=1)
self.conv3 = nn.Conv2d(400, 800, 3, stride=1)
self.conv4 = nn.Conv2d(800, 10, 3, stride=1)
#self.fc1 = nn.Linear(800, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv2(out))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv3(out))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = self.conv4(out)
#out = F.avg_pool2d(out,2,ceil_mode=True)
out = out.view(out.size(0), -1 )
#out = (self.fc1(out))
return out
class MyLeNetSimNormal(nn.Module):#epoch 12s
def __init__(self):
super(MyLeNetSimNormal, self).__init__()
self.conv1 = nn.Conv2d(3, 200, 3, stride=1)
self.conv2 = nn.Conv2d(200, 400, 3, stride=1)
self.conv3 = nn.Conv2d(400, 800, 3, stride=1)
self.conv4 = nn.Conv2d(800, 10, 3, stride=1)
#self.fc1 = nn.Linear(800, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x))
# out = self.savg_pool2d(out,2,ceil_mode=True)
out = savg_pool2d(out,2, mode='s', ceil_mode=True)
out = F.relu(self.conv2(out))
# out = self.savg_pool2d(out,2,ceil_mode=True)
out = savg_pool2d(out,2, mode='s', ceil_mode=True)
out = F.relu(self.conv3(out))
# out = self.savg_pool2d(out,2,ceil_mode=True)
out = savg_pool2d(out,2, mode='s', ceil_mode=True)
out = self.conv4(out)
#out = F.avg_pool2d(out,2,ceil_mode=True)
out = out.view(out.size(0), -1 )
#out = (self.fc1(out))
return out
class MyLeNetStride(nn.Module):#epoch 6s
def __init__(self):
super(MyLeNetStride, self).__init__()
self.conv1 = nn.Conv2d(3, 200, 3, stride=2)
self.conv2 = nn.Conv2d(200, 400, 3, stride=2)
self.conv3 = nn.Conv2d(400, 800, 3, stride=2)
self.conv4 = nn.Conv2d(800, 10, 3, stride=1)
#self.fc1 = nn.Linear(800, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = F.relu(self.conv3(out))
out = self.conv4(out)
out = out.view(out.size(0), -1 )
#out = (self.fc1(out))
return out
class MyLeNetMatNormal(nn.Module):#epach 21s
def __init__(self):
super(MyLeNetMatNormal, self).__init__()
self.conv1 = SConv2dAvg(3, 200, 3, stride=1)
self.conv2 = SConv2dAvg(200, 400, 3, stride=1)
self.conv3 = SConv2dAvg(400, 800, 3, stride=1)
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)
#self.fc1 = nn.Linear(800, 10)
def forward(self, x, stoch=True):
out = F.relu(self.conv1(x))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv2(out))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = F.relu(self.conv3(out))
out = F.avg_pool2d(out,2,ceil_mode=True)
out = (self.conv4(out))
#out = F.avg_pool2d(out,1,ceil_mode=True)
out = out.view(out.size(0), -1 )
#out = (self.fc1(out))
return out
class MyLeNetMatStoch(nn.Module):#epoch 17s
def __init__(self):
super(MyLeNetMatStoch, self).__init__()
self.conv1 = SConv2dAvg(3, 200, 3, stride=2)
self.conv2 = SConv2dAvg(200, 400, 3, stride=2)
self.conv3 = SConv2dAvg(400, 800, 3, stride=2)
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)
#self.fc1 = nn.Linear(800, 10)
def forward(self, x, stoch=True):
print('in',x.shape)
out = F.relu(self.conv1(x,stoch=stoch))
print('c1',out.shape)
out = F.relu(self.conv2(out,stoch=stoch))
print('c2', out.shape)
out = F.relu(self.conv3(out,stoch=stoch))
print('c3',out.shape)
#hkjhlg
out = self.conv4(out,stoch=stoch)
print('c4',out.shape)
out = out.view(out.size(0), -1 )
#out = self.fc1(out)
return out
class MyLeNetMatStochBU(nn.Module):#epoch 11s
def __init__(self):
super(MyLeNetMatStochBU, self).__init__()
self.conv1 = SConv2dAvg(3, 200, 3, stride=2)
self.conv2 = SConv2dAvg(200, 400, 3, stride=2)
self.conv3 = SConv2dAvg(400, 800, 3, stride=2, ceil_mode=True)
self.conv4 = SConv2dAvg(800, 10, 3, stride=1)
# self.fc1 = nn.Linear(800, 10)
def forward(self, x, stoch=True):
#get sizes
h0,w0 = x.shape[2],x.shape[3]
h1,w1 = self.conv1.get_size(h0,w0)
h2,w2 = self.conv2.get_size(h1,w1)
h3,w3 = self.conv3.get_size(h2,w2)
print(h0,w0)
print(h1,w1)
print(h2,w2)
print(h3,w3)
#sample BU
mask3 = torch.ones(h3,w3).to(x.device)
print(mask3.shape)
selh3,selw3,mask2 = self.conv3.sample(h2,w2,mask=mask3)
print(mask2.shape)
selh2,selw2,mask1 = self.conv2.sample(h1,w1,mask=mask2)
print(mask1.shape)
selh1,selw1,mask0 = self.conv1.sample(h0,w0,mask=mask1)
#forward
out = F.relu(self.conv1(x,selh1,selw1,mask1,stoch=stoch))
out = F.relu(self.conv2(out,selh2,selw2,mask2,stoch=stoch))
out = F.relu(self.conv3(out,selh3,selw3,mask3,stoch=stoch))
out = self.conv4(out,stoch=stoch)
out = out.view(out.size(0), -1 )
# out = (self.fc1(out))
return out
# class SConv2dAvg(nn.Module):
# def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
# super(SConv2dAvg, self).__init__()
# conv = nn.Conv2d(in_channels, out_channels, kernel_size)
# self.deconv = nn.ConvTranspose2d(1, 1, kernel_size, 1, padding=0, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
# nn.init.constant_(self.deconv.weight, 1)
# self.pooldeconv = nn.ConvTranspose2d(1, 1, kernel_size=stride,padding=0,stride=stride, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros')
# nn.init.constant_(self.pooldeconv.weight, 1)
# self.weight = nn.Parameter(conv.weight)
# self.bias = nn.Parameter(conv.bias)
# self.stride = stride
# self.dilation = dilation
# self.padding = padding
# self.kernel_size = kernel_size
# def forward(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True):
# stride = self.stride
# if stoch==False:
# stride=1
# batch_size, in_channels, in_h, in_w = input.shape
# out_channels, in_channels, kh, kw = self.weight.shape
# afterconv_h = in_h-(kh-1)
# afterconv_w = in_w-(kw-1)
# out_h = int((afterconv_h+stride-1)/stride)
# out_w = int((afterconv_w+stride-1)/stride)
# unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
# inp_unf = unfold(input)
# if stride!=1:
# inp_unf = inp_unf.view(batch_size,in_channels*kh*kw,afterconv_h,afterconv_w)
# if selh[0,0]==-1:
# resth = (out_h*stride)-afterconv_h
# restw = (out_w*stride)-afterconv_w
# selh = torch.cuda.LongTensor(out_h,out_w).random_(0, stride)
# selw = torch.cuda.LongTensor(out_h,out_w).random_(0, stride)
# #if resth!=0:
# # selh[-1,:]=selh[-1,:]%(stride-resth);selh[:,-1]=selh[:,-1]%(stride-restw)
# # selw[-1,:]=selw[-1,:]%(stride-resth);selw[:,-1]=selw[:,-1]%(stride-restw)
# #if mask[0,0]==-1
# # mask = torch.ones(out_h,out_w,device=torch.device('cuda'))
# rng_h = selh + torch.arange(0,out_h*stride,stride,device=torch.device('cuda')).view(-1,1)
# rng_w = selw + torch.arange(0,out_w*stride,stride,device=torch.device('cuda'))
# if mask[0,0]==-1:
# inp_unf = inp_unf[:,:,rng_h,rng_w].view(batch_size,in_channels*kh*kw,-1)
# else:
# inp_unf = inp_unf[:,:,rng_h[mask>0],rng_w[mask>0]]
# if self.bias is None:
# out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
# else:
# out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
# if stride==1 or mask[0,0]==-1:
# out = out_unf.view(batch_size,out_channels,out_h,out_w)
# if stoch==False:
# out = F.avg_pool2d(out,self.stride,ceil_mode=True)
# else:
# out = torch.zeros(batch_size, out_channels,out_h,out_w,device=torch.device('cuda'))
# out[:,:,mask>0] = out_unf
# return out
# def forward_(self, input, selh=-torch.ones(1,1), selw=-torch.ones(1,1), mask=-torch.ones(1,1),stoch=True):
# stride = self.stride
# if stoch==False:
# stride=1
# batch_size, in_channels, in_h, in_w = input.shape
# out_channels, in_channels, kh, kw = self.weight.shape
# afterconv_h = in_h-(kh-1)
# afterconv_w = in_w-(kw-1)
# out_h = (afterconv_h+stride-1)/stride
# out_w = (afterconv_w+stride-1)/stride
# unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=self.dilation, padding=self.padding, stride=1)
# inp_unf = unfold(input)
# if self.bias is None:
# out_unf = inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)
# else:
# out_unf = (inp_unf.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()) + self.bias).transpose(1, 2)
# out = out_unf.view(batch_size,out_channels,afterconv_h,afterconv_w)
# if stoch==False:
# out = F.avg_pool2d(out,self.stride,ceil_mode=True)
# return out
# def sample(self,h,w,mask):
# out_channels, in_channels, kh, kw = self.weight.shape
# afterconv_h = h-(kh-1)
# afterconv_w = w-(kw-1)
# out_h = (afterconv_h+self.stride-1)/self.stride
# out_w = (afterconv_w+self.stride-1)/self.stride
# selh = torch.cuda.LongTensor(out_h,out_w).random_(0, self.stride)
# selw = torch.cuda.LongTensor(out_h,out_w).random_(0, self.stride)
# resth = (out_h*self.stride)-afterconv_h
# restw = (out_w*self.stride)-afterconv_w
# #print(resth)
# #if resth!=0:
# # selh[-1,:]=selh[-1,:]%(self.stride-resth);selh[:,-1]=selh[:,-1]%(self.stride-restw)
# # selw[-1,:]=selw[-1,:]%(self.stride-resth);selw[:,-1]=selw[:,-1]%(self.stride-restw)
# maskh = (out_h)*self.stride#-resth#+self.kernel_size-1
# maskw = (out_w)*self.stride#-restw#+self.kernel_size-1
# rng_h = selh + torch.arange(0,out_h*self.stride,self.stride,device=torch.device('cuda')).view(-1,1)
# rng_w = selw + torch.arange(0,out_w*self.stride,self.stride,device=torch.device('cuda'))
# nmask = torch.zeros((maskh,maskw),device=torch.device('cuda'))
# nmask[rng_h,rng_w] = 1
# #rmask = mask * nmask
# dmask = self.pooldeconv(mask.float().view(1,1,mask.shape[0],mask.shape[1]))
# rmask = nmask * dmask
# #rmask = rmask[:,:,:out_h,:out_w]
# fmask = self.deconv(rmask)
# fmask = fmask[0,0]
# return selh,selw,fmask.long()
# def get_size(self,h,w):
# newh=(h-(self.kernel_size-1)+(self.stride-1))/self.stride
# neww=(w-(self.kernel_size-1)+(self.stride-1))/self.stride
# return newh,neww
# def savg_pool2d(x,size,ceil_mode=False):
# b,c,h,w = x.shape
# selh = torch.LongTensor(h/size,w/size).random_(0, size)
# rngh = torch.arange(0,h,size).long().view(h/size,1).repeat(1,w/size).view(h/size,w/size)
# selx = (selh+rngh).repeat(b,c,1,1)
# selw = torch.LongTensor(h/size,w/size).random_(0, size)
# rngw = torch.arange(0,w,size).long().view(1,h/size).repeat(h/size,1).view(h/size,w/size)
# sely = (selw+rngw).repeat(b,c,1,1)
# bv, cv ,hv, wv = torch.meshgrid([torch.arange(0,b), torch.arange(0,c),torch.arange(0,h/size),torch.arange(0,w/size)])
# #x=x.view(b,c,h*w)
# newx = x[bv,cv, selx, sely]
# #ghdh
# return newx
# def savg_pool2d_(x,size,ceil_mode=False):
# b,c,h,w = x.shape
# selh = torch.cuda.LongTensor(h/size,w/size).random_(0, size)
# rngh = torch.arange(0,h,size,device=torch.device('cuda')).view(-1,1)
# selx = selh+rngh
# selw = torch.cuda.LongTensor(h/size,w/size).random_(0, size)
# rngw = torch.arange(0,w,size,device=torch.device('cuda'))
# sely = selw+rngw
# #bv, cv ,hv, wv = torch.meshgrid([torch.arange(0,b), torch.arange(0,c),torch.arange(0,h/size),torch.arange(0,w/size)])
# #x=x.view(b,c,h*w)
# newx = x[:,:, selx, sely]
# #ghdh
# return newx