mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Amelioration visualisation des proba
This commit is contained in:
parent
f0c0559e73
commit
93d91815f5
7 changed files with 720 additions and 211 deletions
|
@ -54,6 +54,7 @@ class LeNet(nn.Module):
|
|||
## Wide ResNet ##
|
||||
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
|
||||
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
|
||||
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
@ -97,9 +98,10 @@ class WideResNet(nn.Module):
|
|||
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
|
||||
super(WideResNet, self).__init__()
|
||||
|
||||
kernel_size = wrn_size
|
||||
self.kernel_size = wrn_size
|
||||
self.depth=depth
|
||||
filter_size = 3
|
||||
nChannels = [min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4]
|
||||
nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4]
|
||||
strides = [1, 2, 2] # stride for each resblock
|
||||
|
||||
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
|
||||
|
@ -137,4 +139,10 @@ class WideResNet(nn.Module):
|
|||
out = self.relu(self.bn1(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.nChannels)
|
||||
return self.fc(out)
|
||||
return self.fc(out)
|
||||
|
||||
def architecture(self):
|
||||
return super(WideResNet, self).__str__()
|
||||
|
||||
def __str__(self):
|
||||
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)
|
Loading…
Add table
Add a link
Reference in a new issue