Rangement

This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-24 14:32:37 -05:00
parent f83c73ec17
commit f507ff4741
16 changed files with 85 additions and 46 deletions

View file

@ -0,0 +1,490 @@
# coding=utf-8
# Copyright 2019 The Google UDA Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transforms used in the Augmentation Policies.
Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
# pylint:disable=g-multiple-import
from PIL import ImageOps, ImageEnhance, ImageFilter, Image
# pylint:enable=g-multiple-import
#import tensorflow as tf
#FLAGS = tf.flags.FLAGS
IMAGE_SIZE = 32
# What is the dataset mean and std of the images on the training set
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
def get_mean_and_std():
#if FLAGS.task_name == "cifar10":
means = [0.49139968, 0.48215841, 0.44653091]
stds = [0.24703223, 0.24348513, 0.26158784]
#elif FLAGS.task_name == "svhn":
# means = [0.4376821, 0.4437697, 0.47280442]
# stds = [0.19803012, 0.20101562, 0.19703614]
#else:
# assert False
return means, stds
def random_flip(x):
"""Flip the input x horizontally with 50% probability."""
if np.random.rand(1)[0] > 0.5:
return np.fliplr(x)
return x
def zero_pad_and_crop(img, amount=4):
"""Zero pad by `amount` zero pixels on each side then take a random crop.
Args:
img: numpy image that will be zero padded and cropped.
amount: amount of zeros to pad `img` with horizontally and verically.
Returns:
The cropped zero padded img. The returned numpy array will be of the same
shape as `img`.
"""
padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
img.shape[2]))
padded_img[amount:img.shape[0] + amount, amount:
img.shape[1] + amount, :] = img
top = np.random.randint(low=0, high=2 * amount)
left = np.random.randint(low=0, high=2 * amount)
new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
return new_img
def create_cutout_mask(img_height, img_width, num_channels, size):
"""Creates a zero mask used for cutout of shape `img_height` x `img_width`.
Args:
img_height: Height of image cutout mask will be applied to.
img_width: Width of image cutout mask will be applied to.
num_channels: Number of channels in the image.
size: Size of the zeros mask.
Returns:
A mask of shape `img_height` x `img_width` with all ones except for a
square of zeros of shape `size` x `size`. This mask is meant to be
elementwise multiplied with the original image. Additionally returns
the `upper_coord` and `lower_coord` which specify where the cutout mask
will be applied.
"""
assert img_height == img_width
# Sample center where cutout mask will be applied
height_loc = np.random.randint(low=0, high=img_height)
width_loc = np.random.randint(low=0, high=img_width)
# Determine upper right and lower left corners of patch
upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
lower_coord = (min(img_height, height_loc + size // 2),
min(img_width, width_loc + size // 2))
mask_height = lower_coord[0] - upper_coord[0]
mask_width = lower_coord[1] - upper_coord[1]
assert mask_height > 0
assert mask_width > 0
mask = np.ones((img_height, img_width, num_channels))
zeros = np.zeros((mask_height, mask_width, num_channels))
mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
zeros)
return mask, upper_coord, lower_coord
def cutout_numpy(img, size=16):
"""Apply cutout with mask of shape `size` x `size` to `img`.
The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
This operation applies a `size`x`size` mask of zeros to a random location
within `img`.
Args:
img: Numpy image that cutout will be applied to.
size: Height/width of the cutout mask that will be
Returns:
A numpy tensor that is the result of applying the cutout mask to `img`.
"""
img_height, img_width, num_channels = (img.shape[0], img.shape[1],
img.shape[2])
assert len(img.shape) == 3
mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
return img * mask
def float_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
A float that results from scaling `maxval` according to `level`.
"""
return float(level) * maxval / PARAMETER_MAX
def int_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
An int that results from scaling `maxval` according to `level`.
"""
return int(level * maxval / PARAMETER_MAX)
def pil_wrap(img, use_mean_std):
"""Convert the `img` numpy tensor to a PIL Image."""
if use_mean_std:
MEANS, STDS = get_mean_and_std()
else:
MEANS = [0, 0, 0]
STDS = [1, 1, 1]
img_ori = (img * STDS + MEANS) * 255
return Image.fromarray(
np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')
def pil_unwrap(pil_img, use_mean_std, img_shape):
"""Converts the PIL img to a numpy array."""
if use_mean_std:
MEANS, STDS = get_mean_and_std()
else:
MEANS = [0, 0, 0]
STDS = [1, 1, 1]
pic_array = np.array(pil_img.getdata()).reshape((img_shape[0], img_shape[1], 4)) / 255.0
i1, i2 = np.where(pic_array[:, :, 3] == 0)
pic_array = (pic_array[:, :, :3] - MEANS) / STDS
pic_array[i1, i2] = [0, 0, 0]
return pic_array
def apply_policy(policy, img, use_mean_std=True):
"""Apply the `policy` to the numpy `img`.
Args:
policy: A list of tuples with the form (name, probability, level) where
`name` is the name of the augmentation operation to apply, `probability`
is the probability of applying the operation and `level` is what strength
the operation to apply.
img: Numpy image that will have `policy` applied to it.
Returns:
The result of applying `policy` to `img`.
"""
img_shape = img.shape
pil_img = pil_wrap(img, use_mean_std)
for xform in policy:
assert len(xform) == 3
name, probability, level = xform
xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(
probability, level, img_shape)
pil_img = xform_fn(pil_img)
return pil_unwrap(pil_img, use_mean_std, img_shape)
class TransformFunction(object):
"""Wraps the Transform function for pretty printing options."""
def __init__(self, func, name):
self.f = func
self.name = name
def __repr__(self):
return '<' + self.name + '>'
def __call__(self, pil_img):
return self.f(pil_img)
class TransformT(object):
"""Each instance of this class represents a specific transform."""
def __init__(self, name, xform_fn):
self.name = name
self.xform = xform_fn
def pil_transformer(self, probability, level, img_shape):
def return_function(im):
if random.random() < probability:
im = self.xform(im, level, img_shape)
return im
name = self.name + '({:.1f},{})'.format(probability, level)
return TransformFunction(return_function, name)
################## Transform Functions ##################
identity = TransformT('identity', lambda pil_img, level, _: pil_img)
flip_lr = TransformT(
'FlipLR',
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
flip_ud = TransformT(
'FlipUD',
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
# pylint:disable=g-long-lambda
auto_contrast = TransformT(
'AutoContrast',
lambda pil_img, level, _: ImageOps.autocontrast(
pil_img.convert('RGB')).convert('RGBA'))
equalize = TransformT(
'Equalize',
lambda pil_img, level, _: ImageOps.equalize(
pil_img.convert('RGB')).convert('RGBA'))
invert = TransformT(
'Invert',
lambda pil_img, level, _: ImageOps.invert(
pil_img.convert('RGB')).convert('RGBA'))
# pylint:enable=g-long-lambda
blur = TransformT(
'Blur', lambda pil_img, level, _: pil_img.filter(ImageFilter.BLUR))
smooth = TransformT(
'Smooth',
lambda pil_img, level, _: pil_img.filter(ImageFilter.SMOOTH))
def _rotate_impl(pil_img, level, _):
"""Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
degrees = int_parameter(level, 30)
if random.random() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees)
rotate = TransformT('Rotate', _rotate_impl)
def _posterize_impl(pil_img, level, _):
"""Applies PIL Posterize to `pil_img`."""
level = int_parameter(level, 4)
return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')
posterize = TransformT('Posterize', _posterize_impl)
def _shear_x_impl(pil_img, level, img_shape):
"""Applies PIL ShearX to `pil_img`.
The ShearX operation shears the image along the horizontal axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, level, 0, 0, 1, 0))
shear_x = TransformT('ShearX', _shear_x_impl)
def _shear_y_impl(pil_img, level, img_shape):
"""Applies PIL ShearY to `pil_img`.
The ShearY operation shears the image along the vertical axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, 0, 0, level, 1, 0))
shear_y = TransformT('ShearY', _shear_y_impl)
def _translate_x_impl(pil_img, level, img_shape):
"""Applies PIL TranslateX to `pil_img`.
Translate the image in the horizontal direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateX applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, 0, level, 0, 1, 0))
translate_x = TransformT('TranslateX', _translate_x_impl)
def _translate_y_impl(pil_img, level, img_shape):
"""Applies PIL TranslateY to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateY applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform(
(img_shape[0], img_shape[1]),
Image.AFFINE,
(1, 0, 0, 0, 1, level))
translate_y = TransformT('TranslateY', _translate_y_impl)
def _crop_impl(pil_img, level, img_shape, interpolation=Image.BILINEAR):
"""Applies a crop to `pil_img` with the size depending on the `level`."""
cropped = pil_img.crop((level, level, img_shape[0] - level, img_shape[1] - level))
resized = cropped.resize((img_shape[0], img_shape[1]), interpolation)
return resized
crop_bilinear = TransformT('CropBilinear', _crop_impl)
def _solarize_impl(pil_img, level, _):
"""Applies PIL Solarize to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had Solarize applied to it.
"""
level = int_parameter(level, 256)
return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')
solarize = TransformT('Solarize', _solarize_impl)
def _cutout_pil_impl(pil_img, level, img_shape):
"""Apply cutout to pil_img at the specified level."""
size = int_parameter(level, 20)
if size <= 0:
return pil_img
img_height, img_width, num_channels = (img_shape[0], img_shape[1], 3)
_, upper_coord, lower_coord = (
create_cutout_mask(img_height, img_width, num_channels, size))
pixels = pil_img.load() # create the pixel map
for i in range(upper_coord[0], lower_coord[0]): # for every col:
for j in range(upper_coord[1], lower_coord[1]): # For every row
pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
return pil_img
cutout = TransformT('Cutout', _cutout_pil_impl)
def _enhancer_impl(enhancer):
"""Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
def impl(pil_img, level, _):
v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it
return enhancer(pil_img).enhance(v)
return impl
color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
brightness = TransformT('Brightness', _enhancer_impl(
ImageEnhance.Brightness))
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))
ALL_TRANSFORMS = [
flip_lr,
flip_ud,
auto_contrast,
equalize,
invert,
rotate,
posterize,
crop_bilinear,
solarize,
color,
contrast,
brightness,
sharpness,
shear_x,
shear_y,
translate_x,
translate_y,
cutout,
blur,
smooth
]
NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,85 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import higher
import time
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
dl_train = torch.utils.data.DataLoader(data_train, batch_size=300, shuffle=True, num_workers=0, pin_memory=False)
class Aug_model(nn.Module):
def __init__(self, model, hyper_param=True):
super(Aug_model, self).__init__()
#### Origin of the issue ? ####
if hyper_param:
self._params = nn.ParameterDict({
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
})
###############################
self._mods = nn.ModuleDict({
'model': model,
})
def forward(self, x):
return self._mods['model'](x) #* self._params['hyper_param']
def __getitem__(self, key):
return self._mods[key]
class Aug_model2(nn.Module): #Slow increase like no hyper_param
def __init__(self, model, hyper_param=True):
super(Aug_model2, self).__init__()
#### Origin of the issue ? ####
if hyper_param:
self._params = nn.ParameterDict({
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
})
###############################
self._mods = nn.ModuleDict({
'model': model,
'fmodel': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
})
def forward(self, x):
return self._mods['fmodel'](x) * self._params['hyper_param']
def get_diffopt(self, opt, track_higher_grads=True):
return higher.optim.get_diff_optim(opt,
self._mods['model'].parameters(),
fmodel=self._mods['fmodel'],
track_higher_grads=track_higher_grads)
def __getitem__(self, key):
return self._mods[key]
if __name__ == "__main__":
device = torch.device('cuda:1')
aug_model = Aug_model2(
model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
hyper_param=True #False will not extend step time
).to(device)
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
#fmodel = higher.patch.monkeypatch(aug_model, device=None, copy_initial_weights=True)
#diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel,track_higher_grads=True)
diffopt = aug_model.get_diffopt(inner_opt)
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
#logits = fmodel(xs)
logits = aug_model(xs)
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
#print(len(fmodel._fast_params),"step", time.process_time()-t)
print(len(aug_model['fmodel']._fast_params),"step", time.process_time()-t)

