'''ResNet in PyTorch. For Pre-activation ResNet, see 'preact_resnet.py'. Reference: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun Deep Residual Learning for Image Recognition. arXiv:1512.03385 https://github.com/yechengxi/deconvolution ''' import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules import conv from torch.nn.modules.utils import _pair from functools import partial __all__ = ['ResNet18_DC', 'ResNet34_DC', 'ResNet50_DC', 'ResNet101_DC', 'ResNet152_DC', 'WRN_DC26_10'] ### Deconvolution ### #iteratively solve for inverse sqrt of a matrix def isqrt_newton_schulz_autograd(A, numIters): dim = A.shape[0] normA=A.norm() Y = A.div(normA) I = torch.eye(dim,dtype=A.dtype,device=A.device) Z = torch.eye(dim,dtype=A.dtype,device=A.device) for i in range(numIters): T = 0.5*(3.0*I - Z@Y) Y = Y@T Z = T@Z #A_sqrt = Y*torch.sqrt(normA) A_isqrt = Z / torch.sqrt(normA) return A_isqrt def isqrt_newton_schulz_autograd_batch(A, numIters): batchSize,dim,_ = A.shape normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1) Y = A.div(normA) I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A) Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A) for i in range(numIters): T = 0.5*(3.0*I - Z.bmm(Y)) Y = Y.bmm(T) Z = T.bmm(Z) #A_sqrt = Y*torch.sqrt(normA) A_isqrt = Z / torch.sqrt(normA) return A_isqrt #deconvolve channels class ChannelDeconv(nn.Module): def __init__(self, block, eps=1e-2,n_iter=5,momentum=0.1,sampling_stride=3): super(ChannelDeconv, self).__init__() self.eps = eps self.n_iter=n_iter self.momentum=momentum self.block = block self.register_buffer('running_mean1', torch.zeros(block, 1)) #self.register_buffer('running_cov', torch.eye(block)) self.register_buffer('running_deconv', torch.eye(block)) self.register_buffer('running_mean2', torch.zeros(1, 1)) self.register_buffer('running_var', torch.ones(1, 1)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) self.sampling_stride=sampling_stride def forward(self, x): x_shape = x.shape if len(x.shape)==2: x=x.view(x.shape[0],x.shape[1],1,1) if len(x.shape)==3: print('Error! Unsupprted tensor shape.') N, C, H, W = x.size() B = self.block #take the first c channels out for deconv c=int(C/B)*B if c==0: print('Error! block should be set smaller.') #step 1. remove mean if c!=C: x1=x[:,:c].permute(1,0,2,3).contiguous().view(B,-1) else: x1=x.permute(1,0,2,3).contiguous().view(B,-1) if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride: x1_s = x1[:,::self.sampling_stride**2] else: x1_s=x1 mean1 = x1_s.mean(-1, keepdim=True) if self.num_batches_tracked==0: self.running_mean1.copy_(mean1.detach()) if self.training: self.running_mean1.mul_(1-self.momentum) self.running_mean1.add_(mean1.detach()*self.momentum) else: mean1 = self.running_mean1 x1=x1-mean1 #step 2. calculate deconv@x1 = cov^(-0.5)@x1 if self.training: cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device) deconv = isqrt_newton_schulz_autograd(cov, self.n_iter) if self.num_batches_tracked==0: #self.running_cov.copy_(cov.detach()) self.running_deconv.copy_(deconv.detach()) if self.training: #self.running_cov.mul_(1-self.momentum) #self.running_cov.add_(cov.detach()*self.momentum) self.running_deconv.mul_(1 - self.momentum) self.running_deconv.add_(deconv.detach() * self.momentum) else: # cov = self.running_cov deconv = self.running_deconv x1 =deconv@x1 #reshape to N,c,J,W x1 = x1.view(c, N, H, W).contiguous().permute(1,0,2,3) # normalize the remaining channels if c!=C: x_tmp=x[:, c:].view(N,-1) if self.sampling_stride > 1 and H>=self.sampling_stride and W>=self.sampling_stride: x_s = x_tmp[:, ::self.sampling_stride ** 2] else: x_s = x_tmp mean2=x_s.mean() var=x_s.var() if self.num_batches_tracked == 0: self.running_mean2.copy_(mean2.detach()) self.running_var.copy_(var.detach()) if self.training: self.running_mean2.mul_(1 - self.momentum) self.running_mean2.add_(mean2.detach() * self.momentum) self.running_var.mul_(1 - self.momentum) self.running_var.add_(var.detach() * self.momentum) else: mean2 = self.running_mean2 var = self.running_var x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt() x1 = torch.cat([x1, x_tmp], dim=1) if self.training: self.num_batches_tracked.add_(1) if len(x_shape)==2: x1=x1.view(x_shape) return x1 #An alternative implementation class Delinear(nn.Module): __constants__ = ['bias', 'in_features', 'out_features'] def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512): super(Delinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = nn.Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() if block > in_features: block = in_features else: if in_features%block!=0: block=math.gcd(block,in_features) print('block size set to:', block) self.block = block self.momentum = momentum self.n_iter = n_iter self.eps = eps self.register_buffer('running_mean', torch.zeros(self.block)) self.register_buffer('running_deconv', torch.eye(self.block)) def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) nn.init.uniform_(self.bias, -bound, bound) def forward(self, input): if self.training: # 1. reshape X=input.view(-1, self.block) # 2. subtract mean X_mean = X.mean(0) X = X - X_mean.unsqueeze(0) self.running_mean.mul_(1 - self.momentum) self.running_mean.add_(X_mean.detach() * self.momentum) # 3. calculate COV, COV^(-0.5), then deconv # Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device) Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device) Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X) deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter) # track stats for evaluation self.running_deconv.mul_(1 - self.momentum) self.running_deconv.add_(deconv.detach() * self.momentum) else: X_mean = self.running_mean deconv = self.running_deconv w = self.weight.view(-1, self.block) @ deconv b = self.bias if self.bias is not None: b = b - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1) w = w.view(self.weight.shape) return F.linear(input, w, b) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None ) class FastDeconv(conv._ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100): self.momentum = momentum self.n_iter = n_iter self.eps = eps self.counter=0 self.track_running_stats=True super(FastDeconv, self).__init__( in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation), False, _pair(0), groups, bias, padding_mode='zeros') if block > in_channels: block = in_channels else: if in_channels%block!=0: block=math.gcd(block,in_channels) if groups>1: #grouped conv block=in_channels//groups self.block=block self.num_features = kernel_size**2 *block if groups==1: self.register_buffer('running_mean', torch.zeros(self.num_features)) self.register_buffer('running_deconv', torch.eye(self.num_features)) else: self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels)) self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1)) self.sampling_stride=sampling_stride*stride self.counter=0 self.freeze_iter=freeze_iter self.freeze=freeze def forward(self, x): N, C, H, W = x.shape B = self.block frozen=self.freeze and (self.counter>self.freeze_iter) if self.training and self.track_running_stats: self.counter+=1 self.counter %= (self.freeze_iter * 10) if self.training and (not frozen): # 1. im2col: N x cols x pixels -> N*pixles x cols if self.kernel_size[0]>1: X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous() else: #channel wise X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:] if self.groups==1: # (C//B*N*pixels,k*k*B) X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features) else: X=X.view(-1,X.shape[-1]) # 2. subtract mean X_mean = X.mean(0) X = X - X_mean.unsqueeze(0) # 3. calculate COV, COV^(-0.5), then deconv if self.groups==1: #Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device) Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device) Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X) deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter) else: X = X.view(-1, self.groups, self.num_features).transpose(0, 1) Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features) Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X) deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter) if self.track_running_stats: self.running_mean.mul_(1 - self.momentum) self.running_mean.add_(X_mean.detach() * self.momentum) # track stats for evaluation self.running_deconv.mul_(1 - self.momentum) self.running_deconv.add_(deconv.detach() * self.momentum) else: X_mean = self.running_mean deconv = self.running_deconv #4. X * deconv * conv = X * (deconv * conv) if self.groups==1: w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ deconv b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1) w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous() else: w = self.weight.view(C//B, -1,self.num_features)@deconv b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape) w = w.view(self.weight.shape) x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups) return x ### ResNet class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, deconv=None): super(BasicBlock, self).__init__() if deconv: self.conv1 = deconv(in_planes, planes, kernel_size=3, stride=stride, padding=1) self.conv2 = deconv(planes, planes, kernel_size=3, stride=1, padding=1) self.deconv = True else: self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.deconv = False self.shortcut = nn.Sequential() if not deconv: self.bn1 = nn.BatchNorm2d(planes) self.bn2 = nn.BatchNorm2d(planes) #self.bn1 = nn.GroupNorm(planes//16,planes) #self.bn2 = nn.GroupNorm(planes//16,planes) if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) #nn.GroupNorm(self.expansion * planes//16,self.expansion * planes) ) else: if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( deconv(in_planes, self.expansion*planes, kernel_size=1, stride=stride) ) def forward(self, x): if self.deconv: out = F.relu(self.conv1(x)) out = self.conv2(out) out += self.shortcut(x) out = F.relu(out) return out else: #self.batch_norm: out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1, deconv=None): super(Bottleneck, self).__init__() if deconv: self.deconv = True self.conv1 = deconv(in_planes, planes, kernel_size=1) self.conv2 = deconv(planes, planes, kernel_size=3, stride=stride, padding=1) self.conv3 = deconv(planes, self.expansion*planes, kernel_size=1) else: self.deconv = False self.bn1 = nn.BatchNorm2d(planes) self.bn2 = nn.BatchNorm2d(planes) self.bn3 = nn.BatchNorm2d(self.expansion * planes) self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) self.shortcut = nn.Sequential() if not deconv: if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) else: if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( deconv(in_planes, self.expansion * planes, kernel_size=1, stride=stride) ) def forward(self, x): """ No batch normalization for deconv. """ if self.deconv: out = F.relu((self.conv1(x))) out = F.relu((self.conv2(out))) out = self.conv3(out) out += self.shortcut(x) out = F.relu(out) return out else: out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10, deconv=None,channel_deconv=None): super(ResNet, self).__init__() self.in_planes = 64 if deconv: self.deconv = True self.conv1 = deconv(3, 64, kernel_size=3, stride=1, padding=1) else: self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) if not deconv: self.bn1 = nn.BatchNorm2d(64) #this line is really recent, take extreme care if the result is not good. if channel_deconv: self.deconv1=channel_deconv() self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, deconv=deconv) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, deconv=deconv) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, deconv=deconv) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, deconv=deconv) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride, deconv): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride, deconv)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): if hasattr(self,'bn1'): out = F.relu(self.bn1(self.conv1(x))) else: out = F.relu(self.conv1(x)) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) if hasattr(self, 'deconv1'): out = self.deconv1(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) return out def_deconv = partial(FastDeconv,bias=True, eps=1e-5, n_iter=5,block=64,sampling_stride=3) #channel_deconv=partial(ChannelDeconv, block=512,eps=1e-5, n_iter=5,sampling_stride=3) #Pas forcément conseillé def ResNet18_DC(num_classes,deconv=def_deconv,channel_deconv=None): return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv) def ResNet34_DC(num_classes,deconv=def_deconv,channel_deconv=None): return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv) def ResNet50_DC(num_classes,deconv=def_deconv,channel_deconv=None): return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv) def ResNet101_DC(num_classes,deconv=def_deconv,channel_deconv=None): return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv) def ResNet152_DC(num_classes,deconv=def_deconv,channel_deconv=None): return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, deconv=deconv,channel_deconv=channel_deconv) import math class Wide_ResNet_Cifar_DC(nn.Module): def __init__(self, block, layers, wfactor, num_classes=10, deconv=None, channel_deconv=None): super(Wide_ResNet_Cifar_DC, self).__init__() self.depth=layers[0]*6+2 self.widen_factor=wfactor self.inplanes = 16 self.conv1 = deconv(3, 16, kernel_size=3, stride=1, padding=1) if channel_deconv: self.deconv1=channel_deconv() # self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) # self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 16*wfactor, layers[0], stride=1, deconv=deconv) self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2, deconv=deconv) self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2, deconv=deconv) self.avgpool = nn.AvgPool2d(8, stride=1) self.fc = nn.Linear(64*block.expansion*wfactor, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, num_blocks, stride, deconv): # downsample = None # if stride != 1 or self.inplanes != planes * block.expansion: # downsample = nn.Sequential( # nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), # nn.BatchNorm2d(planes * block.expansion) # ) # layers = [] # layers.append(block(self.inplanes, planes, stride, downsample)) # self.inplanes = planes * block.expansion # for _ in range(1, blocks): # layers.append(block(self.inplanes, planes)) # return nn.Sequential(*layers) strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.inplanes, planes, stride, deconv)) self.inplanes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) # x = self.bn1(x) x = self.relu(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) if hasattr(self, 'deconv1'): out = self.deconv1(out) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x def __str__(self): """ Get name of model """ return "Wide_ResNet_cifar_DC%d_%d"%(self.depth,self.widen_factor) def WRN_DC26_10(depth=26, width=10, deconv=def_deconv, channel_deconv=None, **kwargs): assert (depth - 2) % 6 == 0 n = int((depth - 2) / 6) return Wide_ResNet_Cifar_DC(BasicBlock, [n, n, n], width, deconv=deconv,channel_deconv=channel_deconv, **kwargs) def test(): net = ResNet18_DC() y = net(torch.randn(1,3,32,32)) print(y.size()) # test()