mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Changes since Teledyne
This commit is contained in:
parent
03ffd7fe05
commit
b89dac9084
185 changed files with 16668 additions and 484 deletions
56
higher/smart_aug/nets/LeNet.py
Normal file
56
higher/smart_aug/nets/LeNet.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet(nn.Module):
|
||||
"""Basic CNN.
|
||||
|
||||
"""
|
||||
def __init__(self, num_inp, num_out):
|
||||
"""Init LeNet.
|
||||
|
||||
"""
|
||||
super(LeNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_inp, 20, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(20, 50, 5)
|
||||
self.pool2 = nn.MaxPool2d(2, 2)
|
||||
#self.fc1 = nn.Linear(4*4*50, 500)
|
||||
self.fc1 = nn.Linear(5*5*50, 500)
|
||||
self.fc2 = nn.Linear(500, num_out)
|
||||
|
||||
def forward(self, x):
|
||||
"""Main method of LeNet
|
||||
|
||||
"""
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool2(F.relu(self.conv2(x)))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "LeNet"
|
||||
|
||||
#MNIST
|
||||
class MLPNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(MLPNet, self).__init__()
|
||||
self.fc1 = nn.Linear(28*28, 500)
|
||||
self.fc2 = nn.Linear(500, 256)
|
||||
self.fc3 = nn.Linear(256, 10)
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 28*28)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
def name(self):
|
||||
return "MLP"
|
426
higher/smart_aug/nets/resnet_abn.py
Normal file
426
higher/smart_aug/nets/resnet_abn.py
Normal file
|
@ -0,0 +1,426 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
|
||||
|
||||
# __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
# 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
# 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
# model_urls = {
|
||||
# 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
# 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
# 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
# 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
# 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
# }
|
||||
|
||||
__all__ = ['ResNet_ABN', 'resnet18_ABN', 'resnet34_ABN', 'resnet50_ABN', 'resnet101_ABN',
|
||||
'resnet152_ABN', 'resnext50_32x4d_ABN', 'resnext101_32x8d_ABN',
|
||||
'wide_resnet50_2_ABN', 'wide_resnet101_2_ABN']
|
||||
|
||||
class aux_batchNorm(nn.Module):
|
||||
def __init__(self, norm_layer, nb_features):
|
||||
super(aux_batchNorm, self).__init__()
|
||||
self.mode='clean'
|
||||
self.bn=nn.ModuleDict({
|
||||
'clean': norm_layer(nb_features),
|
||||
'augmented': norm_layer(nb_features)
|
||||
})
|
||||
def forward(self, x):
|
||||
if self.mode is 'mixed':
|
||||
running_mean=(self.bn['clean'].running_mean+self.bn['augmented'].running_mean)/2
|
||||
running_var=(self.bn['clean'].running_var+self.bn['augmented'].running_var)/2
|
||||
return nn.functional.batch_norm(x, running_mean, running_var, self.bn['clean'].weight, self.bn['clean'].bias)
|
||||
return self.bn[self.mode](x)
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock_ABN(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
#self.bn1 = norm_layer(planes)
|
||||
self.bn1 = aux_batchNorm(norm_layer, planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
#self.bn2 = norm_layer(planes)
|
||||
self.bn2 = aux_batchNorm(norm_layer, planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck_ABN(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
#self.bn1 = norm_layer(width)
|
||||
self.bn1 = aux_batchNorm(norm_layer, width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
# self.bn2 = norm_layer(width)
|
||||
self.bn2 = aux_batchNorm(norm_layer, width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
# self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.bn3 = aux_batchNorm(norm_layer, planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_ABN(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None):
|
||||
super(ResNet_ABN, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
#self.bn1 = norm_layer(self.inplanes)
|
||||
self.bn1 = aux_batchNorm(norm_layer, self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
print('WARNING : zero_init_residual not implemented with ABN')
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, Bottleneck):
|
||||
# nn.init.constant_(m.bn3.weight, 0)
|
||||
# elif isinstance(m, BasicBlock):
|
||||
# nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
# Memoire des BN layers pas fonctinnel avec Higher
|
||||
# self.bn_layers=[]
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, aux_batchNorm):
|
||||
# self.bn_layers.append(m)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
#norm_layer(planes * block.expansion),
|
||||
aux_batchNorm(norm_layer, planes * block.expansion)
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def set_mode(self, mode):
|
||||
# for bn in self.bn_layers:
|
||||
for m in self.modules():
|
||||
if isinstance(m, aux_batchNorm):
|
||||
m.mode=mode
|
||||
|
||||
|
||||
|
||||
# def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
# model = ResNet(block, layers, **kwargs)
|
||||
# if pretrained:
|
||||
# state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
# progress=progress)
|
||||
# model.load_state_dict(state_dict)
|
||||
# return model
|
||||
|
||||
|
||||
def resnet18_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(BasicBlock_ABN, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet34_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(BasicBlock_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet50_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet101_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnet152_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
# return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
# **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 8, 36, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def resnext50_32x4d_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
# return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d_ABN(pretrained=False, progress=True, **kwargs):
|
||||
"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
# return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def wide_resnet50_2_ABN(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
# return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 6, 3], **kwargs)
|
||||
|
||||
|
||||
|
||||
def wide_resnet101_2_ABN(pretrained=False, progress=True, **kwargs):
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
# return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
# pretrained, progress, **kwargs)
|
||||
if(pretrained):
|
||||
print('WARNING: pretrained weight support not implemented for Auxiliary Batch Norm')
|
||||
return ResNet_ABN(Bottleneck_ABN, [3, 4, 23, 3], **kwargs)
|
618
higher/smart_aug/nets/resnet_deconv.py
Normal file
618
higher/smart_aug/nets/resnet_deconv.py
Normal file
|
@ -0,0 +1,618 @@
|
|||
'''ResNet in PyTorch.
|
||||
For Pre-activation ResNet, see 'preact_resnet.py'.
|
||||
Reference:
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
|
||||
Deep Residual Learning for Image Recognition. arXiv:1512.03385
|
||||
|
||||
https://github.com/yechengxi/deconvolution
|
||||
'''
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from torch.nn.modules import conv
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from functools import partial
|
||||
|
||||
__all__ = ['ResNet18_DC', 'ResNet34_DC', 'ResNet50_DC', 'ResNet101_DC', 'ResNet152_DC', 'WRN_DC26_10']
|
||||
|
||||
### Deconvolution ###
|
||||
|
||||
#iteratively solve for inverse sqrt of a matrix
|
||||
def isqrt_newton_schulz_autograd(A, numIters):
|
||||
dim = A.shape[0]
|
||||
normA=A.norm()
|
||||
Y = A.div(normA)
|
||||
I = torch.eye(dim,dtype=A.dtype,device=A.device)
|
||||
Z = torch.eye(dim,dtype=A.dtype,device=A.device)
|
||||
|
||||
for i in range(numIters):
|
||||
T = 0.5*(3.0*I - Z@Y)
|
||||
Y = Y@T
|
||||
Z = T@Z
|
||||
#A_sqrt = Y*torch.sqrt(normA)
|
||||
A_isqrt = Z / torch.sqrt(normA)
|
||||
return A_isqrt
|
||||
|
||||
def isqrt_newton_schulz_autograd_batch(A, numIters):
|
||||
batchSize,dim,_ = A.shape
|
||||
normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
|
||||
Y = A.div(normA)
|
||||
I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
|
||||
Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
|
||||
|
||||
for i in range(numIters):
|
||||
T = 0.5*(3.0*I - Z.bmm(Y))
|
||||
Y = Y.bmm(T)
|
||||
Z = T.bmm(Z)
|
||||
#A_sqrt = Y*torch.sqrt(normA)
|
||||
A_isqrt = Z / torch.sqrt(normA)
|
||||
|
||||
return A_isqrt
|
||||
|
||||
|
||||
|
||||
#deconvolve channels
|
||||
class ChannelDeconv(nn.Module):
|
||||
def __init__(self, block, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3):
|
||||
super(ChannelDeconv, self).__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.n_iter=n_iter
|
||||
self.momentum=momentum
|
||||
self.block = block
|
||||
|
||||
self.register_buffer('running_mean1', torch.zeros(block, 1))
|
||||
#self.register_buffer('running_cov', torch.eye(block))
|
||||
self.register_buffer('running_deconv', torch.eye(block))
|
||||
self.register_buffer('running_mean2', torch.zeros(1, 1))
|
||||
self.register_buffer('running_var', torch.ones(1, 1))
|
||||
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
|
||||
self.sampling_stride=sampling_stride
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
if len(x.shape)==2:
|
||||
x=x.view(x.shape[0],x.shape[1],1,1)
|
||||
if len(x.shape)==3:
|
||||
print('Error! Unsupprted tensor shape.')
|
||||
|
||||
N, C, H, W = x.size()
|
||||
B = self.block
|
||||
|
||||
#take the first c channels out for deconv
|
||||
c=int(C/B)*B
|
||||
if c==0:
|
||||
print('Error! block should be set smaller.')
|
||||
|
||||
#step 1. remove mean
|
||||
if c!=C:
|
||||
x1=x[:,:c].permute(1,0,2,3).contiguous().view(B,-1)
|
||||
else:
|
||||
x1=x.permute(1,0,2,3).contiguous().view(B,-1)
|
||||
|
||||
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
|
||||
x1_s = x1[:,::self.sampling_stride**2]
|
||||
else:
|
||||
x1_s=x1
|
||||
|
||||
mean1 = x1_s.mean(-1, keepdim=True)
|
||||
|
||||
if self.num_batches_tracked==0:
|
||||
self.running_mean1.copy_(mean1.detach())
|
||||
if self.training:
|
||||
self.running_mean1.mul_(1-self.momentum)
|
||||
self.running_mean1.add_(mean1.detach()*self.momentum)
|
||||
else:
|
||||
mean1 = self.running_mean1
|
||||
|
||||
x1=x1-mean1
|
||||
|
||||
#step 2. calculate deconv@x1 = cov^(-0.5)@x1
|
||||
if self.training:
|
||||
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device)
|
||||
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
|
||||
|
||||
if self.num_batches_tracked==0:
|
||||
#self.running_cov.copy_(cov.detach())
|
||||
self.running_deconv.copy_(deconv.detach())
|
||||
|
||||
if self.training:
|
||||
#self.running_cov.mul_(1-self.momentum)
|
||||
#self.running_cov.add_(cov.detach()*self.momentum)
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
else:
|
||||
# cov = self.running_cov
|
||||
deconv = self.running_deconv
|
||||
|
||||
x1 =deconv@x1
|
||||
|
||||
#reshape to N,c,J,W
|
||||
x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3)
|
||||
|
||||
# normalize the remaining channels
|
||||
if c!=C:
|
||||
x_tmp=x[:, c:].view(N,-1)
|
||||
if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride:
|
||||
x_s = x_tmp[:, ::self.sampling_stride ** 2]
|
||||
else:
|
||||
x_s = x_tmp
|
||||
|
||||
mean2=x_s.mean()
|
||||
var=x_s.var()
|
||||
|
||||
if self.num_batches_tracked == 0:
|
||||
self.running_mean2.copy_(mean2.detach())
|
||||
self.running_var.copy_(var.detach())
|
||||
|
||||
if self.training:
|
||||
self.running_mean2.mul_(1 - self.momentum)
|
||||
self.running_mean2.add_(mean2.detach() * self.momentum)
|
||||
self.running_var.mul_(1 - self.momentum)
|
||||
self.running_var.add_(var.detach() * self.momentum)
|
||||
else:
|
||||
mean2 = self.running_mean2
|
||||
var = self.running_var
|
||||
|
||||
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
|
||||
x1 = torch.cat([x1, x_tmp], dim=1)
|
||||
|
||||
|
||||
if self.training:
|
||||
self.num_batches_tracked.add_(1)
|
||||
|
||||
if len(x_shape)==2:
|
||||
x1=x1.view(x_shape)
|
||||
return x1
|
||||
|
||||
#An alternative implementation
|
||||
class Delinear(nn.Module):
|
||||
__constants__ = ['bias', 'in_features', 'out_features']
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512):
|
||||
super(Delinear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
|
||||
if block > in_features:
|
||||
block = in_features
|
||||
else:
|
||||
if in_features%block!=0:
|
||||
block=math.gcd(block,in_features)
|
||||
print('block size set to:', block)
|
||||
self.block = block
|
||||
self.momentum = momentum
|
||||
self.n_iter = n_iter
|
||||
self.eps = eps
|
||||
self.register_buffer('running_mean', torch.zeros(self.block))
|
||||
self.register_buffer('running_deconv', torch.eye(self.block))
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if self.training:
|
||||
|
||||
# 1. reshape
|
||||
X=input.view(-1, self.block)
|
||||
|
||||
# 2. subtract mean
|
||||
X_mean = X.mean(0)
|
||||
X = X - X_mean.unsqueeze(0)
|
||||
self.running_mean.mul_(1 - self.momentum)
|
||||
self.running_mean.add_(X_mean.detach() * self.momentum)
|
||||
|
||||
# 3. calculate COV, COV^(-0.5), then deconv
|
||||
# Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
|
||||
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
|
||||
# track stats for evaluation
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
|
||||
else:
|
||||
X_mean = self.running_mean
|
||||
deconv = self.running_deconv
|
||||
|
||||
w = self.weight.view(-1, self.block) @ deconv
|
||||
b = self.bias
|
||||
if self.bias is not None:
|
||||
b = b - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
|
||||
w = w.view(self.weight.shape)
|
||||
return F.linear(input, w, b)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
|
||||
|
||||
class FastDeconv(conv._ConvNd):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100):
|
||||
self.momentum = momentum
|
||||
self.n_iter = n_iter
|
||||
self.eps = eps
|
||||
self.counter=0
|
||||
self.track_running_stats=True
|
||||
super(FastDeconv, self).__init__(
|
||||
in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
|
||||
False, _pair(0), groups, bias, padding_mode='zeros')
|
||||
|
||||
if block > in_channels:
|
||||
block = in_channels
|
||||
else:
|
||||
if in_channels%block!=0:
|
||||
block=math.gcd(block,in_channels)
|
||||
|
||||
if groups>1:
|
||||
#grouped conv
|
||||
block=in_channels//groups
|
||||
|
||||
self.block=block
|
||||
|
||||
self.num_features = kernel_size**2 *block
|
||||
if groups==1:
|
||||
self.register_buffer('running_mean', torch.zeros(self.num_features))
|
||||
self.register_buffer('running_deconv', torch.eye(self.num_features))
|
||||
else:
|
||||
self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
|
||||
self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))
|
||||
|
||||
self.sampling_stride=sampling_stride*stride
|
||||
self.counter=0
|
||||
self.freeze_iter=freeze_iter
|
||||
self.freeze=freeze
|
||||
|
||||
def forward(self, x):
|
||||
N, C, H, W = x.shape
|
||||
B = self.block
|
||||
frozen=self.freeze and (self.counter>self.freeze_iter)
|
||||
if self.training and self.track_running_stats:
|
||||
self.counter+=1
|
||||
self.counter %= (self.freeze_iter * 10)
|
||||
|
||||
if self.training and (not frozen):
|
||||
|
||||
# 1. im2col: N x cols x pixels -> N*pixles x cols
|
||||
if self.kernel_size[0]>1:
|
||||
X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
|
||||
else:
|
||||
#channel wise
|
||||
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]
|
||||
|
||||
if self.groups==1:
|
||||
# (C//B*N*pixels,k*k*B)
|
||||
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
|
||||
else:
|
||||
X=X.view(-1,X.shape[-1])
|
||||
|
||||
# 2. subtract mean
|
||||
X_mean = X.mean(0)
|
||||
X = X - X_mean.unsqueeze(0)
|
||||
|
||||
# 3. calculate COV, COV^(-0.5), then deconv
|
||||
if self.groups==1:
|
||||
#Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
|
||||
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
|
||||
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
|
||||
else:
|
||||
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
|
||||
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
|
||||
Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X)
|
||||
|
||||
deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter)
|
||||
|
||||
if self.track_running_stats:
|
||||
self.running_mean.mul_(1 - self.momentum)
|
||||
self.running_mean.add_(X_mean.detach() * self.momentum)
|
||||
# track stats for evaluation
|
||||
self.running_deconv.mul_(1 - self.momentum)
|
||||
self.running_deconv.add_(deconv.detach() * self.momentum)
|
||||
|
||||
else:
|
||||
X_mean = self.running_mean
|
||||
deconv = self.running_deconv
|
||||
|
||||
#4. X * deconv * conv = X * (deconv * conv)
|
||||
if self.groups==1:
|
||||
w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ deconv
|
||||
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
|
||||
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
|
||||
else:
|
||||
w = self.weight.view(C//B, -1,self.num_features)@deconv
|
||||
b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)
|
||||
|
||||
w = w.view(self.weight.shape)
|
||||
x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
return x
|
||||
|
||||
### ResNet
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, deconv=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if deconv:
|
||||
self.conv1 = deconv(in_planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||
self.conv2 = deconv(planes, planes, kernel_size=3, stride=1, padding=1)
|
||||
self.deconv = True
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.deconv = False
|
||||
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if not deconv:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
#self.bn1 = nn.GroupNorm(planes//16,planes)
|
||||
#self.bn2 = nn.GroupNorm(planes//16,planes)
|
||||
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
#nn.GroupNorm(self.expansion * planes//16,self.expansion * planes)
|
||||
)
|
||||
else:
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
deconv(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.deconv:
|
||||
out = F.relu(self.conv1(x))
|
||||
out = self.conv2(out)
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
else: #self.batch_norm:
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, deconv=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
|
||||
if deconv:
|
||||
self.deconv = True
|
||||
self.conv1 = deconv(in_planes, planes, kernel_size=1)
|
||||
self.conv2 = deconv(planes, planes, kernel_size=3, stride=stride, padding=1)
|
||||
self.conv3 = deconv(planes, self.expansion*planes, kernel_size=1)
|
||||
|
||||
else:
|
||||
self.deconv = False
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if not deconv:
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes)
|
||||
)
|
||||
else:
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
deconv(in_planes, self.expansion * planes, kernel_size=1, stride=stride)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
"""
|
||||
No batch normalization for deconv.
|
||||
"""
|
||||
if self.deconv:
|
||||
out = F.relu((self.conv1(x)))
|
||||
out = F.relu((self.conv2(out)))
|
||||
out = self.conv3(out)
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
else:
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = F.relu(self.bn2(self.conv2(out)))
|
||||
out = self.bn3(self.conv3(out))
|
||||
out += self.shortcut(x)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, num_blocks, num_classes=10, deconv=None,channel_deconv=None):
|
||||
super(ResNet, self).__init__()
|
||||
self.in_planes = 64
|
||||
|
||||
if deconv:
|
||||
self.deconv = True
|
||||
self.conv1 = deconv(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
|
||||
if not deconv:
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
|
||||
#this line is really recent, take extreme care if the result is not good.
|
||||
if channel_deconv:
|
||||
self.deconv1=channel_deconv()
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, deconv=deconv)
|
||||
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, deconv=deconv)
|
||||
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, deconv=deconv)
|
||||
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, deconv=deconv)
|
||||
self.linear = nn.Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, deconv):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride, deconv))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self,'bn1'):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
else:
|
||||
out = F.relu(self.conv1(x))
|
||||
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
if hasattr(self, 'deconv1'):
|
||||
out = self.deconv1(out)
|
||||
out = F.avg_pool2d(out, 4)
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
|
||||
|
||||
def_deconv = partial(FastDeconv,bias=True, eps=1e-5, n_iter=5,block=64,sampling_stride=3)
|
||||
#channel_deconv=partial(ChannelDeconv, block=512,eps=1e-5, n_iter=5,sampling_stride=3) #Pas forcément conseillé
|
||||
|
||||
def ResNet18_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet34_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet50_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet101_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
def ResNet152_DC(num_classes,deconv=def_deconv,channel_deconv=None):
|
||||
return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
|
||||
|
||||
import math
|
||||
class Wide_ResNet_Cifar_DC(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, wfactor, num_classes=10, deconv=None, channel_deconv=None):
|
||||
super(Wide_ResNet_Cifar_DC, self).__init__()
|
||||
self.depth=layers[0]*6+2
|
||||
self.widen_factor=wfactor
|
||||
|
||||
self.inplanes = 16
|
||||
self.conv1 = deconv(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
if channel_deconv:
|
||||
self.deconv1=channel_deconv()
|
||||
# self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
# self.bn1 = nn.BatchNorm2d(16)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(block, 16*wfactor, layers[0], stride=1, deconv=deconv)
|
||||
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2, deconv=deconv)
|
||||
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2, deconv=deconv)
|
||||
self.avgpool = nn.AvgPool2d(8, stride=1)
|
||||
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, deconv):
|
||||
# downsample = None
|
||||
# if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
# downsample = nn.Sequential(
|
||||
# nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
# nn.BatchNorm2d(planes * block.expansion)
|
||||
# )
|
||||
|
||||
# layers = []
|
||||
# layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
# self.inplanes = planes * block.expansion
|
||||
# for _ in range(1, blocks):
|
||||
# layers.append(block(self.inplanes, planes))
|
||||
|
||||
# return nn.Sequential(*layers)
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.inplanes, planes, stride, deconv))
|
||||
self.inplanes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
# x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
if hasattr(self, 'deconv1'):
|
||||
out = self.deconv1(out)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet_cifar_DC%d_%d"%(self.depth,self.widen_factor)
|
||||
|
||||
def WRN_DC26_10(depth=26, width=10, deconv=def_deconv, channel_deconv=None, **kwargs):
|
||||
assert (depth - 2) % 6 == 0
|
||||
n = int((depth - 2) / 6)
|
||||
return Wide_ResNet_Cifar_DC(BasicBlock, [n, n, n], width, deconv=deconv,channel_deconv=channel_deconv, **kwargs)
|
||||
|
||||
def test():
|
||||
net = ResNet18_DC()
|
||||
y = net(torch.randn(1,3,32,32))
|
||||
print(y.size())
|
||||
|
||||
# test()
|
98
higher/smart_aug/nets/wideresnet.py
Normal file
98
higher/smart_aug/nets/wideresnet.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
_bn_momentum = 0.1
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
|
||||
|
||||
|
||||
def conv_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
|
||||
init.constant_(m.bias, 0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class WideBasic(nn.Module):
|
||||
def __init__(self, in_planes, planes, dropout_rate, stride=1):
|
||||
super(WideBasic, self).__init__()
|
||||
assert dropout_rate==0.0, 'dropout layer not used'
|
||||
self.bn1 = nn.BatchNorm2d(in_planes, momentum=_bn_momentum)
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
|
||||
#self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.bn2 = nn.BatchNorm2d(planes, momentum=_bn_momentum)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# out = self.dropout(self.conv1(F.relu(self.bn1(x))))
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += self.shortcut(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WideResNet(nn.Module):
|
||||
def __init__(self, depth, widen_factor, dropout_rate, num_classes):
|
||||
super(WideResNet, self).__init__()
|
||||
self.depth=depth
|
||||
self.widen_factor=widen_factor
|
||||
self.in_planes = 16
|
||||
|
||||
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
|
||||
n = int((depth - 4) / 6)
|
||||
k = widen_factor
|
||||
|
||||
nStages = [16, 16*k, 32*k, 64*k]
|
||||
|
||||
self.conv1 = conv3x3(3, nStages[0])
|
||||
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
|
||||
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
|
||||
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
|
||||
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=_bn_momentum)
|
||||
self.linear = nn.Linear(nStages[3], num_classes)
|
||||
|
||||
# self.apply(conv_init)
|
||||
|
||||
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
|
||||
strides = [stride] + [1]*(num_blocks-1)
|
||||
layers = []
|
||||
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, dropout_rate, stride))
|
||||
self.in_planes = planes
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = F.relu(self.bn1(out))
|
||||
# out = F.avg_pool2d(out, 8)
|
||||
out = F.adaptive_avg_pool2d(out, (1, 1))
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet%d_%d"%(self.depth,self.widen_factor)
|
119
higher/smart_aug/nets/wideresnet_cifar.py
Normal file
119
higher/smart_aug/nets/wideresnet_cifar.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
wide resnet for cifar in pytorch
|
||||
Reference:
|
||||
[1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
#from models.resnet_cifar import BasicBlock
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
" 3x3 convolution with padding "
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion=1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class Wide_ResNet_Cifar(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, wfactor, num_classes=10):
|
||||
super(Wide_ResNet_Cifar, self).__init__()
|
||||
self.depth=layers[0]*6+2
|
||||
self.widen_factor=wfactor
|
||||
|
||||
self.inplanes = 16
|
||||
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(16)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.layer1 = self._make_layer(block, 16*wfactor, layers[0])
|
||||
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(8, stride=1)
|
||||
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion)
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
""" Get name of model
|
||||
|
||||
"""
|
||||
return "Wide_ResNet_cifar%d_%d"%(self.depth,self.widen_factor)
|
||||
|
||||
|
||||
def wide_resnet_cifar(depth, width, **kwargs):
|
||||
assert (depth - 2) % 6 == 0
|
||||
n = int((depth - 2) / 6)
|
||||
return Wide_ResNet_Cifar(BasicBlock, [n, n, n], width, **kwargs)
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
net = wide_resnet_cifar(20, 10)
|
||||
y = net(torch.randn(1, 3, 32, 32))
|
||||
print(isinstance(net, Wide_ResNet_Cifar))
|
||||
print(y.size())
|
Loading…
Add table
Add a link
Reference in a new issue