View file

@ -0,0 +1,502 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
## Basic CNN ##
class LeNet_F(nn.Module):
def __init__(self, num_inp, num_out):
super(LeNet_F, self).__init__()
self._params = nn.ParameterDict({
'w1': nn.Parameter(torch.zeros(20, num_inp, 5, 5)),
'b1': nn.Parameter(torch.zeros(20)),
'w2': nn.Parameter(torch.zeros(50, 20, 5, 5)),
'b2': nn.Parameter(torch.zeros(50)),
#'w3': nn.Parameter(torch.zeros(500,4*4*50)), #num_imp=1
'w3': nn.Parameter(torch.zeros(500,5*5*50)), #num_imp=3
'b3': nn.Parameter(torch.zeros(500)),
'w4': nn.Parameter(torch.zeros(num_out, 500)),
'b4': nn.Parameter(torch.zeros(num_out))
})
self.initialize()
def initialize(self):
nn.init.kaiming_uniform_(self._params["w1"], a=math.sqrt(5))
nn.init.kaiming_uniform_(self._params["w2"], a=math.sqrt(5))
nn.init.kaiming_uniform_(self._params["w3"], a=math.sqrt(5))
nn.init.kaiming_uniform_(self._params["w4"], a=math.sqrt(5))
def forward(self, x):
#print("Start Shape ", x.shape)
out = F.relu(F.conv2d(input=x, weight=self._params["w1"], bias=self._params["b1"]))
#print("Shape ", out.shape)
out = F.max_pool2d(out, 2)
#print("Shape ", out.shape)
out = F.relu(F.conv2d(input=out, weight=self._params["w2"], bias=self._params["b2"]))
#print("Shape ", out.shape)
out = F.max_pool2d(out, 2)
#print("Shape ", out.shape)
out = out.view(out.size(0), -1)
#print("Shape ", out.shape)
out = F.relu(F.linear(out, self._params["w3"], self._params["b3"]))
#print("Shape ", out.shape)
out = F.linear(out, self._params["w4"], self._params["b4"])
#print("Shape ", out.shape)
#return F.log_softmax(out, dim=1)
return out
def __getitem__(self, key):
return self._params[key]
def __str__(self):
return "LeNet"
## MobileNetv2 ##
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self,
num_classes=1000,
width_mult=1.0,
inverted_residual_setting=None,
round_nearest=8,
block=None):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
"""
super(MobileNetV2, self).__init__()
if block is None:
block = InvertedResidual
input_channel = 32
last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x):
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
x = x.mean([2, 3])
x = self.classifier(x)
return x
def forward(self, x):
return self._forward_impl(x)
def __str__(self):
return "MobileNetV2"
## ResNet ##
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2]
class ResNet(nn.Module):
def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def __str__(self):
return "ResNet18"
## Wide ResNet ##
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
'''
class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.droprate = dropRate
self.equalInOut = (in_planes == out_planes)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=False) or None
def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
if self.droprate > 0:
out = F.dropout(out, p=self.droprate, training=self.training)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
class NetworkBlock(nn.Module):
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
layers = []
for i in range(int(nb_layers)):
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
#wrn_size: 32 = WRN-28-2 ? 160 = WRN-28-10
class WideResNet(nn.Module):
#def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
super(WideResNet, self).__init__()
self.kernel_size = wrn_size
self.depth=depth
filter_size = 3
nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4]
strides = [1, 2, 2] # stride for each resblock
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
assert((depth - 4) % 6 == 0)
n = (depth - 4) / 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
out = out.view(-1, self.nChannels)
return self.fc(out)
def architecture(self):
return super(WideResNet, self).__str__()
def __str__(self):
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)
'''

150
higher/smart_aug/old/test_lr.py Executable file
View file

@ -0,0 +1,150 @@
import numpy as np
import json, math, time, os
from torch.utils.data import SubsetRandomSampler
import torch.optim as optim
import higher
from model import *
import copy
BATCH_SIZE = 300
TEST_SIZE = 300
mnist_train = torchvision.datasets.MNIST(
"./data", train=True, download=True,
transform=torchvision.transforms.Compose([
#torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0),
torchvision.transforms.ToTensor()
])
)
mnist_test = torchvision.datasets.MNIST(
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
)
#train_subset_indices=range(int(len(mnist_train)/2))
train_subset_indices=range(BATCH_SIZE)
val_subset_indices=range(int(len(mnist_train)/2),len(mnist_train))
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
dl_val = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=TEST_SIZE, shuffle=False)
def test(model):
model.eval()
for i, (features, labels) in enumerate(dl_test):
pred = model.forward(features)
return pred.argmax(dim=1).eq(labels).sum().item() / TEST_SIZE * 100
def train_classic(model, optim, epochs=1):
model.train()
log = []
for epoch in range(epochs):
t0 = time.process_time()
for i, (features, labels) in enumerate(dl_train):
optim.zero_grad()
pred = model.forward(features)
loss = F.cross_entropy(pred,labels)
loss.backward()
optim.step()
#### Log ####
tf = time.process_time()
data={
"time": tf - t0,
}
log.append(data)
times = [x["time"] for x in log]
print("Vanilla : acc", test(model), "in (ms):", np.mean(times), "+/-", np.std(times))
##########################################
if __name__ == "__main__":
device = torch.device('cpu')
model = LeNet(1,10)
opt_param = {
"lr": torch.tensor(1e-2).requires_grad_(),
"momentum": torch.tensor(0.9).requires_grad_()
}
n_inner_iter = 1
dl_train_it = iter(dl_train)
dl_val_it = iter(dl_val)
epoch = 0
epochs = 10
####
train_classic(model=model, optim=torch.optim.Adam(model.parameters(), lr=0.001), epochs=epochs)
model = LeNet(1,10)
meta_opt = torch.optim.Adam(opt_param.values(), lr=1e-2)
inner_opt = torch.optim.SGD(model.parameters(), lr=opt_param['lr'], momentum=opt_param['momentum'])
#for xs_val, ys_val in dl_val:
while epoch < epochs:
#print(data_aug.params["mag"], data_aug.params["mag"].grad)
meta_opt.zero_grad()
model.train()
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
for param_group in diffopt.param_groups:
param_group['lr'] = opt_param['lr']
param_group['momentum'] = opt_param['momentum']
for i in range(n_inner_iter):
try:
xs, ys = next(dl_train_it)
except StopIteration: #Fin epoch train
epoch +=1
dl_train_it = iter(dl_train)
xs, ys = next(dl_train_it)
print('Epoch', epoch)
print('train loss',loss.item(), '/ val loss', val_loss.item())
print('acc', test(model))
print('opt : lr', opt_param['lr'].item(), 'momentum', opt_param['momentum'].item())
print('-'*9)
model.train()
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
#print('loss',loss.item())
diffopt.step(loss) # note that `step` must take `loss` as an argument!
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
# these new parameters, as an alternative to getting them from
# `fmodel.fast_params` or `fmodel.parameters()` after calling
# `diffopt.step`.
# At this point, or at any point in the iteration, you can take the
# gradient of `fmodel.parameters()` (or equivalently
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
# `grad_fn` as an attribute, and be part of the gradient tape.
# At the end of your inner loop you can obtain these e.g. ...
#grad_of_grads = torch.autograd.grad(
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val_it)
xs_val, ys_val = next(dl_val_it)
val_logits = fmodel(xs_val)
val_loss = F.cross_entropy(val_logits, ys_val)
#print('val_loss',val_loss.item())
val_loss.backward()
#meta_grads = torch.autograd.grad(val_loss, opt_lr, allow_unused=True)
#print(meta_grads)
for param_group in diffopt.param_groups:
print(param_group['lr'], '/',param_group['lr'].grad)
print(param_group['momentum'], '/',param_group['momentum'].grad)
#model=copy.deepcopy(fmodel)
model.load_state_dict(fmodel.state_dict())
meta_opt.step()

View file

@ -0,0 +1,866 @@
import torch
#import torch.optim
import torchvision
import higher
from datasets import *
from utils import *
def train_classic_higher(model, epochs=1):
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
model.train()
dl_val_it = iter(dl_val)
log = []
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt):
for epoch in range(epochs):
#print_torch_mem("Start epoch "+str(epoch))
#print("Fast param ",len(fmodel._fast_params))
t0 = time.process_time()
for i, (features, labels) in enumerate(dl_train):
#print_torch_mem("Start iter")
features,labels = features.to(device), labels.to(device)
#optim.zero_grad()
logits = model.forward(features)
pred = F.log_softmax(logits, dim=1)
loss = F.cross_entropy(pred,labels)
#.backward()
#optim.step()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
model_copy(src=fmodel, dst=model, patch_copy=False)
optim_copy(dopt=diffopt, opt=optim)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
#### Tests ####
tf = time.process_time()
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
val_loss = F.cross_entropy(model(xs_val), ys_val)
accuracy, _ =test(model)
model.train()
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": None,
}
log.append(data)
return log
def train_classic_tests(model, epochs=1):
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
countcopy=0
model.train()
dl_val_it = iter(dl_val)
log = []
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
doptim = higher.optim.get_diff_optim(optim, model.parameters(), fmodel=fmodel, track_higher_grads=False)
for epoch in range(epochs):
print_torch_mem("Start epoch")
print(len(fmodel._fast_params))
t0 = time.process_time()
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=True) as (fmodel, doptim):
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
for i, (features, labels) in enumerate(dl_train):
features,labels = features.to(device), labels.to(device)
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, doptim):
#optim.zero_grad()
pred = fmodel.forward(features)
loss = F.cross_entropy(pred,labels)
doptim.step(loss) #(opt.zero_grad, loss.backward, opt.step)
#loss.backward()
#new_params = doptim.step(loss, params=fmodel.parameters())
#fmodel.update_params(new_params)
#print('Fast param',len(fmodel._fast_params))
#print('opt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][2]['momentum_buffer'].shape)
if False or (len(fmodel._fast_params)>1):
print("fmodel fast param",len(fmodel._fast_params))
'''
#val_loss = F.cross_entropy(fmodel(features), labels)
#print_graph(val_loss)
#val_loss.backward()
#print('bip')
tmp = fmodel.parameters()
#print(list(tmp)[1])
tmp = [higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in tmp]
#print(len(tmp))
#fmodel._fast_params.clear()
del fmodel._fast_params
fmodel._fast_params=None
fmodel.fast_params=tmp # Surcharge la memoire
#fmodel.update_params(tmp) #Meilleur perf / Surcharge la memoire avec trach higher grad
#optim._fmodel=fmodel
'''
countcopy+=1
model_copy(src=fmodel, dst=model, patch_copy=False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
#doptim.detach_dyn()
#tmp = doptim.state
#tmp = doptim.state_dict()
#for k, v in tmp['state'].items():
# print('dict',k, type(v))
a = optim.param_groups[0]['params'][0]
state = optim.state[a]
#state['momentum_buffer'] = None
#print('opt state', type(optim.state[a]), len(optim.state[a]))
#optim.load_state_dict(tmp)
for group_idx, group in enumerate(optim.param_groups):
# print('gp idx',group_idx)
for p_idx, p in enumerate(group['params']):
optim.state[p]=doptim.state[group_idx][p_idx]
#print('opt state', type(optim.state[a]['momentum_buffer']), optim.state[a]['momentum_buffer'][0:10])
#print('dopt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][0]['momentum_buffer'][0:10])
'''
for a in tmp:
#print(type(a), len(a))
for nb, b in a.items():
#print(nb, type(b), len(b))
for n, state in b.items():
#print(n, type(states))
#print(state.grad_fn)
state = torch.tensor(state.data).requires_grad_()
#print(state.grad_fn)
'''
doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
#doptim.state = tmp
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
#### Tests ####
tf = time.process_time()
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
val_loss = F.cross_entropy(model(xs_val), ys_val)
accuracy, _ =test(model)
model.train()
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": None,
}
log.append(data)
#countcopy+=1
#model_copy(src=fmodel, dst=model, patch_copy=False)
#optim.load_state_dict(doptim.state_dict()) #Besoin sauver etat otpim ?
print("Copy ", countcopy)
return log
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import augmentation_transforms
import numpy as np
class AugmentedDatasetV2(VisionDataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
super(AugmentedDatasetV2, self).__init__(root, transform=transform, target_transform=target_transform)
supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform)
self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]]
self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]]
assert len(self.sup_data)==len(self.sup_targets)
for idx, img in enumerate(self.sup_data):
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
self.unsup_data=[]
self.unsup_targets=[]
self.origin_idx=[]
self.dataset_info= {
'name': 'CIFAR10',
'sup': len(self.sup_data),
'unsup': len(self.unsup_data),
'length': len(self.sup_data)+len(self.unsup_data),
}
self._TF = [
## Geometric TF ##
'Rotate',
'TranslateX',
'TranslateY',
'ShearX',
'ShearY',
'Cutout',
## Color TF ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize',
'Invert',
'AutoContrast',
'Equalize',
]
self._op_list =[]
self.prob=0.5
self.mag_range=(1, 10)
for tf in self._TF:
for mag in range(self.mag_range[0], self.mag_range[1]):
self._op_list+=[(tf, self.prob, mag)]
self._nb_op = len(self._op_list)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
aug_img, origin_img, target = self.unsup_data[index], self.sup_data[self.origin_idx[index]], self.unsup_targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
#img = Image.fromarray(img)
if self.transform is not None:
aug_img = self.transform(aug_img)
origin_img = self.transform(origin_img)
if self.target_transform is not None:
target = self.target_transform(target)
return aug_img, origin_img, target
def augement_data(self, aug_copy=1):
policies = []
for op_1 in self._op_list:
for op_2 in self._op_list:
policies += [[op_1, op_2]]
for idx, image in enumerate(self.sup_data):
if idx%(self.dataset_info['sup']/5)==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup'])
#if idx==10000:break
for _ in range(aug_copy):
chosen_policy = policies[np.random.choice(len(policies))]
aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8
self.unsup_targets+=[self.sup_targets[idx]]
self.origin_idx+=[idx]
#self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
self.unsup_data=np.array(self.unsup_data)
assert len(self.unsup_data)==len(self.unsup_targets)
self.dataset_info['unsup']=len(self.unsup_data)
self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup']
def __len__(self):
return self.dataset_info['unsup']#self.dataset_info['length']
def __str__(self):
return "CIFAR10(Sup:{}-Unsup:{}-{}TF(Mag{}-{}))".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF), self.mag_range[0], self.mag_range[1])
def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1):
"""Training of a model using UDA inspired approach.
Intended to be used alongside an already augmented dataset (see AugmentedDatasetV2).
Args:
model (nn.Module): Model to train.
dl_unsup (Dataloader): Data loader of unsupervised/augmented data.
opt_param (dict): Dictionnary containing optimizers parameters.
epochs (int): Number of epochs to perform. (default: 1)
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
Returns:
(list) Logs of training. Each items is a dict containing results of an epoch.
"""
device = next(model.parameters()).device
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
model.train()
dl_val_it = iter(dl_val)
dl_unsup_it =iter(dl_unsup)
log = []
for epoch in range(epochs):
#print_torch_mem("Start epoch")
t0 = time.process_time()
for i, (features, labels) in enumerate(dl_train):
#print_torch_mem("Start iter")
features,labels = features.to(device), labels.to(device)
optim.zero_grad()
#Supervised
logits = model.forward(features)
pred = F.log_softmax(logits, dim=1)
sup_loss = F.cross_entropy(pred,labels)
#Unsupervised
try:
aug_xs, origin_xs, ys = next(dl_unsup_it)
except StopIteration: #Fin epoch val
dl_unsup_it =iter(dl_unsup)
aug_xs, origin_xs, ys = next(dl_unsup_it)
aug_xs, origin_xs, ys = aug_xs.to(device), origin_xs.to(device), ys.to(device)
#print(aug_xs.shape, origin_xs.shape, ys.shape)
sup_logits = model.forward(origin_xs)
unsup_logits = model.forward(aug_xs)
log_sup=F.log_softmax(sup_logits, dim=1)
log_unsup=F.log_softmax(unsup_logits, dim=1)
#KL div w/ logits
unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup)
unsup_loss=unsup_loss.sum(dim=-1).mean()
#print(unsup_loss)
unsupp_coeff = 1
loss = sup_loss + unsup_loss * unsupp_coeff
loss.backward()
optim.step()
#### Tests ####
tf = time.process_time()
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
val_loss = F.cross_entropy(model(xs_val), ys_val)
accuracy, _ =test(model)
model.train()
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Sup Loss :', sup_loss.item(), '/ unsup_loss :', unsup_loss.item())
print('Accuracy :', accuracy)
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": None,
}
log.append(data)
return log
def run_simple_dataug(inner_it, epochs=1):
device = next(model.parameters()).device
dl_train_it = iter(dl_train)
dl_val_it = iter(dl_val)
#aug_model = nn.Sequential(
# Data_aug(),
# LeNet(1,10),
# )
aug_model = Augmented_model(Data_aug(), LeNet(1,10)).to(device)
print(str(aug_model))
meta_opt = torch.optim.Adam(aug_model['data_aug'].parameters(), lr=1e-2)
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
log = []
t0 = time.process_time()
epoch = 0
while epoch < epochs:
meta_opt.zero_grad()
aug_model.train()
with higher.innerloop_ctx(aug_model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
for i in range(n_inner_iter):
try:
xs, ys = next(dl_train_it)
except StopIteration: #Fin epoch train
tf = time.process_time()
epoch +=1
dl_train_it = iter(dl_train)
xs, ys = next(dl_train_it)
accuracy, _ =test(model)
aug_model.train()
#### Print ####
print('-'*9)
print('Epoch %d/%d'%(epoch,epochs))
print('train loss',loss.item(), '/ val loss', val_loss.item())
print('acc', accuracy)
print('mag', aug_model['data_aug']['mag'].item())
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": aug_model['data_aug']['mag'].item(),
}
log.append(data)
t0 = time.process_time()
xs, ys = xs.to(device), ys.to(device)
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
#loss.backward(retain_graph=True)
#print(fmodel['model']._params['b4'].grad)
#print('mag', fmodel['data_aug']['mag'].grad)
diffopt.step(loss) # note that `step` must take `loss` as an argument!
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
# these new parameters, as an alternative to getting them from
# `fmodel.fast_params` or `fmodel.parameters()` after calling
# `diffopt.step`.
# At this point, or at any point in the iteration, you can take the
# gradient of `fmodel.parameters()` (or equivalently
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
# `grad_fn` as an attribute, and be part of the gradient tape.
# At the end of your inner loop you can obtain these e.g. ...
#grad_of_grads = torch.autograd.grad(
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
fmodel.augment(mode=False)
val_logits = fmodel(xs_val) #Validation sans transfornations !
val_loss = F.cross_entropy(val_logits, ys_val)
#print('val_loss',val_loss.item())
val_loss.backward()
#print('mag', fmodel['data_aug']['mag'], '/', fmodel['data_aug']['mag'].grad)
#model=copy.deepcopy(fmodel)
aug_model.load_state_dict(fmodel.state_dict()) #Do not copy gradient !
#Copie des gradients
for paramName, paramValue, in fmodel.named_parameters():
for netCopyName, netCopyValue, in aug_model.named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#print('mag', aug_model['data_aug']['mag'], '/', aug_model['data_aug']['mag'].grad)
meta_opt.step()
plot_res(log, fig_name="res/{}-{} epochs- {} in_it".format(str(aug_model),epochs,inner_it))
print('-'*9)
times = [x["time"] for x in log]
print(str(aug_model),": acc", max([x["acc"] for x in log]), "in (ms):", np.mean(times), "+/-", np.std(times))
def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
device = next(model.parameters()).device
dl_train_it = iter(dl_train)
dl_val_it = iter(dl_val)
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-3)
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
high_grad_track = True
if dataug_epoch_start>0:
model.augment(mode=False)
high_grad_track = False
model.train()
log = []
t0 = time.process_time()
countcopy=0
val_loss=torch.tensor(0)
opt_param=None
epoch = 0
while epoch < epochs:
meta_opt.zero_grad()
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
for i in range(n_inner_iter):
try:
xs, ys = next(dl_train_it)
except StopIteration: #Fin epoch train
tf = time.process_time()
epoch +=1
dl_train_it = iter(dl_train)
xs, ys = next(dl_train_it)
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
accuracy, _ =test(model)
model.train()
#### Print ####
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy :', accuracy)
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
print('TF Proba :', model['data_aug']['prob'].data)
#print('proba grad',aug_model['data_aug']['prob'].grad)
#############
#### Log ####
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": [p for p in model['data_aug']['prob']],
}
log.append(data)
#############
if epoch == dataug_epoch_start:
print('Starting Data Augmention...')
model.augment(mode=True)
high_grad_track = True
t0 = time.process_time()
xs, ys = xs.to(device), ys.to(device)
'''
#Methode exacte
final_loss = 0
for tf_idx in range(fmodel['data_aug']._nb_tf):
fmodel['data_aug'].transf_idx=tf_idx
logits = fmodel(xs)
loss = F.cross_entropy(logits, ys)
#loss.backward(retain_graph=True)
#print('idx', tf_idx)
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
loss = final_loss
'''
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight().to(device)
loss = loss * w_loss
loss = loss.mean()
#'''
#to visualize computational graph
#print_graph(loss)
#loss.backward(retain_graph=True)
#print(fmodel['model']._params['b4'].grad)
#print('prob grad', fmodel['data_aug']['prob'].grad)
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
try:
xs_val, ys_val = next(dl_val_it)
except StopIteration: #Fin epoch val
dl_val_it = iter(dl_val)
xs_val, ys_val = next(dl_val_it)
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
fmodel.augment(mode=False) #Validation sans transfornations !
val_loss = F.cross_entropy(fmodel(xs_val), ys_val)
#print_graph(val_loss)
val_loss.backward()
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
meta_opt.step()
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
print("Copy ", countcopy)
return log
def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
device = next(model.parameters()).device
log = []
countcopy=0
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
dl_val_it = iter(dl_val)
#if inner_it!=0:
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
high_grad_track = True
if inner_it == 0:
high_grad_track=False
if dataug_epoch_start!=0:
model.augment(mode=False)
high_grad_track = False
val_loss_monitor= None
if loss_patience != None :
if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start
else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data)
model.train()
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
meta_opt.zero_grad()
for epoch in range(1, epochs+1):
#print_torch_mem("Start epoch "+str(epoch))
#print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params))
t0 = time.process_time()
#with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt):
for i, (xs, ys) in enumerate(dl_train):
xs, ys = xs.to(device), ys.to(device)
#Methode exacte
#final_loss = 0
#for tf_idx in range(fmodel['data_aug']._nb_tf):
# fmodel['data_aug'].transf_idx=tf_idx
# logits = fmodel(xs)
# loss = F.cross_entropy(logits, ys)
# #loss.backward(retain_graph=True)
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
#loss = final_loss
if(not KLdiv):
#Methode uniforme
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
if fmodel._data_augmentation: #Weight loss
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
loss = loss * w_loss
loss = loss.mean()
else:
#Methode KL div
if fmodel._data_augmentation :
fmodel.augment(mode=False)
sup_logits = fmodel(xs)
fmodel.augment(mode=True)
else:
sup_logits = fmodel(xs)
log_sup=F.log_softmax(sup_logits, dim=1)
loss = F.cross_entropy(log_sup, ys)
if fmodel._data_augmentation:
aug_logits = fmodel(xs)
log_aug=F.log_softmax(aug_logits, dim=1)
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
#if epoch>50: #debut differe ?
#KL div w/ logits - Similarite predictions (distributions)
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
aug_loss = aug_loss.sum(dim=-1)
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
aug_loss = (w_loss * aug_loss).mean()
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
unsupp_coeff = 1
loss += aug_loss * unsupp_coeff
#to visualize computational graph
#print_graph(loss)
#loss.backward(retain_graph=True)
#print(fmodel['model']._params['b4'].grad)
#print('prob grad', fmodel['data_aug']['prob'].grad)
#t = time.process_time()
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
#print(len(fmodel._fast_params),"step", time.process_time()-t)
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
#print("meta")
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss()
#print_graph(val_loss)
#t = time.process_time()
val_loss.backward()
#print("meta", time.process_time()-t)
#print('proba grad',model['data_aug']['prob'].grad)
if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None:
print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad)
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
#if epoch>50:
meta_opt.step()
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
try: #Dataugv6
model['data_aug'].next_TF_set()
except:
pass
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
meta_opt.zero_grad()
tf = time.process_time()
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
if(not high_grad_track):
countcopy+=1
model_copy(src=fmodel, dst=model)
optim_copy(dopt=diffopt, opt=inner_opt)
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
accuracy, test_loss =test(model)
model.train()
#### Log ####
#print(type(model['data_aug']) is dataug.Data_augV5)
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
data={
"epoch": epoch,
"train_loss": loss.item(),
"val_loss": val_loss.item(),
"acc": accuracy,
"time": tf - t0,
"param": param #if isinstance(model['data_aug'], Data_augV5)
#else [p.item() for p in model['data_aug']['prob']],
}
log.append(data)
#############
#### Print ####
if(print_freq and epoch%print_freq==0):
print('-'*9)
print('Epoch : %d/%d'%(epoch,epochs))
print('Time : %.00f'%(tf - t0))
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
print('Accuracy :', max([x["acc"] for x in log]))
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
print('TF Proba :', model['data_aug']['prob'].data)
#print('proba grad',model['data_aug']['prob'].grad)
print('TF Mag :', model['data_aug']['mag'].data)
#print('Mag grad',model['data_aug']['mag'].grad)
#print('Reg loss:', model['data_aug'].reg_loss().item())
#print('Aug loss', aug_loss.item())
#############
if val_loss_monitor :
model.eval()
val_loss_monitor.register(test_loss)#val_loss.item())
if val_loss_monitor.end_training(): break #Stop training
model.train()
if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)):
print('Starting Data Augmention...')
dataug_epoch_start = epoch
model.augment(mode=True)
if inner_it != 0: high_grad_track = True
try:
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
except:
print("Couldn't save finals samples")
pass
#print("Copy ", countcopy)
return log

View file

@ -0,0 +1,161 @@
import numpy as np
import json, math, time, os
import matplotlib.pyplot as plt
import copy
import gc
from torchviz import make_dot
import torch
import torch.nn.functional as F
import time
class timer():
def __init__(self):
self._start_time=time.time()
def exec_time(self):
end = time.time()
res = end-self._start_time
self._start_time=end
return res
def plot_res(log, fig_name='res', param_names=None):
epochs = [x["epoch"] for x in log]
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
ax[0].set_title('Loss')
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
ax[0].legend()
ax[1].set_title('Acc')
ax[1].plot(epochs,[x["acc"] for x in log])
if log[0]["param"]!= None:
if isinstance(log[0]["param"],float):
ax[2].set_title('Mag')
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
ax[2].legend()
else :
ax[2].set_title('Prob')
#for idx, _ in enumerate(log[0]["param"]):
#ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
ax[2].stackplot(epochs, proba, labels=param_names)
ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name)
plt.close()
def plot_res_compare(filenames, fig_name='res'):
all_data=[]
#legend=""
for idx, file in enumerate(filenames):
#legend+=str(idx)+'-'+file+'\n'
with open(file) as json_file:
data = json.load(json_file)
all_data.append(data)
n_tf = [len(x["Param_names"]) for x in all_data]
acc = [x["Accuracy"] for x in all_data]
time = [x["Time"][0] for x in all_data]
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
ax[0].plot(n_tf, acc)
ax[1].plot(n_tf, time)
ax[0].set_title('Acc')
ax[1].set_title('Time')
#for a in ax: a.legend()
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def plot_TF_res(log, tf_names, fig_name='res'):
mean = np.mean([x["param"] for x in log], axis=0)
std = np.std([x["param"] for x in log], axis=0)
fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True)
ax.bar(tf_names, mean, yerr=std)
#ax.bar(tf_names, log[-1]["param"])
fig_name = fig_name.replace('.',',')
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
def model_copy(src,dst, patch_copy=True, copy_grad=True):
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
if patch_copy:
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
#Copie des gradients
if copy_grad:
for paramName, paramValue, in src.named_parameters():
for netCopyName, netCopyValue, in dst.named_parameters():
if paramName == netCopyName:
netCopyValue.grad = paramValue.grad
#netCopyValue=copy.deepcopy(paramValue)
try: #Data_augV4
dst['data_aug']._input_info = src['data_aug']._input_info
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
except:
pass
def optim_copy(dopt, opt):
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
for group_idx, group in enumerate(opt.param_groups):
# print('gp idx',group_idx)
for p_idx, p in enumerate(group['params']):
opt.state[p]=dopt.state[group_idx][p_idx]
class loss_monitor(): #Voir https://github.com/pytorch/ignite
def __init__(self, patience, end_train=1):
self.patience = patience
self.end_train = end_train
self.counter = 0
self.best_score = None
self.reached_limit = 0
def register(self, loss):
if self.best_score is None:
self.best_score = loss
elif loss > self.best_score:
self.counter += 1
#if not self.reached_limit:
print("loss no improve counter", self.counter, self.reached_limit)
else:
self.best_score = loss
self.counter = 0
def limit_reached(self):
if self.counter >= self.patience:
self.counter = 0
self.reached_limit +=1
self.best_score = None
return self.reached_limit
def end_training(self):
if self.limit_reached() >= self.end_train:
return True
else:
return False
def reset(self):
self.__init__(self.patience, self.end_train)