mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 09:40:46 +02:00
Initial commit
This commit is contained in:
parent
2ba6dbe7cc
commit
3de923156c
32 changed files with 4054 additions and 1 deletions
125
models/Old/pnasnet.py
Normal file
125
models/Old/pnasnet.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
'''PNASNet in PyTorch.
|
||||
|
||||
Paper: Progressive Neural Architecture Search
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
'''Separable Convolution.'''
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride):
|
||||
super(SepConv, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes,
|
||||
kernel_size, stride,
|
||||
padding=(kernel_size-1)//2,
|
||||
bias=False, groups=in_planes)
|
||||
self.bn1 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn1(self.conv1(x))
|
||||
|
||||
|
||||
class CellA(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=1):
|
||||
super(CellA, self).__init__()
|
||||
self.stride = stride
|
||||
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
|
||||
if stride==2:
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
def forward(self, x):
|
||||
y1 = self.sep_conv1(x)
|
||||
y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
|
||||
if self.stride==2:
|
||||
y2 = self.bn1(self.conv1(y2))
|
||||
return F.relu(y1+y2)
|
||||
|
||||
class CellB(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=1):
|
||||
super(CellB, self).__init__()
|
||||
self.stride = stride
|
||||
# Left branch
|
||||
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
|
||||
self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
|
||||
# Right branch
|
||||
self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
|
||||
if stride==2:
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_planes)
|
||||
# Reduce channels
|
||||
self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
|
||||
def forward(self, x):
|
||||
# Left branch
|
||||
y1 = self.sep_conv1(x)
|
||||
y2 = self.sep_conv2(x)
|
||||
# Right branch
|
||||
y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
|
||||
if self.stride==2:
|
||||
y3 = self.bn1(self.conv1(y3))
|
||||
y4 = self.sep_conv3(x)
|
||||
# Concat & reduce channels
|
||||
b1 = F.relu(y1+y2)
|
||||
b2 = F.relu(y3+y4)
|
||||
y = torch.cat([b1,b2], 1)
|
||||
return F.relu(self.bn2(self.conv2(y)))
|
||||
|
||||
class PNASNet(nn.Module):
|
||||
def __init__(self, cell_type, num_cells, num_planes):
|
||||
super(PNASNet, self).__init__()
|
||||
self.in_planes = num_planes
|
||||
self.cell_type = cell_type
|
||||
|
||||
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(num_planes)
|
||||
|
||||
self.layer1 = self._make_layer(num_planes, num_cells=6)
|
||||
self.layer2 = self._downsample(num_planes*2)
|
||||
self.layer3 = self._make_layer(num_planes*2, num_cells=6)
|
||||
self.layer4 = self._downsample(num_planes*4)
|
||||
self.layer5 = self._make_layer(num_planes*4, num_cells=6)
|
||||
|
||||
self.linear = nn.Linear(num_planes*4, 10)
|
||||
|
||||
def _make_layer(self, planes, num_cells):
|
||||
layers = []
|
||||
for _ in range(num_cells):
|
||||
layers.append(self.cell_type(self.in_planes, planes, stride=1))
|
||||
self.in_planes = planes
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _downsample(self, planes):
|
||||
layer = self.cell_type(self.in_planes, planes, stride=2)
|
||||
self.in_planes = planes
|
||||
return layer
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = self.layer5(out)
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = self.linear(out.view(out.size(0), -1))
|
||||
return out
|
||||
|
||||
|
||||
def PNASNetA():
|
||||
return PNASNet(CellA, num_cells=6, num_planes=44)
|
||||
|
||||
def PNASNetB():
|
||||
return PNASNet(CellB, num_cells=6, num_planes=32)
|
||||
|
||||
|
||||
def test():
|
||||
net = PNASNetB()
|
||||
x = torch.randn(1,3,32,32)
|
||||
y = net(x)
|
||||
print(y)
|
||||
|
||||
# test()
|
Loading…
Add table
Add a link
Reference in a new issue