Amelioration visualisation des proba

This commit is contained in:
Harle, Antoine (Contracteur) 2019-11-13 16:18:53 -05:00
parent f0c0559e73
commit 93d91815f5
7 changed files with 720 additions and 211 deletions

View file

@ -54,6 +54,7 @@ class LeNet(nn.Module):
## Wide ResNet ##
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
super(BasicBlock, self).__init__()
@ -97,9 +98,10 @@ class WideResNet(nn.Module):
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
super(WideResNet, self).__init__()
kernel_size = wrn_size
self.kernel_size = wrn_size
self.depth=depth
filter_size = 3
nChannels = [min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4]
nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4]
strides = [1, 2, 2] # stride for each resblock
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
@ -137,4 +139,10 @@ class WideResNet(nn.Module):
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
return self.fc(out)
return self.fc(out)
def architecture(self):
return super(WideResNet, self).__str__()
def __str__(self):
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)