smart_augmentation/higher/smart_aug/nets/resnet_deconv.py
2024-08-20 11:53:35 +02:00

618 lines
No EOL
22 KiB
Python

'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
https://github.com/yechengxi/deconvolution
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import conv
from torch.nn.modules.utils import _pair
from functools import partial
__all__ = ['ResNet18_DC', 'ResNet34_DC', 'ResNet50_DC', 'ResNet101_DC', 'ResNet152_DC', 'WRN_DC26_10']
### Deconvolution ###
#iteratively solve for inverse sqrt of a matrix
def isqrt_newton_schulz_autograd(A, numIters):
dim = A.shape[0]
normA=A.norm()
Y = A.div(normA)
I = torch.eye(dim,dtype=A.dtype,device=A.device)
Z = torch.eye(dim,dtype=A.dtype,device=A.device)
for i in range(numIters):
T = 0.5*(3.0*I - Z@Y)
Y = Y@T
Z = T@Z
#A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
def isqrt_newton_schulz_autograd_batch(A, numIters):
batchSize,dim,_ = A.shape
normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
Y = A.div(normA)
I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
for i in range(numIters):
T = 0.5*(3.0*I - Z.bmm(Y))
Y = Y.bmm(T)
Z = T.bmm(Z)
#A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
#deconvolve channels
class ChannelDeconv(nn.Module):
def __init__(self, block, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3):
super(ChannelDeconv, self).__init__()
self.eps = eps
self.n_iter=n_iter
self.momentum=momentum
self.block = block
self.register_buffer('running_mean1', torch.zeros(block, 1))
#self.register_buffer('running_cov', torch.eye(block))
self.register_buffer('running_deconv', torch.eye(block))
self.register_buffer('running_mean2', torch.zeros(1, 1))
self.register_buffer('running_var', torch.ones(1, 1))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.sampling_stride=sampling_stride
def forward(self, x):
x_shape = x.shape
if len(x.shape)==2:
x=x.view(x.shape[0],x.shape[1],1,1)
if len(x.shape)==3:
print('Error! Unsupprted tensor shape.')
N, C, H, W = x.size()
B = self.block
#take the first c channels out for deconv
c=int(C/B)*B
if c==0:
print('Error! block should be set smaller.')
#step 1. remove mean
if c!=C:
x1=x[:,:c].permute(1,0,2,3).contiguous().view(B,-1)
else:
x1=x.permute(1,0,2,3).contiguous().view(B,-1)
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
x1_s = x1[:,::self.sampling_stride**2]
else:
x1_s=x1
mean1 = x1_s.mean(-1, keepdim=True)
if self.num_batches_tracked==0:
self.running_mean1.copy_(mean1.detach())
if self.training:
self.running_mean1.mul_(1-self.momentum)
self.running_mean1.add_(mean1.detach()*self.momentum)
else:
mean1 = self.running_mean1
x1=x1-mean1
#step 2. calculate deconv@x1 = cov^(-0.5)@x1
if self.training:
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device)
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
if self.num_batches_tracked==0:
#self.running_cov.copy_(cov.detach())
self.running_deconv.copy_(deconv.detach())
if self.training:
#self.running_cov.mul_(1-self.momentum)
#self.running_cov.add_(cov.detach()*self.momentum)
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
# cov = self.running_cov
deconv = self.running_deconv
x1 =deconv@x1
#reshape to N,c,J,W
x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3)
# normalize the remaining channels
if c!=C:
x_tmp=x[:, c:].view(N,-1)
if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride:
x_s = x_tmp[:, ::self.sampling_stride ** 2]
else:
x_s = x_tmp
mean2=x_s.mean()
var=x_s.var()
if self.num_batches_tracked == 0:
self.running_mean2.copy_(mean2.detach())
self.running_var.copy_(var.detach())
if self.training:
self.running_mean2.mul_(1 - self.momentum)
self.running_mean2.add_(mean2.detach() * self.momentum)
self.running_var.mul_(1 - self.momentum)
self.running_var.add_(var.detach() * self.momentum)
else:
mean2 = self.running_mean2
var = self.running_var
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
x1 = torch.cat([x1, x_tmp], dim=1)
if self.training:
self.num_batches_tracked.add_(1)
if len(x_shape)==2:
x1=x1.view(x_shape)
return x1
#An alternative implementation
class Delinear(nn.Module):
__constants__ = ['bias', 'in_features', 'out_features']
def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512):
super(Delinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
if block > in_features:
block = in_features
else:
if in_features%block!=0:
block=math.gcd(block,in_features)
print('block size set to:', block)
self.block = block
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
self.register_buffer('running_mean', torch.zeros(self.block))
self.register_buffer('running_deconv', torch.eye(self.block))
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
if self.training:
# 1. reshape
X=input.view(-1, self.block)
# 2. subtract mean
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
# 3. calculate COV, COV^(-0.5), then deconv
# Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
# track stats for evaluation
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
X_mean = self.running_mean
deconv = self.running_deconv
w = self.weight.view(-1, self.block) @ deconv
b = self.bias
if self.bias is not None:
b = b - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(self.weight.shape)
return F.linear(input, w, b)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class FastDeconv(conv._ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100):
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
self.counter=0
self.track_running_stats=True
super(FastDeconv, self).__init__(
in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
False, _pair(0), groups, bias, padding_mode='zeros')
if block > in_channels:
block = in_channels
else:
if in_channels%block!=0:
block=math.gcd(block,in_channels)
if groups>1:
#grouped conv
block=in_channels//groups
self.block=block
self.num_features = kernel_size**2 *block
if groups==1:
self.register_buffer('running_mean', torch.zeros(self.num_features))
self.register_buffer('running_deconv', torch.eye(self.num_features))
else:
self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))
self.sampling_stride=sampling_stride*stride
self.counter=0
self.freeze_iter=freeze_iter
self.freeze=freeze
def forward(self, x):
N, C, H, W = x.shape
B = self.block
frozen=self.freeze and (self.counter>self.freeze_iter)
if self.training and self.track_running_stats:
self.counter+=1
self.counter %= (self.freeze_iter * 10)
if self.training and (not frozen):
# 1. im2col: N x cols x pixels -> N*pixles x cols
if self.kernel_size[0]>1:
X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
else:
#channel wise
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]
if self.groups==1:
# (C//B*N*pixels,k*k*B)
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
else:
X=X.view(-1,X.shape[-1])
# 2. subtract mean
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
# 3. calculate COV, COV^(-0.5), then deconv
if self.groups==1:
#Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
else:
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X)
deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter)
if self.track_running_stats:
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
# track stats for evaluation
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
X_mean = self.running_mean
deconv = self.running_deconv
#4. X * deconv * conv = X * (deconv * conv)
if self.groups==1:
w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ deconv
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
else:
w = self.weight.view(C//B, -1,self.num_features)@deconv
b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)
w = w.view(self.weight.shape)
x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)
return x
### ResNet
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, deconv=None):
super(BasicBlock, self).__init__()
if deconv:
self.conv1 = deconv(in_planes, planes, kernel_size=3, stride=stride, padding=1)
self.conv2 = deconv(planes, planes, kernel_size=3, stride=1, padding=1)
self.deconv = True
else:
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.deconv = False
self.shortcut = nn.Sequential()
if not deconv:
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
#self.bn1 = nn.GroupNorm(planes//16,planes)
#self.bn2 = nn.GroupNorm(planes//16,planes)
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
#nn.GroupNorm(self.expansion * planes//16,self.expansion * planes)
)
else:
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
deconv(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
)
def forward(self, x):
if self.deconv:
out = F.relu(self.conv1(x))
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(out)
return out
else: #self.batch_norm:
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, deconv=None):
super(Bottleneck, self).__init__()
if deconv:
self.deconv = True
self.conv1 = deconv(in_planes, planes, kernel_size=1)
self.conv2 = deconv(planes, planes, kernel_size=3, stride=stride, padding=1)
self.conv3 = deconv(planes, self.expansion*planes, kernel_size=1)
else:
self.deconv = False
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.shortcut = nn.Sequential()
if not deconv:
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
else:
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
deconv(in_planes, self.expansion * planes, kernel_size=1, stride=stride)
)
def forward(self, x):
"""
No batch normalization for deconv.
"""
if self.deconv:
out = F.relu((self.conv1(x)))
out = F.relu((self.conv2(out)))
out = self.conv3(out)
out += self.shortcut(x)
out = F.relu(out)
return out
else:
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10, deconv=None,channel_deconv=None):
super(ResNet, self).__init__()
self.in_planes = 64
if deconv:
self.deconv = True
self.conv1 = deconv(3, 64, kernel_size=3, stride=1, padding=1)
else:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
if not deconv:
self.bn1 = nn.BatchNorm2d(64)
#this line is really recent, take extreme care if the result is not good.
if channel_deconv:
self.deconv1=channel_deconv()
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, deconv=deconv)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, deconv=deconv)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, deconv=deconv)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, deconv=deconv)
self.linear = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride, deconv):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, deconv))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
if hasattr(self,'bn1'):
out = F.relu(self.bn1(self.conv1(x)))
else:
out = F.relu(self.conv1(x))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
if hasattr(self, 'deconv1'):
out = self.deconv1(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def_deconv = partial(FastDeconv,bias=True, eps=1e-5, n_iter=5,block=64,sampling_stride=3)
#channel_deconv=partial(ChannelDeconv, block=512,eps=1e-5, n_iter=5,sampling_stride=3) #Pas forcément conseillé
def ResNet18_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet34_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet50_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet101_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
def ResNet152_DC(num_classes,deconv=def_deconv,channel_deconv=None):
return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv)
import math
class Wide_ResNet_Cifar_DC(nn.Module):
def __init__(self, block, layers, wfactor, num_classes=10, deconv=None, channel_deconv=None):
super(Wide_ResNet_Cifar_DC, self).__init__()
self.depth=layers[0]*6+2
self.widen_factor=wfactor
self.inplanes = 16
self.conv1 = deconv(3, 16, kernel_size=3, stride=1, padding=1)
if channel_deconv:
self.deconv1=channel_deconv()
# self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
# self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16*wfactor, layers[0], stride=1, deconv=deconv)
self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2, deconv=deconv)
self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2, deconv=deconv)
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64*block.expansion*wfactor, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, num_blocks, stride, deconv):
# downsample = None
# if stride != 1 or self.inplanes != planes * block.expansion:
# downsample = nn.Sequential(
# nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(planes * block.expansion)
# )
# layers = []
# layers.append(block(self.inplanes, planes, stride, downsample))
# self.inplanes = planes * block.expansion
# for _ in range(1, blocks):
# layers.append(block(self.inplanes, planes))
# return nn.Sequential(*layers)
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.inplanes, planes, stride, deconv))
self.inplanes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
# x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if hasattr(self, 'deconv1'):
out = self.deconv1(out)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def __str__(self):
""" Get name of model
"""
return "Wide_ResNet_cifar_DC%d_%d"%(self.depth,self.widen_factor)
def WRN_DC26_10(depth=26, width=10, deconv=def_deconv, channel_deconv=None, **kwargs):
assert (depth - 2) % 6 == 0
n = int((depth - 2) / 6)
return Wide_ResNet_Cifar_DC(BasicBlock, [n, n, n], width, deconv=deconv,channel_deconv=channel_deconv, **kwargs)
def test():
net = ResNet18_DC()
y = net(torch.randn(1,3,32,32))
print(y.size())
# test()