Rangement

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-24 14:32:37 -05:00
parent f83c73ec17
commit f507ff4741
16 changed files with 85 additions and 46 deletions

26
higher/smart_aug/model.py Executable file
View file

@ -0,0 +1,26 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
## Basic CNN ##
class LeNet(nn.Module):
def __init__(self, num_inp, num_out):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(num_inp, 20, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(20, 50, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(5*5*50, 500)
self.fc2 = nn.Linear(500, num_out)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def __str__(self):
return "LeNet"