mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 04:00:46 +02:00
Ajout WideResNet (A tester) pour comparaison a PBA
This commit is contained in:
parent
3cffac9852
commit
96bb7d5002
2 changed files with 93 additions and 4 deletions
|
@ -3,6 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet(nn.Module):
|
||||
def __init__(self, num_inp, num_out):
|
||||
super(LeNet, self).__init__()
|
||||
|
@ -48,4 +49,92 @@ class LeNet(nn.Module):
|
|||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "LeNet"
|
||||
return "LeNet"
|
||||
|
||||
## Wide ResNet ##
|
||||
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
|
||||
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
self.equalInOut = (in_planes == out_planes)
|
||||
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||
padding=0, bias=False) or None
|
||||
def forward(self, x):
|
||||
if not self.equalInOut:
|
||||
x = self.relu1(self.bn1(x))
|
||||
else:
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||
out = self.conv2(out)
|
||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||
|
||||
class NetworkBlock(nn.Module):
|
||||
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
|
||||
super(NetworkBlock, self).__init__()
|
||||
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
|
||||
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
|
||||
layers = []
|
||||
for i in range(int(nb_layers)):
|
||||
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
|
||||
return nn.Sequential(*layers)
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
#def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
|
||||
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
|
||||
super(WideResNet, self).__init__()
|
||||
|
||||
kernel_size = wrn_size
|
||||
filter_size = 3
|
||||
nChannels = [min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4]
|
||||
strides = [1, 2, 2] # stride for each resblock
|
||||
|
||||
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
|
||||
assert((depth - 4) % 6 == 0)
|
||||
n = (depth - 4) / 6
|
||||
block = BasicBlock
|
||||
# 1st conv before any network block
|
||||
self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
# 1st block
|
||||
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate)
|
||||
# 2nd block
|
||||
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate)
|
||||
# 3rd block
|
||||
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate)
|
||||
# global average pooling and classifier
|
||||
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc = nn.Linear(nChannels[3], num_classes)
|
||||
self.nChannels = nChannels[3]
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.block1(out)
|
||||
out = self.block2(out)
|
||||
out = self.block3(out)
|
||||
out = self.relu(self.bn1(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.nChannels)
|
||||
return self.fc(out)
|
|
@ -783,8 +783,8 @@ if __name__ == "__main__":
|
|||
#### TF number tests ####
|
||||
#'''
|
||||
res_folder="res/TF_nb_tests/"
|
||||
epochs= 200
|
||||
inner_its = [0, 10]
|
||||
epochs= 100
|
||||
inner_its = [10]
|
||||
dataug_epoch_starts= [0]
|
||||
TF_nb = [len(TF.TF_dict)] #range(1,len(TF.TF_dict)+1)
|
||||
N_seq_TF= [1, 2, 3, 4]
|
||||
|
@ -808,7 +808,7 @@ if __name__ == "__main__":
|
|||
aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, N_TF=n_tf, mix_dist=0.0), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
||||
|
||||
####
|
||||
plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue