mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Ajout WideResNet (A tester) pour comparaison a PBA
This commit is contained in:
parent
3cffac9852
commit
96bb7d5002
2 changed files with 93 additions and 4 deletions
|
@ -3,6 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
## Basic CNN ##
|
||||||
class LeNet(nn.Module):
|
class LeNet(nn.Module):
|
||||||
def __init__(self, num_inp, num_out):
|
def __init__(self, num_inp, num_out):
|
||||||
super(LeNet, self).__init__()
|
super(LeNet, self).__init__()
|
||||||
|
@ -49,3 +50,91 @@ class LeNet(nn.Module):
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "LeNet"
|
return "LeNet"
|
||||||
|
|
||||||
|
## Wide ResNet ##
|
||||||
|
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
|
||||||
|
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||||
|
self.relu2 = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||||
|
padding=1, bias=False)
|
||||||
|
self.droprate = dropRate
|
||||||
|
self.equalInOut = (in_planes == out_planes)
|
||||||
|
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||||
|
padding=0, bias=False) or None
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.equalInOut:
|
||||||
|
x = self.relu1(self.bn1(x))
|
||||||
|
else:
|
||||||
|
out = self.relu1(self.bn1(x))
|
||||||
|
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||||
|
if self.droprate > 0:
|
||||||
|
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||||
|
out = self.conv2(out)
|
||||||
|
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||||
|
|
||||||
|
class NetworkBlock(nn.Module):
|
||||||
|
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
|
||||||
|
super(NetworkBlock, self).__init__()
|
||||||
|
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
|
||||||
|
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
|
||||||
|
layers = []
|
||||||
|
for i in range(int(nb_layers)):
|
||||||
|
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layer(x)
|
||||||
|
|
||||||
|
class WideResNet(nn.Module):
|
||||||
|
#def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
|
||||||
|
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
|
||||||
|
super(WideResNet, self).__init__()
|
||||||
|
|
||||||
|
kernel_size = wrn_size
|
||||||
|
filter_size = 3
|
||||||
|
nChannels = [min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4]
|
||||||
|
strides = [1, 2, 2] # stride for each resblock
|
||||||
|
|
||||||
|
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
|
||||||
|
assert((depth - 4) % 6 == 0)
|
||||||
|
n = (depth - 4) / 6
|
||||||
|
block = BasicBlock
|
||||||
|
# 1st conv before any network block
|
||||||
|
self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1,
|
||||||
|
padding=1, bias=False)
|
||||||
|
# 1st block
|
||||||
|
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate)
|
||||||
|
# 2nd block
|
||||||
|
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate)
|
||||||
|
# 3rd block
|
||||||
|
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate)
|
||||||
|
# global average pooling and classifier
|
||||||
|
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.fc = nn.Linear(nChannels[3], num_classes)
|
||||||
|
self.nChannels = nChannels[3]
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
m.bias.data.zero_()
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.block1(out)
|
||||||
|
out = self.block2(out)
|
||||||
|
out = self.block3(out)
|
||||||
|
out = self.relu(self.bn1(out))
|
||||||
|
out = F.avg_pool2d(out, 8)
|
||||||
|
out = out.view(-1, self.nChannels)
|
||||||
|
return self.fc(out)
|
|
@ -783,8 +783,8 @@ if __name__ == "__main__":
|
||||||
#### TF number tests ####
|
#### TF number tests ####
|
||||||
#'''
|
#'''
|
||||||
res_folder="res/TF_nb_tests/"
|
res_folder="res/TF_nb_tests/"
|
||||||
epochs= 200
|
epochs= 100
|
||||||
inner_its = [0, 10]
|
inner_its = [10]
|
||||||
dataug_epoch_starts= [0]
|
dataug_epoch_starts= [0]
|
||||||
TF_nb = [len(TF.TF_dict)] #range(1,len(TF.TF_dict)+1)
|
TF_nb = [len(TF.TF_dict)] #range(1,len(TF.TF_dict)+1)
|
||||||
N_seq_TF= [1, 2, 3, 4]
|
N_seq_TF= [1, 2, 3, 4]
|
||||||
|
@ -808,7 +808,7 @@ if __name__ == "__main__":
|
||||||
aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, N_TF=n_tf, mix_dist=0.0), LeNet(3,10)).to(device)
|
aug_model = Augmented_model(Data_augV4(TF_dict=ntf_dict, N_TF=n_tf, mix_dist=0.0), LeNet(3,10)).to(device)
|
||||||
print(str(aug_model), 'on', device_name)
|
print(str(aug_model), 'on', device_name)
|
||||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10)
|
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
||||||
|
|
||||||
####
|
####
|
||||||
plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))
|
plot_res(log, fig_name=res_folder+"{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue