mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Rangement
This commit is contained in:
parent
f83c73ec17
commit
f507ff4741
16 changed files with 85 additions and 46 deletions
490
higher/smart_aug/old/augmentation_transforms.py
Executable file
490
higher/smart_aug/old/augmentation_transforms.py
Executable file
|
@ -0,0 +1,490 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2019 The Google UDA Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Transforms used in the Augmentation Policies.
|
||||
|
||||
Copied from AutoAugment: https://github.com/tensorflow/models/blob/master/research/autoaugment/
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
# pylint:disable=g-multiple-import
|
||||
from PIL import ImageOps, ImageEnhance, ImageFilter, Image
|
||||
# pylint:enable=g-multiple-import
|
||||
|
||||
#import tensorflow as tf
|
||||
|
||||
#FLAGS = tf.flags.FLAGS
|
||||
|
||||
|
||||
IMAGE_SIZE = 32
|
||||
# What is the dataset mean and std of the images on the training set
|
||||
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
|
||||
|
||||
|
||||
def get_mean_and_std():
|
||||
#if FLAGS.task_name == "cifar10":
|
||||
means = [0.49139968, 0.48215841, 0.44653091]
|
||||
stds = [0.24703223, 0.24348513, 0.26158784]
|
||||
#elif FLAGS.task_name == "svhn":
|
||||
# means = [0.4376821, 0.4437697, 0.47280442]
|
||||
# stds = [0.19803012, 0.20101562, 0.19703614]
|
||||
#else:
|
||||
# assert False
|
||||
return means, stds
|
||||
|
||||
|
||||
def random_flip(x):
|
||||
"""Flip the input x horizontally with 50% probability."""
|
||||
if np.random.rand(1)[0] > 0.5:
|
||||
return np.fliplr(x)
|
||||
return x
|
||||
|
||||
|
||||
def zero_pad_and_crop(img, amount=4):
|
||||
"""Zero pad by `amount` zero pixels on each side then take a random crop.
|
||||
|
||||
Args:
|
||||
img: numpy image that will be zero padded and cropped.
|
||||
amount: amount of zeros to pad `img` with horizontally and verically.
|
||||
|
||||
Returns:
|
||||
The cropped zero padded img. The returned numpy array will be of the same
|
||||
shape as `img`.
|
||||
"""
|
||||
padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
|
||||
img.shape[2]))
|
||||
padded_img[amount:img.shape[0] + amount, amount:
|
||||
img.shape[1] + amount, :] = img
|
||||
top = np.random.randint(low=0, high=2 * amount)
|
||||
left = np.random.randint(low=0, high=2 * amount)
|
||||
new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
|
||||
return new_img
|
||||
|
||||
|
||||
def create_cutout_mask(img_height, img_width, num_channels, size):
|
||||
"""Creates a zero mask used for cutout of shape `img_height` x `img_width`.
|
||||
|
||||
Args:
|
||||
img_height: Height of image cutout mask will be applied to.
|
||||
img_width: Width of image cutout mask will be applied to.
|
||||
num_channels: Number of channels in the image.
|
||||
size: Size of the zeros mask.
|
||||
|
||||
Returns:
|
||||
A mask of shape `img_height` x `img_width` with all ones except for a
|
||||
square of zeros of shape `size` x `size`. This mask is meant to be
|
||||
elementwise multiplied with the original image. Additionally returns
|
||||
the `upper_coord` and `lower_coord` which specify where the cutout mask
|
||||
will be applied.
|
||||
"""
|
||||
assert img_height == img_width
|
||||
|
||||
# Sample center where cutout mask will be applied
|
||||
height_loc = np.random.randint(low=0, high=img_height)
|
||||
width_loc = np.random.randint(low=0, high=img_width)
|
||||
|
||||
# Determine upper right and lower left corners of patch
|
||||
upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
|
||||
lower_coord = (min(img_height, height_loc + size // 2),
|
||||
min(img_width, width_loc + size // 2))
|
||||
mask_height = lower_coord[0] - upper_coord[0]
|
||||
mask_width = lower_coord[1] - upper_coord[1]
|
||||
assert mask_height > 0
|
||||
assert mask_width > 0
|
||||
|
||||
mask = np.ones((img_height, img_width, num_channels))
|
||||
zeros = np.zeros((mask_height, mask_width, num_channels))
|
||||
mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
|
||||
zeros)
|
||||
return mask, upper_coord, lower_coord
|
||||
|
||||
|
||||
def cutout_numpy(img, size=16):
|
||||
"""Apply cutout with mask of shape `size` x `size` to `img`.
|
||||
|
||||
The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
|
||||
This operation applies a `size`x`size` mask of zeros to a random location
|
||||
within `img`.
|
||||
|
||||
Args:
|
||||
img: Numpy image that cutout will be applied to.
|
||||
size: Height/width of the cutout mask that will be
|
||||
|
||||
Returns:
|
||||
A numpy tensor that is the result of applying the cutout mask to `img`.
|
||||
"""
|
||||
img_height, img_width, num_channels = (img.shape[0], img.shape[1],
|
||||
img.shape[2])
|
||||
assert len(img.shape) == 3
|
||||
mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
|
||||
return img * mask
|
||||
|
||||
|
||||
def float_parameter(level, maxval):
|
||||
"""Helper function to scale `val` between 0 and maxval .
|
||||
|
||||
Args:
|
||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||
maxval: Maximum value that the operation can have. This will be scaled
|
||||
to level/PARAMETER_MAX.
|
||||
|
||||
Returns:
|
||||
A float that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
return float(level) * maxval / PARAMETER_MAX
|
||||
|
||||
|
||||
def int_parameter(level, maxval):
|
||||
"""Helper function to scale `val` between 0 and maxval .
|
||||
|
||||
Args:
|
||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||
maxval: Maximum value that the operation can have. This will be scaled
|
||||
to level/PARAMETER_MAX.
|
||||
|
||||
Returns:
|
||||
An int that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
return int(level * maxval / PARAMETER_MAX)
|
||||
|
||||
|
||||
def pil_wrap(img, use_mean_std):
|
||||
"""Convert the `img` numpy tensor to a PIL Image."""
|
||||
|
||||
if use_mean_std:
|
||||
MEANS, STDS = get_mean_and_std()
|
||||
else:
|
||||
MEANS = [0, 0, 0]
|
||||
STDS = [1, 1, 1]
|
||||
img_ori = (img * STDS + MEANS) * 255
|
||||
|
||||
return Image.fromarray(
|
||||
np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')
|
||||
|
||||
|
||||
def pil_unwrap(pil_img, use_mean_std, img_shape):
|
||||
"""Converts the PIL img to a numpy array."""
|
||||
if use_mean_std:
|
||||
MEANS, STDS = get_mean_and_std()
|
||||
else:
|
||||
MEANS = [0, 0, 0]
|
||||
STDS = [1, 1, 1]
|
||||
pic_array = np.array(pil_img.getdata()).reshape((img_shape[0], img_shape[1], 4)) / 255.0
|
||||
i1, i2 = np.where(pic_array[:, :, 3] == 0)
|
||||
pic_array = (pic_array[:, :, :3] - MEANS) / STDS
|
||||
pic_array[i1, i2] = [0, 0, 0]
|
||||
return pic_array
|
||||
|
||||
|
||||
def apply_policy(policy, img, use_mean_std=True):
|
||||
"""Apply the `policy` to the numpy `img`.
|
||||
|
||||
Args:
|
||||
policy: A list of tuples with the form (name, probability, level) where
|
||||
`name` is the name of the augmentation operation to apply, `probability`
|
||||
is the probability of applying the operation and `level` is what strength
|
||||
the operation to apply.
|
||||
img: Numpy image that will have `policy` applied to it.
|
||||
|
||||
Returns:
|
||||
The result of applying `policy` to `img`.
|
||||
"""
|
||||
img_shape = img.shape
|
||||
pil_img = pil_wrap(img, use_mean_std)
|
||||
for xform in policy:
|
||||
assert len(xform) == 3
|
||||
name, probability, level = xform
|
||||
xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(
|
||||
probability, level, img_shape)
|
||||
pil_img = xform_fn(pil_img)
|
||||
return pil_unwrap(pil_img, use_mean_std, img_shape)
|
||||
|
||||
|
||||
class TransformFunction(object):
|
||||
"""Wraps the Transform function for pretty printing options."""
|
||||
|
||||
def __init__(self, func, name):
|
||||
self.f = func
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return '<' + self.name + '>'
|
||||
|
||||
def __call__(self, pil_img):
|
||||
return self.f(pil_img)
|
||||
|
||||
|
||||
class TransformT(object):
|
||||
"""Each instance of this class represents a specific transform."""
|
||||
|
||||
def __init__(self, name, xform_fn):
|
||||
self.name = name
|
||||
self.xform = xform_fn
|
||||
|
||||
def pil_transformer(self, probability, level, img_shape):
|
||||
|
||||
def return_function(im):
|
||||
if random.random() < probability:
|
||||
im = self.xform(im, level, img_shape)
|
||||
return im
|
||||
|
||||
name = self.name + '({:.1f},{})'.format(probability, level)
|
||||
return TransformFunction(return_function, name)
|
||||
|
||||
|
||||
################## Transform Functions ##################
|
||||
identity = TransformT('identity', lambda pil_img, level, _: pil_img)
|
||||
flip_lr = TransformT(
|
||||
'FlipLR',
|
||||
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
|
||||
flip_ud = TransformT(
|
||||
'FlipUD',
|
||||
lambda pil_img, level, _: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
|
||||
# pylint:disable=g-long-lambda
|
||||
auto_contrast = TransformT(
|
||||
'AutoContrast',
|
||||
lambda pil_img, level, _: ImageOps.autocontrast(
|
||||
pil_img.convert('RGB')).convert('RGBA'))
|
||||
equalize = TransformT(
|
||||
'Equalize',
|
||||
lambda pil_img, level, _: ImageOps.equalize(
|
||||
pil_img.convert('RGB')).convert('RGBA'))
|
||||
invert = TransformT(
|
||||
'Invert',
|
||||
lambda pil_img, level, _: ImageOps.invert(
|
||||
pil_img.convert('RGB')).convert('RGBA'))
|
||||
# pylint:enable=g-long-lambda
|
||||
blur = TransformT(
|
||||
'Blur', lambda pil_img, level, _: pil_img.filter(ImageFilter.BLUR))
|
||||
smooth = TransformT(
|
||||
'Smooth',
|
||||
lambda pil_img, level, _: pil_img.filter(ImageFilter.SMOOTH))
|
||||
|
||||
|
||||
def _rotate_impl(pil_img, level, _):
|
||||
"""Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
|
||||
degrees = int_parameter(level, 30)
|
||||
if random.random() > 0.5:
|
||||
degrees = -degrees
|
||||
return pil_img.rotate(degrees)
|
||||
|
||||
|
||||
rotate = TransformT('Rotate', _rotate_impl)
|
||||
|
||||
|
||||
def _posterize_impl(pil_img, level, _):
|
||||
"""Applies PIL Posterize to `pil_img`."""
|
||||
level = int_parameter(level, 4)
|
||||
return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')
|
||||
|
||||
|
||||
posterize = TransformT('Posterize', _posterize_impl)
|
||||
|
||||
|
||||
def _shear_x_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL ShearX to `pil_img`.
|
||||
|
||||
The ShearX operation shears the image along the horizontal axis with `level`
|
||||
magnitude.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had ShearX applied to it.
|
||||
"""
|
||||
level = float_parameter(level, 0.3)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, level, 0, 0, 1, 0))
|
||||
|
||||
|
||||
shear_x = TransformT('ShearX', _shear_x_impl)
|
||||
|
||||
|
||||
def _shear_y_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL ShearY to `pil_img`.
|
||||
|
||||
The ShearY operation shears the image along the vertical axis with `level`
|
||||
magnitude.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had ShearX applied to it.
|
||||
"""
|
||||
level = float_parameter(level, 0.3)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, level, 1, 0))
|
||||
|
||||
|
||||
shear_y = TransformT('ShearY', _shear_y_impl)
|
||||
|
||||
|
||||
def _translate_x_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL TranslateX to `pil_img`.
|
||||
|
||||
Translate the image in the horizontal direction by `level`
|
||||
number of pixels.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had TranslateX applied to it.
|
||||
"""
|
||||
level = int_parameter(level, 10)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, 0, level, 0, 1, 0))
|
||||
|
||||
|
||||
translate_x = TransformT('TranslateX', _translate_x_impl)
|
||||
|
||||
|
||||
def _translate_y_impl(pil_img, level, img_shape):
|
||||
"""Applies PIL TranslateY to `pil_img`.
|
||||
|
||||
Translate the image in the vertical direction by `level`
|
||||
number of pixels.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had TranslateY applied to it.
|
||||
"""
|
||||
level = int_parameter(level, 10)
|
||||
if random.random() > 0.5:
|
||||
level = -level
|
||||
return pil_img.transform(
|
||||
(img_shape[0], img_shape[1]),
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, 0, 1, level))
|
||||
|
||||
|
||||
translate_y = TransformT('TranslateY', _translate_y_impl)
|
||||
|
||||
|
||||
def _crop_impl(pil_img, level, img_shape, interpolation=Image.BILINEAR):
|
||||
"""Applies a crop to `pil_img` with the size depending on the `level`."""
|
||||
cropped = pil_img.crop((level, level, img_shape[0] - level, img_shape[1] - level))
|
||||
resized = cropped.resize((img_shape[0], img_shape[1]), interpolation)
|
||||
return resized
|
||||
|
||||
|
||||
crop_bilinear = TransformT('CropBilinear', _crop_impl)
|
||||
|
||||
|
||||
def _solarize_impl(pil_img, level, _):
|
||||
"""Applies PIL Solarize to `pil_img`.
|
||||
|
||||
Translate the image in the vertical direction by `level`
|
||||
number of pixels.
|
||||
|
||||
Args:
|
||||
pil_img: Image in PIL object.
|
||||
level: Strength of the operation specified as an Integer from
|
||||
[0, `PARAMETER_MAX`].
|
||||
|
||||
Returns:
|
||||
A PIL Image that has had Solarize applied to it.
|
||||
"""
|
||||
level = int_parameter(level, 256)
|
||||
return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')
|
||||
|
||||
|
||||
solarize = TransformT('Solarize', _solarize_impl)
|
||||
|
||||
|
||||
def _cutout_pil_impl(pil_img, level, img_shape):
|
||||
"""Apply cutout to pil_img at the specified level."""
|
||||
size = int_parameter(level, 20)
|
||||
if size <= 0:
|
||||
return pil_img
|
||||
img_height, img_width, num_channels = (img_shape[0], img_shape[1], 3)
|
||||
_, upper_coord, lower_coord = (
|
||||
create_cutout_mask(img_height, img_width, num_channels, size))
|
||||
pixels = pil_img.load() # create the pixel map
|
||||
for i in range(upper_coord[0], lower_coord[0]): # for every col:
|
||||
for j in range(upper_coord[1], lower_coord[1]): # For every row
|
||||
pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
|
||||
return pil_img
|
||||
|
||||
cutout = TransformT('Cutout', _cutout_pil_impl)
|
||||
|
||||
|
||||
def _enhancer_impl(enhancer):
|
||||
"""Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
|
||||
def impl(pil_img, level, _):
|
||||
v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it
|
||||
return enhancer(pil_img).enhance(v)
|
||||
return impl
|
||||
|
||||
|
||||
color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
|
||||
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
|
||||
brightness = TransformT('Brightness', _enhancer_impl(
|
||||
ImageEnhance.Brightness))
|
||||
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))
|
||||
|
||||
ALL_TRANSFORMS = [
|
||||
flip_lr,
|
||||
flip_ud,
|
||||
auto_contrast,
|
||||
equalize,
|
||||
invert,
|
||||
rotate,
|
||||
posterize,
|
||||
crop_bilinear,
|
||||
solarize,
|
||||
color,
|
||||
contrast,
|
||||
brightness,
|
||||
sharpness,
|
||||
shear_x,
|
||||
shear_y,
|
||||
translate_x,
|
||||
translate_y,
|
||||
cutout,
|
||||
blur,
|
||||
smooth
|
||||
]
|
||||
|
||||
NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
|
||||
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()
|
1065
higher/smart_aug/old/dataug_old.py
Normal file
1065
higher/smart_aug/old/dataug_old.py
Normal file
File diff suppressed because it is too large
Load diff
85
higher/smart_aug/old/higher_repro.py
Normal file
85
higher/smart_aug/old/higher_repro.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import higher
|
||||
import time
|
||||
|
||||
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=torchvision.transforms.ToTensor())
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=300, shuffle=True, num_workers=0, pin_memory=False)
|
||||
|
||||
|
||||
class Aug_model(nn.Module):
|
||||
def __init__(self, model, hyper_param=True):
|
||||
super(Aug_model, self).__init__()
|
||||
|
||||
#### Origin of the issue ? ####
|
||||
if hyper_param:
|
||||
self._params = nn.ParameterDict({
|
||||
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
|
||||
})
|
||||
###############################
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'model': model,
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['model'](x) #* self._params['hyper_param']
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
class Aug_model2(nn.Module): #Slow increase like no hyper_param
|
||||
def __init__(self, model, hyper_param=True):
|
||||
super(Aug_model2, self).__init__()
|
||||
|
||||
#### Origin of the issue ? ####
|
||||
if hyper_param:
|
||||
self._params = nn.ParameterDict({
|
||||
"hyper_param": nn.Parameter(torch.Tensor([0.5])),
|
||||
})
|
||||
###############################
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'model': model,
|
||||
'fmodel': higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['fmodel'](x) * self._params['hyper_param']
|
||||
|
||||
def get_diffopt(self, opt, track_higher_grads=True):
|
||||
return higher.optim.get_diff_optim(opt,
|
||||
self._mods['model'].parameters(),
|
||||
fmodel=self._mods['fmodel'],
|
||||
track_higher_grads=track_higher_grads)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
aug_model = Aug_model2(
|
||||
model=torch.hub.load('pytorch/vision:v0.4.2', 'resnet18', pretrained=False),
|
||||
hyper_param=True #False will not extend step time
|
||||
).to(device)
|
||||
|
||||
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
#fmodel = higher.patch.monkeypatch(aug_model, device=None, copy_initial_weights=True)
|
||||
#diffopt = higher.optim.get_diff_optim(inner_opt, aug_model.parameters(),fmodel=fmodel,track_higher_grads=True)
|
||||
diffopt = aug_model.get_diffopt(inner_opt)
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
#logits = fmodel(xs)
|
||||
logits = aug_model(xs)
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='mean')
|
||||
|
||||
t = time.process_time()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
#print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||
print(len(aug_model['fmodel']._fast_params),"step", time.process_time()-t)
|
502
higher/smart_aug/old/model_old.py
Normal file
502
higher/smart_aug/old/model_old.py
Normal file
|
@ -0,0 +1,502 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
## Basic CNN ##
|
||||
class LeNet_F(nn.Module):
|
||||
def __init__(self, num_inp, num_out):
|
||||
super(LeNet_F, self).__init__()
|
||||
self._params = nn.ParameterDict({
|
||||
'w1': nn.Parameter(torch.zeros(20, num_inp, 5, 5)),
|
||||
'b1': nn.Parameter(torch.zeros(20)),
|
||||
'w2': nn.Parameter(torch.zeros(50, 20, 5, 5)),
|
||||
'b2': nn.Parameter(torch.zeros(50)),
|
||||
#'w3': nn.Parameter(torch.zeros(500,4*4*50)), #num_imp=1
|
||||
'w3': nn.Parameter(torch.zeros(500,5*5*50)), #num_imp=3
|
||||
'b3': nn.Parameter(torch.zeros(500)),
|
||||
'w4': nn.Parameter(torch.zeros(num_out, 500)),
|
||||
'b4': nn.Parameter(torch.zeros(num_out))
|
||||
})
|
||||
self.initialize()
|
||||
|
||||
|
||||
def initialize(self):
|
||||
nn.init.kaiming_uniform_(self._params["w1"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w2"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w3"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w4"], a=math.sqrt(5))
|
||||
|
||||
def forward(self, x):
|
||||
#print("Start Shape ", x.shape)
|
||||
out = F.relu(F.conv2d(input=x, weight=self._params["w1"], bias=self._params["b1"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.max_pool2d(out, 2)
|
||||
#print("Shape ", out.shape)
|
||||
out = F.relu(F.conv2d(input=out, weight=self._params["w2"], bias=self._params["b2"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.max_pool2d(out, 2)
|
||||
#print("Shape ", out.shape)
|
||||
out = out.view(out.size(0), -1)
|
||||
#print("Shape ", out.shape)
|
||||
out = F.relu(F.linear(out, self._params["w3"], self._params["b3"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.linear(out, self._params["w4"], self._params["b4"])
|
||||
#print("Shape ", out.shape)
|
||||
#return F.log_softmax(out, dim=1)
|
||||
return out
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "LeNet"
|
||||
|
||||
|
||||
## MobileNetv2 ##
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Sequential):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
padding = (kernel_size - 1) // 2
|
||||
super(ConvBNReLU, self).__init__(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend([
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
])
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self,
|
||||
num_classes=1000,
|
||||
width_mult=1.0,
|
||||
inverted_residual_setting=None,
|
||||
round_nearest=8,
|
||||
block=None):
|
||||
"""
|
||||
MobileNet V2 main class
|
||||
Args:
|
||||
num_classes (int): Number of classes
|
||||
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
||||
inverted_residual_setting: Network structure
|
||||
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
||||
Set to 1 to turn off rounding
|
||||
block: Module specifying inverted residual building block for mobilenet
|
||||
"""
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
if inverted_residual_setting is None:
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# only check the first element, assuming user knows t,c,n,s are required
|
||||
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
||||
raise ValueError("inverted_residual_setting should be non-empty "
|
||||
"or a 4-element list, got {}".format(inverted_residual_setting))
|
||||
|
||||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = _make_divisible(c * width_mult, round_nearest)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
||||
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
||||
x = self.features(x)
|
||||
x = x.mean([2, 3])
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def __str__(self):
|
||||
return "MobileNetV2"
|
||||
|
||||
## ResNet ##
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
__constants__ = ['downsample']
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
__constants__ = ['downsample']
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2]
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x):
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward_impl(x)
|
||||
|
||||
def __str__(self):
|
||||
return "ResNet18"
|
||||
|
||||
## Wide ResNet ##
|
||||
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
|
||||
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
|
||||
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
|
||||
'''
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
self.equalInOut = (in_planes == out_planes)
|
||||
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
|
||||
padding=0, bias=False) or None
|
||||
def forward(self, x):
|
||||
if not self.equalInOut:
|
||||
x = self.relu1(self.bn1(x))
|
||||
else:
|
||||
out = self.relu1(self.bn1(x))
|
||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, training=self.training)
|
||||
out = self.conv2(out)
|
||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out)
|
||||
|
||||
class NetworkBlock(nn.Module):
|
||||
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
|
||||
super(NetworkBlock, self).__init__()
|
||||
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
|
||||
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
|
||||
layers = []
|
||||
for i in range(int(nb_layers)):
|
||||
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
|
||||
return nn.Sequential(*layers)
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
#wrn_size: 32 = WRN-28-2 ? 160 = WRN-28-10
|
||||
class WideResNet(nn.Module):
|
||||
#def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
|
||||
def __init__(self, num_classes, wrn_size, depth=28, dropRate=0.0):
|
||||
super(WideResNet, self).__init__()
|
||||
|
||||
self.kernel_size = wrn_size
|
||||
self.depth=depth
|
||||
filter_size = 3
|
||||
nChannels = [min(self.kernel_size, 16), self.kernel_size, self.kernel_size * 2, self.kernel_size * 4]
|
||||
strides = [1, 2, 2] # stride for each resblock
|
||||
|
||||
#nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
|
||||
assert((depth - 4) % 6 == 0)
|
||||
n = (depth - 4) / 6
|
||||
block = BasicBlock
|
||||
# 1st conv before any network block
|
||||
self.conv1 = nn.Conv2d(filter_size, nChannels[0], kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
# 1st block
|
||||
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, strides[0], dropRate)
|
||||
# 2nd block
|
||||
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, strides[1], dropRate)
|
||||
# 3rd block
|
||||
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, strides[2], dropRate)
|
||||
# global average pooling and classifier
|
||||
self.bn1 = nn.BatchNorm2d(nChannels[3])
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc = nn.Linear(nChannels[3], num_classes)
|
||||
self.nChannels = nChannels[3]
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.block1(out)
|
||||
out = self.block2(out)
|
||||
out = self.block3(out)
|
||||
out = self.relu(self.bn1(out))
|
||||
out = F.avg_pool2d(out, 8)
|
||||
out = out.view(-1, self.nChannels)
|
||||
return self.fc(out)
|
||||
|
||||
def architecture(self):
|
||||
return super(WideResNet, self).__str__()
|
||||
|
||||
def __str__(self):
|
||||
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)
|
||||
'''
|
150
higher/smart_aug/old/test_lr.py
Executable file
150
higher/smart_aug/old/test_lr.py
Executable file
|
@ -0,0 +1,150 @@
|
|||
import numpy as np
|
||||
import json, math, time, os
|
||||
|
||||
from torch.utils.data import SubsetRandomSampler
|
||||
import torch.optim as optim
|
||||
import higher
|
||||
from model import *
|
||||
|
||||
import copy
|
||||
|
||||
BATCH_SIZE = 300
|
||||
TEST_SIZE = 300
|
||||
|
||||
mnist_train = torchvision.datasets.MNIST(
|
||||
"./data", train=True, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
#torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0),
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
)
|
||||
|
||||
mnist_test = torchvision.datasets.MNIST(
|
||||
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
|
||||
)
|
||||
|
||||
#train_subset_indices=range(int(len(mnist_train)/2))
|
||||
train_subset_indices=range(BATCH_SIZE)
|
||||
val_subset_indices=range(int(len(mnist_train)/2),len(mnist_train))
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
dl_val = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
|
||||
|
||||
def test(model):
|
||||
model.eval()
|
||||
for i, (features, labels) in enumerate(dl_test):
|
||||
pred = model.forward(features)
|
||||
return pred.argmax(dim=1).eq(labels).sum().item() / TEST_SIZE * 100
|
||||
|
||||
def train_classic(model, optim, epochs=1):
|
||||
model.train()
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
|
||||
optim.zero_grad()
|
||||
pred = model.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
#### Log ####
|
||||
tf = time.process_time()
|
||||
data={
|
||||
"time": tf - t0,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
times = [x["time"] for x in log]
|
||||
print("Vanilla : acc", test(model), "in (ms):", np.mean(times), "+/-", np.std(times))
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
model = LeNet(1,10)
|
||||
opt_param = {
|
||||
"lr": torch.tensor(1e-2).requires_grad_(),
|
||||
"momentum": torch.tensor(0.9).requires_grad_()
|
||||
}
|
||||
n_inner_iter = 1
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
epoch = 0
|
||||
epochs = 10
|
||||
|
||||
####
|
||||
train_classic(model=model, optim=torch.optim.Adam(model.parameters(), lr=0.001), epochs=epochs)
|
||||
model = LeNet(1,10)
|
||||
|
||||
meta_opt = torch.optim.Adam(opt_param.values(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(model.parameters(), lr=opt_param['lr'], momentum=opt_param['momentum'])
|
||||
#for xs_val, ys_val in dl_val:
|
||||
while epoch < epochs:
|
||||
#print(data_aug.params["mag"], data_aug.params["mag"].grad)
|
||||
meta_opt.zero_grad()
|
||||
model.train()
|
||||
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for param_group in diffopt.param_groups:
|
||||
param_group['lr'] = opt_param['lr']
|
||||
param_group['momentum'] = opt_param['momentum']
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
print('Epoch', epoch)
|
||||
print('train loss',loss.item(), '/ val loss', val_loss.item())
|
||||
print('acc', test(model))
|
||||
print('opt : lr', opt_param['lr'].item(), 'momentum', opt_param['momentum'].item())
|
||||
print('-'*9)
|
||||
model.train()
|
||||
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
|
||||
#print('loss',loss.item())
|
||||
diffopt.step(loss) # note that `step` must take `loss` as an argument!
|
||||
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
|
||||
# these new parameters, as an alternative to getting them from
|
||||
# `fmodel.fast_params` or `fmodel.parameters()` after calling
|
||||
# `diffopt.step`.
|
||||
|
||||
# At this point, or at any point in the iteration, you can take the
|
||||
# gradient of `fmodel.parameters()` (or equivalently
|
||||
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
|
||||
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
|
||||
# `grad_fn` as an attribute, and be part of the gradient tape.
|
||||
|
||||
# At the end of your inner loop you can obtain these e.g. ...
|
||||
#grad_of_grads = torch.autograd.grad(
|
||||
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val_it)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
|
||||
val_logits = fmodel(xs_val)
|
||||
val_loss = F.cross_entropy(val_logits, ys_val)
|
||||
#print('val_loss',val_loss.item())
|
||||
|
||||
val_loss.backward()
|
||||
#meta_grads = torch.autograd.grad(val_loss, opt_lr, allow_unused=True)
|
||||
#print(meta_grads)
|
||||
for param_group in diffopt.param_groups:
|
||||
print(param_group['lr'], '/',param_group['lr'].grad)
|
||||
print(param_group['momentum'], '/',param_group['momentum'].grad)
|
||||
|
||||
#model=copy.deepcopy(fmodel)
|
||||
model.load_state_dict(fmodel.state_dict())
|
||||
|
||||
meta_opt.step()
|
866
higher/smart_aug/old/train_utils_old.py
Normal file
866
higher/smart_aug/old/train_utils_old.py
Normal file
|
@ -0,0 +1,866 @@
|
|||
import torch
|
||||
#import torch.optim
|
||||
import torchvision
|
||||
import higher
|
||||
|
||||
from datasets import *
|
||||
from utils import *
|
||||
|
||||
def train_classic_higher(model, epochs=1):
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt):
|
||||
|
||||
for epoch in range(epochs):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
#print("Fast param ",len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
#optim.zero_grad()
|
||||
logits = model.forward(features)
|
||||
pred = F.log_softmax(logits, dim=1)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
#.backward()
|
||||
#optim.step()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
optim_copy(dopt=diffopt, opt=optim)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
return log
|
||||
|
||||
def train_classic_tests(model, epochs=1):
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
countcopy=0
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
doptim = higher.optim.get_diff_optim(optim, model.parameters(), fmodel=fmodel, track_higher_grads=False)
|
||||
for epoch in range(epochs):
|
||||
print_torch_mem("Start epoch")
|
||||
print(len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=True) as (fmodel, doptim):
|
||||
|
||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
|
||||
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, doptim):
|
||||
|
||||
|
||||
#optim.zero_grad()
|
||||
pred = fmodel.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
doptim.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
#loss.backward()
|
||||
#new_params = doptim.step(loss, params=fmodel.parameters())
|
||||
#fmodel.update_params(new_params)
|
||||
|
||||
|
||||
#print('Fast param',len(fmodel._fast_params))
|
||||
#print('opt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][2]['momentum_buffer'].shape)
|
||||
|
||||
if False or (len(fmodel._fast_params)>1):
|
||||
print("fmodel fast param",len(fmodel._fast_params))
|
||||
'''
|
||||
#val_loss = F.cross_entropy(fmodel(features), labels)
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
#val_loss.backward()
|
||||
#print('bip')
|
||||
|
||||
tmp = fmodel.parameters()
|
||||
|
||||
#print(list(tmp)[1])
|
||||
tmp = [higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in tmp]
|
||||
#print(len(tmp))
|
||||
|
||||
#fmodel._fast_params.clear()
|
||||
del fmodel._fast_params
|
||||
fmodel._fast_params=None
|
||||
|
||||
fmodel.fast_params=tmp # Surcharge la memoire
|
||||
#fmodel.update_params(tmp) #Meilleur perf / Surcharge la memoire avec trach higher grad
|
||||
|
||||
#optim._fmodel=fmodel
|
||||
'''
|
||||
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#doptim.detach_dyn()
|
||||
#tmp = doptim.state
|
||||
#tmp = doptim.state_dict()
|
||||
#for k, v in tmp['state'].items():
|
||||
# print('dict',k, type(v))
|
||||
|
||||
a = optim.param_groups[0]['params'][0]
|
||||
state = optim.state[a]
|
||||
#state['momentum_buffer'] = None
|
||||
#print('opt state', type(optim.state[a]), len(optim.state[a]))
|
||||
#optim.load_state_dict(tmp)
|
||||
|
||||
|
||||
for group_idx, group in enumerate(optim.param_groups):
|
||||
# print('gp idx',group_idx)
|
||||
for p_idx, p in enumerate(group['params']):
|
||||
optim.state[p]=doptim.state[group_idx][p_idx]
|
||||
|
||||
#print('opt state', type(optim.state[a]['momentum_buffer']), optim.state[a]['momentum_buffer'][0:10])
|
||||
#print('dopt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][0]['momentum_buffer'][0:10])
|
||||
'''
|
||||
for a in tmp:
|
||||
#print(type(a), len(a))
|
||||
for nb, b in a.items():
|
||||
#print(nb, type(b), len(b))
|
||||
for n, state in b.items():
|
||||
#print(n, type(states))
|
||||
#print(state.grad_fn)
|
||||
state = torch.tensor(state.data).requires_grad_()
|
||||
#print(state.grad_fn)
|
||||
'''
|
||||
|
||||
|
||||
doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
|
||||
#doptim.state = tmp
|
||||
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
#countcopy+=1
|
||||
#model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
#optim.load_state_dict(doptim.state_dict()) #Besoin sauver etat otpim ?
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
from PIL import Image
|
||||
import augmentation_transforms
|
||||
import numpy as np
|
||||
class AugmentedDatasetV2(VisionDataset):
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, subset=None):
|
||||
|
||||
super(AugmentedDatasetV2, self).__init__(root, transform=transform, target_transform=target_transform)
|
||||
|
||||
supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform)
|
||||
|
||||
self.sup_data = supervised_dataset.data if not subset else supervised_dataset.data[subset[0]:subset[1]]
|
||||
self.sup_targets = supervised_dataset.targets if not subset else supervised_dataset.targets[subset[0]:subset[1]]
|
||||
assert len(self.sup_data)==len(self.sup_targets)
|
||||
|
||||
for idx, img in enumerate(self.sup_data):
|
||||
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
|
||||
|
||||
self.unsup_data=[]
|
||||
self.unsup_targets=[]
|
||||
self.origin_idx=[]
|
||||
|
||||
self.dataset_info= {
|
||||
'name': 'CIFAR10',
|
||||
'sup': len(self.sup_data),
|
||||
'unsup': len(self.unsup_data),
|
||||
'length': len(self.sup_data)+len(self.unsup_data),
|
||||
}
|
||||
|
||||
|
||||
self._TF = [
|
||||
## Geometric TF ##
|
||||
'Rotate',
|
||||
'TranslateX',
|
||||
'TranslateY',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
|
||||
'Cutout',
|
||||
|
||||
## Color TF ##
|
||||
'Contrast',
|
||||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'Posterize',
|
||||
'Solarize',
|
||||
|
||||
'Invert',
|
||||
'AutoContrast',
|
||||
'Equalize',
|
||||
]
|
||||
self._op_list =[]
|
||||
self.prob=0.5
|
||||
self.mag_range=(1, 10)
|
||||
for tf in self._TF:
|
||||
for mag in range(self.mag_range[0], self.mag_range[1]):
|
||||
self._op_list+=[(tf, self.prob, mag)]
|
||||
self._nb_op = len(self._op_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
aug_img, origin_img, target = self.unsup_data[index], self.sup_data[self.origin_idx[index]], self.unsup_targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
#img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
aug_img = self.transform(aug_img)
|
||||
origin_img = self.transform(origin_img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return aug_img, origin_img, target
|
||||
|
||||
def augement_data(self, aug_copy=1):
|
||||
|
||||
policies = []
|
||||
for op_1 in self._op_list:
|
||||
for op_2 in self._op_list:
|
||||
policies += [[op_1, op_2]]
|
||||
|
||||
for idx, image in enumerate(self.sup_data):
|
||||
if idx%(self.dataset_info['sup']/5)==0: print("Augmenting data... ", idx,"/", self.dataset_info['sup'])
|
||||
#if idx==10000:break
|
||||
|
||||
for _ in range(aug_copy):
|
||||
chosen_policy = policies[np.random.choice(len(policies))]
|
||||
aug_image = augmentation_transforms.apply_policy(chosen_policy, image, use_mean_std=False) #Cast en float image
|
||||
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
|
||||
|
||||
self.unsup_data+=[(aug_image*255.).astype(self.sup_data.dtype)]#Cast float image to uint8
|
||||
self.unsup_targets+=[self.sup_targets[idx]]
|
||||
self.origin_idx+=[idx]
|
||||
|
||||
#self.unsup_data=(np.array(self.unsup_data)*255.).astype(self.sup_data.dtype) #Cast float image to uint8
|
||||
self.unsup_data=np.array(self.unsup_data)
|
||||
|
||||
assert len(self.unsup_data)==len(self.unsup_targets)
|
||||
|
||||
self.dataset_info['unsup']=len(self.unsup_data)
|
||||
self.dataset_info['length']=self.dataset_info['sup']+self.dataset_info['unsup']
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_info['unsup']#self.dataset_info['length']
|
||||
|
||||
def __str__(self):
|
||||
return "CIFAR10(Sup:{}-Unsup:{}-{}TF(Mag{}-{}))".format(self.dataset_info['sup'], self.dataset_info['unsup'], len(self._TF), self.mag_range[0], self.mag_range[1])
|
||||
|
||||
def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1):
|
||||
"""Training of a model using UDA inspired approach.
|
||||
|
||||
Intended to be used alongside an already augmented dataset (see AugmentedDatasetV2).
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to train.
|
||||
dl_unsup (Dataloader): Data loader of unsupervised/augmented data.
|
||||
opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
epochs (int): Number of epochs to perform. (default: 1)
|
||||
print_freq (int): Number of epoch between display of the state of training. If set to None, no display will be done. (default:1)
|
||||
|
||||
Returns:
|
||||
(list) Logs of training. Each items is a dict containing results of an epoch.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
dl_unsup_it =iter(dl_unsup)
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
#print_torch_mem("Start epoch")
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
optim.zero_grad()
|
||||
#Supervised
|
||||
logits = model.forward(features)
|
||||
pred = F.log_softmax(logits, dim=1)
|
||||
sup_loss = F.cross_entropy(pred,labels)
|
||||
|
||||
#Unsupervised
|
||||
try:
|
||||
aug_xs, origin_xs, ys = next(dl_unsup_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_unsup_it =iter(dl_unsup)
|
||||
aug_xs, origin_xs, ys = next(dl_unsup_it)
|
||||
aug_xs, origin_xs, ys = aug_xs.to(device), origin_xs.to(device), ys.to(device)
|
||||
|
||||
#print(aug_xs.shape, origin_xs.shape, ys.shape)
|
||||
sup_logits = model.forward(origin_xs)
|
||||
unsup_logits = model.forward(aug_xs)
|
||||
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
log_unsup=F.log_softmax(unsup_logits, dim=1)
|
||||
#KL div w/ logits
|
||||
unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup)
|
||||
unsup_loss=unsup_loss.sum(dim=-1).mean()
|
||||
|
||||
#print(unsup_loss)
|
||||
unsupp_coeff = 1
|
||||
loss = sup_loss + unsup_loss * unsupp_coeff
|
||||
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Sup Loss :', sup_loss.item(), '/ unsup_loss :', unsup_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
return log
|
||||
|
||||
|
||||
def run_simple_dataug(inner_it, epochs=1):
|
||||
device = next(model.parameters()).device
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
#aug_model = nn.Sequential(
|
||||
# Data_aug(),
|
||||
# LeNet(1,10),
|
||||
# )
|
||||
aug_model = Augmented_model(Data_aug(), LeNet(1,10)).to(device)
|
||||
print(str(aug_model))
|
||||
meta_opt = torch.optim.Adam(aug_model['data_aug'].parameters(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
log = []
|
||||
t0 = time.process_time()
|
||||
|
||||
epoch = 0
|
||||
while epoch < epochs:
|
||||
meta_opt.zero_grad()
|
||||
aug_model.train()
|
||||
with higher.innerloop_ctx(aug_model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
tf = time.process_time()
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
accuracy, _ =test(model)
|
||||
aug_model.train()
|
||||
|
||||
#### Print ####
|
||||
print('-'*9)
|
||||
print('Epoch %d/%d'%(epoch,epochs))
|
||||
print('train loss',loss.item(), '/ val loss', val_loss.item())
|
||||
print('acc', accuracy)
|
||||
print('mag', aug_model['data_aug']['mag'].item())
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": aug_model['data_aug']['mag'].item(),
|
||||
}
|
||||
log.append(data)
|
||||
t0 = time.process_time()
|
||||
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
|
||||
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('mag', fmodel['data_aug']['mag'].grad)
|
||||
|
||||
diffopt.step(loss) # note that `step` must take `loss` as an argument!
|
||||
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
|
||||
# these new parameters, as an alternative to getting them from
|
||||
# `fmodel.fast_params` or `fmodel.parameters()` after calling
|
||||
# `diffopt.step`.
|
||||
|
||||
# At this point, or at any point in the iteration, you can take the
|
||||
# gradient of `fmodel.parameters()` (or equivalently
|
||||
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
|
||||
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
|
||||
# `grad_fn` as an attribute, and be part of the gradient tape.
|
||||
|
||||
# At the end of your inner loop you can obtain these e.g. ...
|
||||
#grad_of_grads = torch.autograd.grad(
|
||||
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
fmodel.augment(mode=False)
|
||||
val_logits = fmodel(xs_val) #Validation sans transfornations !
|
||||
val_loss = F.cross_entropy(val_logits, ys_val)
|
||||
#print('val_loss',val_loss.item())
|
||||
val_loss.backward()
|
||||
|
||||
#print('mag', fmodel['data_aug']['mag'], '/', fmodel['data_aug']['mag'].grad)
|
||||
|
||||
#model=copy.deepcopy(fmodel)
|
||||
aug_model.load_state_dict(fmodel.state_dict()) #Do not copy gradient !
|
||||
#Copie des gradients
|
||||
for paramName, paramValue, in fmodel.named_parameters():
|
||||
for netCopyName, netCopyValue, in aug_model.named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
|
||||
#print('mag', aug_model['data_aug']['mag'], '/', aug_model['data_aug']['mag'].grad)
|
||||
meta_opt.step()
|
||||
|
||||
plot_res(log, fig_name="res/{}-{} epochs- {} in_it".format(str(aug_model),epochs,inner_it))
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
print(str(aug_model),": acc", max([x["acc"] for x in log]), "in (ms):", np.mean(times), "+/-", np.std(times))
|
||||
|
||||
def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
||||
device = next(model.parameters()).device
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-3)
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
high_grad_track = True
|
||||
if dataug_epoch_start>0:
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
model.train()
|
||||
|
||||
log = []
|
||||
t0 = time.process_time()
|
||||
|
||||
countcopy=0
|
||||
val_loss=torch.tensor(0)
|
||||
opt_param=None
|
||||
|
||||
epoch = 0
|
||||
while epoch < epochs:
|
||||
meta_opt.zero_grad()
|
||||
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
tf = time.process_time()
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
accuracy, _ =test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',aug_model['data_aug']['prob'].grad)
|
||||
#############
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": [p for p in model['data_aug']['prob']],
|
||||
}
|
||||
log.append(data)
|
||||
#############
|
||||
|
||||
if epoch == dataug_epoch_start:
|
||||
print('Starting Data Augmention...')
|
||||
model.augment(mode=True)
|
||||
high_grad_track = True
|
||||
|
||||
t0 = time.process_time()
|
||||
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
'''
|
||||
#Methode exacte
|
||||
final_loss = 0
|
||||
for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
fmodel['data_aug'].transf_idx=tf_idx
|
||||
logits = fmodel(xs)
|
||||
loss = F.cross_entropy(logits, ys)
|
||||
#loss.backward(retain_graph=True)
|
||||
#print('idx', tf_idx)
|
||||
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
|
||||
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
|
||||
loss = final_loss
|
||||
'''
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight().to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
#'''
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
fmodel.augment(mode=False) #Validation sans transfornations !
|
||||
val_loss = F.cross_entropy(fmodel(xs_val), ys_val)
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param() #Contrainte sum(proba)=1
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
||||
device = next(model.parameters()).device
|
||||
log = []
|
||||
countcopy=0
|
||||
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
high_grad_track=False
|
||||
if dataug_epoch_start!=0:
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
val_loss_monitor= None
|
||||
if loss_patience != None :
|
||||
if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start
|
||||
else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor (Not on val data : used by Dataug... => Test data)
|
||||
|
||||
model.train()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
meta_opt.zero_grad()
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
#print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
#with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt):
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
#Methode exacte
|
||||
#final_loss = 0
|
||||
#for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
# fmodel['data_aug'].transf_idx=tf_idx
|
||||
# logits = fmodel(xs)
|
||||
# loss = F.cross_entropy(logits, ys)
|
||||
# #loss.backward(retain_graph=True)
|
||||
# final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
#loss = final_loss
|
||||
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(F.log_softmax(logits, dim=1), ys, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
if fmodel._data_augmentation :
|
||||
fmodel.augment(mode=False)
|
||||
sup_logits = fmodel(xs)
|
||||
fmodel.augment(mode=True)
|
||||
else:
|
||||
sup_logits = fmodel(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if fmodel._data_augmentation:
|
||||
aug_logits = fmodel(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
|
||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||
|
||||
#if epoch>50: #debut differe ?
|
||||
#KL div w/ logits - Similarite predictions (distributions)
|
||||
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||
aug_loss = aug_loss.sum(dim=-1)
|
||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
|
||||
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
|
||||
unsupp_coeff = 1
|
||||
loss += aug_loss * unsupp_coeff
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
#t = time.process_time()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
#print(len(fmodel._fast_params),"step", time.process_time()-t)
|
||||
|
||||
if(high_grad_track and i>0 and i%inner_it==0): #Perform Meta step
|
||||
#print("meta")
|
||||
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) #+ fmodel['data_aug'].reg_loss()
|
||||
#print_graph(val_loss)
|
||||
|
||||
#t = time.process_time()
|
||||
val_loss.backward()
|
||||
#print("meta", time.process_time()-t)
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
if model['data_aug']['prob'].grad is None or model['data_aug']['mag'] is None:
|
||||
print("Warning no grad (iter",i,") :\n Prob-",model['data_aug']['prob'].grad,"\n Mag-", model['data_aug']['mag'].grad)
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model['data_aug'].parameters(), max_norm=10, norm_type=2) #Prevent exploding grad with RNN
|
||||
|
||||
#if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
try: #Dataugv6
|
||||
model['data_aug'].next_TF_set()
|
||||
except:
|
||||
pass
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
meta_opt.zero_grad()
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
|
||||
if(not high_grad_track):
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
accuracy, test_loss =test(model)
|
||||
model.train()
|
||||
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param #if isinstance(model['data_aug'], Data_augV5)
|
||||
#else [p.item() for p in model['data_aug']['prob']],
|
||||
}
|
||||
log.append(data)
|
||||
#############
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', max([x["acc"] for x in log]))
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
#print('Aug loss', aug_loss.item())
|
||||
#############
|
||||
if val_loss_monitor :
|
||||
model.eval()
|
||||
val_loss_monitor.register(test_loss)#val_loss.item())
|
||||
if val_loss_monitor.end_training(): break #Stop training
|
||||
model.train()
|
||||
|
||||
if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)):
|
||||
print('Starting Data Augmention...')
|
||||
dataug_epoch_start = epoch
|
||||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
try:
|
||||
viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
viz_sample_data(imgs=model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch), weight_labels=model['data_aug'].loss_weight())
|
||||
except:
|
||||
print("Couldn't save finals samples")
|
||||
pass
|
||||
|
||||
#print("Copy ", countcopy)
|
||||
return log
|
161
higher/smart_aug/old/utils_old.py
Normal file
161
higher/smart_aug/old/utils_old.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
import numpy as np
|
||||
import json, math, time, os
|
||||
import matplotlib.pyplot as plt
|
||||
import copy
|
||||
import gc
|
||||
|
||||
from torchviz import make_dot
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import time
|
||||
|
||||
class timer():
|
||||
def __init__(self):
|
||||
self._start_time=time.time()
|
||||
def exec_time(self):
|
||||
end = time.time()
|
||||
res = end-self._start_time
|
||||
self._start_time=end
|
||||
return res
|
||||
|
||||
def plot_res(log, fig_name='res', param_names=None):
|
||||
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
|
||||
|
||||
ax[0].set_title('Loss')
|
||||
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||
ax[0].legend()
|
||||
|
||||
ax[1].set_title('Acc')
|
||||
ax[1].plot(epochs,[x["acc"] for x in log])
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if isinstance(log[0]["param"],float):
|
||||
ax[2].set_title('Mag')
|
||||
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
|
||||
ax[2].legend()
|
||||
else :
|
||||
ax[2].set_title('Prob')
|
||||
#for idx, _ in enumerate(log[0]["param"]):
|
||||
#ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
|
||||
if not param_names : param_names = ['P'+str(idx) for idx, _ in enumerate(log[0]["param"])]
|
||||
proba=[[x["param"][idx] for x in log] for idx, _ in enumerate(log[0]["param"])]
|
||||
ax[2].stackplot(epochs, proba, labels=param_names)
|
||||
ax[2].legend(param_names, loc='center left', bbox_to_anchor=(1, 0.5))
|
||||
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name)
|
||||
plt.close()
|
||||
|
||||
def plot_res_compare(filenames, fig_name='res'):
|
||||
|
||||
all_data=[]
|
||||
#legend=""
|
||||
for idx, file in enumerate(filenames):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
all_data.append(data)
|
||||
|
||||
n_tf = [len(x["Param_names"]) for x in all_data]
|
||||
acc = [x["Accuracy"] for x in all_data]
|
||||
time = [x["Time"][0] for x in all_data]
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||||
|
||||
ax[0].plot(n_tf, acc)
|
||||
ax[1].plot(n_tf, time)
|
||||
|
||||
ax[0].set_title('Acc')
|
||||
ax[1].set_title('Time')
|
||||
#for a in ax: a.legend()
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def plot_TF_res(log, tf_names, fig_name='res'):
|
||||
|
||||
mean = np.mean([x["param"] for x in log], axis=0)
|
||||
std = np.std([x["param"] for x in log], axis=0)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(30, 8), sharey=True)
|
||||
ax.bar(tf_names, mean, yerr=std)
|
||||
#ax.bar(tf_names, log[-1]["param"])
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def model_copy(src,dst, patch_copy=True, copy_grad=True):
|
||||
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
|
||||
|
||||
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
|
||||
|
||||
if patch_copy:
|
||||
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
|
||||
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
|
||||
|
||||
#Copie des gradients
|
||||
if copy_grad:
|
||||
for paramName, paramValue, in src.named_parameters():
|
||||
for netCopyName, netCopyValue, in dst.named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
#netCopyValue=copy.deepcopy(paramValue)
|
||||
|
||||
try: #Data_augV4
|
||||
dst['data_aug']._input_info = src['data_aug']._input_info
|
||||
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
|
||||
except:
|
||||
pass
|
||||
|
||||
def optim_copy(dopt, opt):
|
||||
|
||||
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
|
||||
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
|
||||
|
||||
for group_idx, group in enumerate(opt.param_groups):
|
||||
# print('gp idx',group_idx)
|
||||
for p_idx, p in enumerate(group['params']):
|
||||
opt.state[p]=dopt.state[group_idx][p_idx]
|
||||
|
||||
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||
def __init__(self, patience, end_train=1):
|
||||
self.patience = patience
|
||||
self.end_train = end_train
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.reached_limit = 0
|
||||
|
||||
def register(self, loss):
|
||||
if self.best_score is None:
|
||||
self.best_score = loss
|
||||
elif loss > self.best_score:
|
||||
self.counter += 1
|
||||
#if not self.reached_limit:
|
||||
print("loss no improve counter", self.counter, self.reached_limit)
|
||||
else:
|
||||
self.best_score = loss
|
||||
self.counter = 0
|
||||
def limit_reached(self):
|
||||
if self.counter >= self.patience:
|
||||
self.counter = 0
|
||||
self.reached_limit +=1
|
||||
self.best_score = None
|
||||
return self.reached_limit
|
||||
|
||||
def end_training(self):
|
||||
if self.limit_reached() >= self.end_train:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
self.__init__(self.patience, self.end_train)
|
Loading…
Add table
Add a link
Reference in a new issue