mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 17:50:46 +02:00
107 lines
3.1 KiB
Python
107 lines
3.1 KiB
Python
'''GoogLeNet with PyTorch.'''
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Inception(nn.Module):
|
|
def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
|
|
super(Inception, self).__init__()
|
|
# 1x1 conv branch
|
|
self.b1 = nn.Sequential(
|
|
nn.Conv2d(in_planes, n1x1, kernel_size=1),
|
|
nn.BatchNorm2d(n1x1),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
# 1x1 conv -> 3x3 conv branch
|
|
self.b2 = nn.Sequential(
|
|
nn.Conv2d(in_planes, n3x3red, kernel_size=1),
|
|
nn.BatchNorm2d(n3x3red),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(n3x3),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
# 1x1 conv -> 5x5 conv branch
|
|
self.b3 = nn.Sequential(
|
|
nn.Conv2d(in_planes, n5x5red, kernel_size=1),
|
|
nn.BatchNorm2d(n5x5red),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(n5x5),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(n5x5),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
# 3x3 pool -> 1x1 conv branch
|
|
self.b4 = nn.Sequential(
|
|
nn.MaxPool2d(3, stride=1, padding=1),
|
|
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
|
|
nn.BatchNorm2d(pool_planes),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
def forward(self, x):
|
|
y1 = self.b1(x)
|
|
y2 = self.b2(x)
|
|
y3 = self.b3(x)
|
|
y4 = self.b4(x)
|
|
return torch.cat([y1,y2,y3,y4], 1)
|
|
|
|
|
|
class GoogLeNet(nn.Module):
|
|
def __init__(self):
|
|
super(GoogLeNet, self).__init__()
|
|
self.pre_layers = nn.Sequential(
|
|
nn.Conv2d(3, 192, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(192),
|
|
nn.ReLU(True),
|
|
)
|
|
|
|
self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
|
|
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
|
|
|
|
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
|
|
|
self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
|
|
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
|
|
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
|
|
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
|
|
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
|
|
|
|
self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
|
|
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
|
|
|
|
self.avgpool = nn.AvgPool2d(8, stride=1)
|
|
self.linear = nn.Linear(1024, 10)
|
|
|
|
def forward(self, x):
|
|
out = self.pre_layers(x)
|
|
out = self.a3(out)
|
|
out = self.b3(out)
|
|
out = self.maxpool(out)
|
|
out = self.a4(out)
|
|
out = self.b4(out)
|
|
out = self.c4(out)
|
|
out = self.d4(out)
|
|
out = self.e4(out)
|
|
out = self.maxpool(out)
|
|
out = self.a5(out)
|
|
out = self.b5(out)
|
|
out = self.avgpool(out)
|
|
out = out.view(out.size(0), -1)
|
|
out = self.linear(out)
|
|
return out
|
|
|
|
|
|
def test():
|
|
net = GoogLeNet()
|
|
x = torch.randn(1,3,32,32)
|
|
y = net(x)
|
|
print(y.size())
|
|
|
|
# test()
|