Initial Commit
583
higher/dataug.py
Normal file
|
@ -0,0 +1,583 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import *
|
||||
|
||||
#import kornia
|
||||
#import random
|
||||
#import numpy as np
|
||||
import copy
|
||||
|
||||
import transformations as TF
|
||||
|
||||
class Data_aug(nn.Module): #Rotation parametree
|
||||
def __init__(self):
|
||||
super(Data_aug, self).__init__()
|
||||
self._data_augmentation = True
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.tensor(0.5)),
|
||||
"mag": nn.Parameter(torch.tensor(1.0))
|
||||
})
|
||||
|
||||
#self.params["mag"].register_hook(print)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self._data_augmentation and random.random() < self._params["prob"]:
|
||||
#print('Aug')
|
||||
batch_size = x.shape[0]
|
||||
# create transformation (rotation)
|
||||
alpha = self._params["mag"]*180 # in degrees
|
||||
angle = torch.ones(batch_size, device=x.device) * alpha
|
||||
|
||||
# define the rotation center
|
||||
center = torch.ones(batch_size, 2, device=x.device)
|
||||
center[..., 0] = x.shape[3] / 2 # x
|
||||
center[..., 1] = x.shape[2] / 2 # y
|
||||
|
||||
#print(x.shape, center)
|
||||
# define the scale factor
|
||||
scale = torch.ones(batch_size, device=x.device)
|
||||
|
||||
# compute the transformation matrix
|
||||
M = kornia.get_rotation_matrix2d(center, angle, scale)
|
||||
|
||||
# apply the transformation to original image
|
||||
x = kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w)
|
||||
|
||||
return x
|
||||
|
||||
def eval(self):
|
||||
self.augment(mode=False)
|
||||
nn.Module.eval(self)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Data_aug(Mag-1 TF)"
|
||||
|
||||
class Data_augV2(nn.Module): #Methode exacte
|
||||
def __init__(self):
|
||||
super(Data_augV2, self).__init__()
|
||||
self._data_augmentation = True
|
||||
|
||||
self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation
|
||||
#self._fixed_transf=[0.0]
|
||||
self._nb_tf= len(self._fixed_transf)
|
||||
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
#"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0))
|
||||
})
|
||||
|
||||
#print(self._params["prob"], self._params["prob2"])
|
||||
|
||||
self.transf_idx=0
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self._data_augmentation:
|
||||
#print('Aug',self._fixed_transf[self.transf_idx])
|
||||
device = x.device
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# create transformation (rotation)
|
||||
#alpha = 180 # in degrees
|
||||
alpha = self._fixed_transf[self.transf_idx]
|
||||
angle = torch.ones(batch_size, device=device) * alpha
|
||||
|
||||
x = self.rotate(x,angle)
|
||||
|
||||
return x
|
||||
|
||||
def rotate(self, x, angle):
|
||||
|
||||
device = x.device
|
||||
batch_size = x.shape[0]
|
||||
# define the rotation center
|
||||
center = torch.ones(batch_size, 2, device=device)
|
||||
center[..., 0] = x.shape[3] / 2 # x
|
||||
center[..., 1] = x.shape[2] / 2 # y
|
||||
|
||||
#print(x.shape, center)
|
||||
# define the scale factor
|
||||
scale = torch.ones(batch_size, device=device)
|
||||
|
||||
# compute the transformation matrix
|
||||
M = kornia.get_rotation_matrix2d(center, angle, scale)
|
||||
|
||||
# apply the transformation to original image
|
||||
return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w)
|
||||
|
||||
|
||||
def adjust_prob(self): #Detach from gradient ?
|
||||
self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
#print('proba',self._params['prob'])
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
#print('Sum p', sum(self._params['prob']))
|
||||
|
||||
def eval(self):
|
||||
self.augment(mode=False)
|
||||
nn.Module.eval(self)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Data_augV2(Exact-%d TF)" % self._nb_tf
|
||||
|
||||
class Data_augV3(nn.Module): #Echantillonage uniforme/Mixte
|
||||
def __init__(self, mix_dist=0.0):
|
||||
super(Data_augV3, self).__init__()
|
||||
self._data_augmentation = True
|
||||
|
||||
#self._fixed_transf=[0.0, 45.0, 180.0] #Degree rotation
|
||||
self._fixed_transf=[0.0, 1.0, -1.0] #Flips (Identity,Horizontal,Vertical)
|
||||
#self._fixed_transf=[0.0]
|
||||
self._nb_tf= len(self._fixed_transf)
|
||||
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
#"prob2": nn.Parameter(torch.ones(len(self._fixed_transf)).softmax(dim=0))
|
||||
})
|
||||
|
||||
#print(self._params["prob"], self._params["prob2"])
|
||||
self._sample = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size = x.shape[0]
|
||||
|
||||
|
||||
#good_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1),self._params["prob"], device=device))
|
||||
#bad_distrib = Uniform(low=torch.zeros(batch_size,1, device=device),high=torch.new_full((batch_size,1), 1-self._params["prob"], device=device))
|
||||
|
||||
#transform_dist = Categorical(probs=torch.tensor([self._params["prob"], 1-self._params["prob"]], device=device))
|
||||
#self._sample = transform_dist._sample(sample_shape=torch.Size([batch_size,1]))
|
||||
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=0)
|
||||
|
||||
if not self._mix_dist:
|
||||
distrib = uniforme_dist
|
||||
else:
|
||||
distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=0) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*distrib)
|
||||
self._sample = cat_distrib.sample()
|
||||
|
||||
TF_param = torch.tensor([self._fixed_transf[x] for x in self._sample], device=device) #Approche de marco peut-etre plus rapide
|
||||
|
||||
#x = self.rotate(x,angle=TF_param)
|
||||
x = self.flip(x,flip_mat=TF_param)
|
||||
|
||||
return x
|
||||
|
||||
def rotate(self, x, angle):
|
||||
|
||||
device = x.device
|
||||
batch_size = x.shape[0]
|
||||
# define the rotation center
|
||||
center = torch.ones(batch_size, 2, device=device)
|
||||
center[..., 0] = x.shape[3] / 2 # x
|
||||
center[..., 1] = x.shape[2] / 2 # y
|
||||
|
||||
#print(x.shape, center)
|
||||
# define the scale factor
|
||||
scale = torch.ones(batch_size, device=device)
|
||||
|
||||
# compute the transformation matrix
|
||||
M = kornia.get_rotation_matrix2d(center, angle, scale)
|
||||
|
||||
# apply the transformation to original image
|
||||
return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3])) #dsize=(h, w)
|
||||
|
||||
def flip(self, x, flip_mat):
|
||||
|
||||
#print(flip_mat)
|
||||
device = x.device
|
||||
batch_size = x.shape[0]
|
||||
|
||||
h, w = x.shape[2], x.shape[3] # destination size
|
||||
#points_src = torch.ones(batch_size, 4, 2, device=device)
|
||||
#points_dst = torch.ones(batch_size, 4, 2, device=device)
|
||||
|
||||
#Identity
|
||||
iM=torch.tensor(np.eye(3))
|
||||
|
||||
#Horizontal flip
|
||||
# the source points are the region to crop corners
|
||||
#points_src = torch.FloatTensor([[
|
||||
# [w - 1, 0], [0, 0], [0, h - 1], [w - 1, h - 1],
|
||||
#]])
|
||||
# the destination points are the image vertexes
|
||||
#points_dst = torch.FloatTensor([[
|
||||
# [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1],
|
||||
#]])
|
||||
# compute perspective transform
|
||||
#hM = kornia.get_perspective_transform(points_src, points_dst)
|
||||
hM =torch.tensor( [[[-1., 0., w-1],
|
||||
[ 0., 1., 0.],
|
||||
[ 0., 0., 1.]]])
|
||||
|
||||
#Vertical flip
|
||||
# the source points are the region to crop corners
|
||||
#points_src = torch.FloatTensor([[
|
||||
# [0, h - 1], [w - 1, h - 1], [w - 1, 0], [0, 0],
|
||||
#]])
|
||||
# the destination points are the image vertexes
|
||||
#points_dst = torch.FloatTensor([[
|
||||
# [0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1],
|
||||
#]])
|
||||
# compute perspective transform
|
||||
#vM = kornia.get_perspective_transform(points_src, points_dst)
|
||||
vM =torch.tensor( [[[ 1., 0., 0.],
|
||||
[ 0., -1., h-1],
|
||||
[ 0., 0., 1.]]])
|
||||
#print(vM)
|
||||
|
||||
M=torch.ones(batch_size, 3, 3, device=device)
|
||||
|
||||
for i in range(batch_size): # A optimiser
|
||||
if flip_mat[i]==0.0:
|
||||
M[i,]=iM
|
||||
elif flip_mat[i]==1.0:
|
||||
M[i,]=hM
|
||||
elif flip_mat[i]==-1.0:
|
||||
M[i,]=vM
|
||||
|
||||
# warp the original image by the found transform
|
||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
#self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
#print('proba',self._params['prob'])
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
#print('Sum p', sum(self._params['prob']))
|
||||
|
||||
def loss_weight(self):
|
||||
#w_loss = [self._params["prob"][x] for x in self._sample]
|
||||
#print(self._sample.view(-1,1).shape)
|
||||
#print(self._sample[:10])
|
||||
|
||||
w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||
w_loss.scatter_(1, self._sample.view(-1,1), 1)
|
||||
#print(w_loss.shape)
|
||||
#print(w_loss[:10,:])
|
||||
w_loss = w_loss * self._params["prob"]
|
||||
#print(w_loss.shape)
|
||||
#print(w_loss[:10,:])
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
#print(w_loss.shape)
|
||||
#print(w_loss[:10])
|
||||
return w_loss
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV3, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
#super(Augmented_model, self).eval()
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
if not self._mix_dist:
|
||||
return "Data_augV3(Uniform-%d TF)" % self._nb_tf
|
||||
else:
|
||||
return "Data_augV3(Mix %.1f-%d TF)" % (self._mix_factor, self._nb_tf)
|
||||
|
||||
class Data_augV4(nn.Module): #Transformations avec mask
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0):
|
||||
super(Data_augV4, self).__init__()
|
||||
self._data_augmentation = True
|
||||
|
||||
#self._TF_matrix={}
|
||||
#self._input_info={'h':0, 'w':0, 'device':None} #Input associe a TF_matrix
|
||||
'''
|
||||
self._mag_fct={ #f(mag_normalise)=mag_reelle
|
||||
## Geometric TF ##
|
||||
'Identity' : (lambda mag: None),
|
||||
'FlipUD' : (lambda mag: None),
|
||||
'FlipLR' : (lambda mag: None),
|
||||
'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))),
|
||||
'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]),
|
||||
'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]),
|
||||
'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]),
|
||||
'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))),
|
||||
'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))),
|
||||
'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
#Non fonctionnel
|
||||
'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
'''
|
||||
self._mag_fct = TF_dict
|
||||
self._TF=list(self._mag_fct.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._fixed_mag=5 #[0, PARAMETER_MAX]
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.ones(self._nb_tf)/self._nb_tf), #Distribution prob uniforme
|
||||
})
|
||||
|
||||
self._sample = []
|
||||
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
if self._data_augmentation:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,self._nb_tf,device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
self._distrib = (self._mix_factor*self._params["prob"]+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
print(self.distrib.shape)
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, self._nb_tf), device=device)*self._distrib)
|
||||
self._sample = cat_distrib.sample()
|
||||
|
||||
## Transformations ##
|
||||
#'''
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
smps_x=[]
|
||||
masks=[]
|
||||
for tf_idx in range(self._nb_tf):
|
||||
mask = self._sample==tf_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ?
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._fixed_mag
|
||||
tf=self._TF[tf_idx]
|
||||
|
||||
## Geometric TF ##
|
||||
if tf=='Identity':
|
||||
pass
|
||||
elif tf=='FlipLR':
|
||||
smp_x = TF.flipLR(smp_x)
|
||||
elif tf=='FlipUD':
|
||||
smp_x = TF.flipUD(smp_x)
|
||||
elif tf=='Rotate':
|
||||
smp_x = TF.rotate(smp_x, angle=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='TranslateX' or tf=='TranslateY':
|
||||
smp_x = TF.translate(smp_x, translation=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='ShearX' or tf=='ShearY' :
|
||||
smp_x = TF.shear(smp_x, shear=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
elif tf=='Contrast':
|
||||
smp_x = TF.contrast(smp_x, contrast_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Color':
|
||||
smp_x = TF.color(smp_x, color_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Brightness':
|
||||
smp_x = TF.brightness(smp_x, brightness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Sharpness':
|
||||
smp_x = TF.sharpeness(smp_x, sharpness_factor=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Posterize':
|
||||
smp_x = TF.posterize(smp_x, bits=torch.tensor([1 for _ in smp_x], device=device))
|
||||
elif tf=='Solarize':
|
||||
smp_x = TF.solarize(smp_x, thresholds=torch.tensor([self._mag_fct[tf](magnitude) for _ in smp_x], device=device))
|
||||
elif tf=='Equalize':
|
||||
smp_x = TF.equalize(smp_x)
|
||||
elif tf=='Auto_Contrast':
|
||||
smp_x = TF.auto_contrast(smp_x)
|
||||
else:
|
||||
raise Exception("Invalid TF requested : ", tf)
|
||||
|
||||
x[mask]=smp_x # Refusionner eviter x[mask] : in place
|
||||
|
||||
#idx= mask.nonzero()
|
||||
#print('-'*8)
|
||||
#print(idx[0], tf_idx)
|
||||
#print(smp_x[0,])
|
||||
#x=x.view(-1,3*32*32)
|
||||
#x=x.scatter(dim=0, index=idx, src=smp_x.view(-1,3*32*32)) #Changement des Tensor mais pas visible sur la visualisation...
|
||||
#x=x.view(-1,3,32,32)
|
||||
#print(x[0,])
|
||||
|
||||
'''
|
||||
if len(self._TF_matrix)==0 or self._input_info['h']!=h or self._input_info['w']!=w or self._input_info['device']!=device: #Device different:Pas necessaire de tout recalculer
|
||||
self.compute_TF_matrix(sample_info={'h': x.shape[2],
|
||||
'w': x.shape[3],
|
||||
'device': x.device})
|
||||
|
||||
TF_matrix = torch.zeros(batch_size, 3, 3, device=device) #All geom TF
|
||||
|
||||
for tf_idx in range(self._nb_tf):
|
||||
mask = self._sample==tf_idx #Create selection mask
|
||||
TF_matrix[mask,]=self._TF_matrix[self._TF[tf_idx]]
|
||||
|
||||
x=kornia.warp_perspective(x, TF_matrix, dsize=(h, w))
|
||||
'''
|
||||
return x
|
||||
'''
|
||||
def compute_TF_matrix(self, magnitude=None, sample_info= None):
|
||||
print('Computing TF_matrix...')
|
||||
if not magnitude :
|
||||
magnitude=self._fixed_mag
|
||||
|
||||
if sample_info:
|
||||
self._input_info['h']= sample_info['h']
|
||||
self._input_info['w']= sample_info['w']
|
||||
self._input_info['device'] = sample_info['device']
|
||||
h, w, device= self._input_info['h'], self._input_info['w'], self._input_info['device']
|
||||
|
||||
self._TF_matrix={}
|
||||
for tf in self._TF :
|
||||
if tf=='Id':
|
||||
self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.],
|
||||
[ 0., 1., 0.],
|
||||
[ 0., 0., 1.]]], device=device)
|
||||
elif tf=='Rot':
|
||||
center = torch.ones(1, 2, device=device)
|
||||
center[0, 0] = w / 2 # x
|
||||
center[0, 1] = h / 2 # y
|
||||
scale = torch.ones(1, device=device)
|
||||
angle = self._mag_fct[tf](magnitude) * torch.ones(1, device=device)
|
||||
R = kornia.get_rotation_matrix2d(center, angle, scale) #Rotation matrix (1,2,3)
|
||||
self._TF_matrix[tf]=torch.cat((R,torch.tensor([[[ 0., 0., 1.]]], device=device)), dim=1) #TF matrix (1,3,3)
|
||||
elif tf=='FlipLR':
|
||||
self._TF_matrix[tf]=torch.tensor([[[-1., 0., w-1],
|
||||
[ 0., 1., 0.],
|
||||
[ 0., 0., 1.]]], device=device)
|
||||
elif tf=='FlipUD':
|
||||
self._TF_matrix[tf]=torch.tensor([[[ 1., 0., 0.],
|
||||
[ 0., -1., h-1],
|
||||
[ 0., 0., 1.]]], device=device)
|
||||
else:
|
||||
raise Exception("Invalid TF requested")
|
||||
'''
|
||||
def adjust_prob(self, soft=False): #Detach from gradient ?
|
||||
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
#self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
|
||||
self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
def loss_weight(self):
|
||||
w_loss = torch.zeros((self._sample.shape[0],self._nb_tf), device=self._sample.device)
|
||||
w_loss.scatter_(1, self._sample.view(-1,1), 1)
|
||||
w_loss = w_loss * self._params["prob"]/self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV4, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
if not self._mix_dist:
|
||||
return "Data_augV4(Uniform-%d TF)" % self._nb_tf
|
||||
else:
|
||||
return "Data_augV4(Mix %.1f-%d TF)" % (self._mix_factor, self._nb_tf)
|
||||
|
||||
class Augmented_model(nn.Module):
|
||||
def __init__(self, data_augmenter, model):
|
||||
super(Augmented_model, self).__init__()
|
||||
|
||||
self._mods = nn.ModuleDict({
|
||||
'data_aug': data_augmenter,
|
||||
'model': model
|
||||
})
|
||||
|
||||
self.augment(mode=True)
|
||||
|
||||
def initialize(self):
|
||||
self._mods['model'].initialize()
|
||||
|
||||
def forward(self, x):
|
||||
return self._mods['model'](self._mods['data_aug'](x))
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self._mods['data_aug'].augment(mode)
|
||||
super(Augmented_model, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
#super(Augmented_model, self).eval()
|
||||
|
||||
def items(self):
|
||||
"""Return an iterable of the ModuleDict key/value pairs.
|
||||
"""
|
||||
return self._mods.items()
|
||||
|
||||
def update(self, modules):
|
||||
self._mods.update(modules)
|
||||
|
||||
def is_augmenting(self):
|
||||
return self._data_augmentation
|
||||
|
||||
def TF_names(self):
|
||||
try:
|
||||
return self._mods['data_aug']._TF
|
||||
except:
|
||||
return None
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._mods[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Aug_mod("+str(self._mods['data_aug'])+"-"+str(self._mods['model'])+")"
|
51
higher/model.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LeNet(nn.Module):
|
||||
def __init__(self, num_inp, num_out):
|
||||
super(LeNet, self).__init__()
|
||||
self._params = nn.ParameterDict({
|
||||
'w1': nn.Parameter(torch.zeros(20, num_inp, 5, 5)),
|
||||
'b1': nn.Parameter(torch.zeros(20)),
|
||||
'w2': nn.Parameter(torch.zeros(50, 20, 5, 5)),
|
||||
'b2': nn.Parameter(torch.zeros(50)),
|
||||
#'w3': nn.Parameter(torch.zeros(500,4*4*50)), #num_imp=1
|
||||
'w3': nn.Parameter(torch.zeros(500,5*5*50)), #num_imp=3
|
||||
'b3': nn.Parameter(torch.zeros(500)),
|
||||
'w4': nn.Parameter(torch.zeros(num_out, 500)),
|
||||
'b4': nn.Parameter(torch.zeros(num_out))
|
||||
})
|
||||
self.initialize()
|
||||
|
||||
|
||||
def initialize(self):
|
||||
nn.init.kaiming_uniform_(self._params["w1"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w2"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w3"], a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self._params["w4"], a=math.sqrt(5))
|
||||
|
||||
def forward(self, x):
|
||||
#print("Start Shape ", x.shape)
|
||||
out = F.relu(F.conv2d(input=x, weight=self._params["w1"], bias=self._params["b1"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.max_pool2d(out, 2)
|
||||
#print("Shape ", out.shape)
|
||||
out = F.relu(F.conv2d(input=out, weight=self._params["w2"], bias=self._params["b2"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.max_pool2d(out, 2)
|
||||
#print("Shape ", out.shape)
|
||||
out = out.view(out.size(0), -1)
|
||||
#print("Shape ", out.shape)
|
||||
out = F.relu(F.linear(out, self._params["w3"], self._params["b3"]))
|
||||
#print("Shape ", out.shape)
|
||||
out = F.linear(out, self._params["w4"], self._params["b4"])
|
||||
#print("Shape ", out.shape)
|
||||
return F.log_softmax(out, dim=1)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
return "LeNet"
|
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 118 KiB |
After Width: | Height: | Size: 45 KiB |
After Width: | Height: | Size: 55 KiB |
After Width: | Height: | Size: 65 KiB |
BIN
higher/res/LeNet-100 epochs.png
Normal file
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 47 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 41 KiB |
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 50 KiB |
After Width: | Height: | Size: 46 KiB |
After Width: | Height: | Size: 53 KiB |
BIN
higher/res/MNIST/LeNet-10 epochs.png
Normal file
After Width: | Height: | Size: 42 KiB |
764
higher/test_dataug.py
Normal file
|
@ -0,0 +1,764 @@
|
|||
from torch.utils.data import SubsetRandomSampler
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
import higher
|
||||
|
||||
from model import *
|
||||
from dataug import *
|
||||
from utils import *
|
||||
|
||||
BATCH_SIZE = 300
|
||||
#TEST_SIZE = 300
|
||||
TEST_SIZE = 10000
|
||||
|
||||
#ATTENTION : Dataug (Kornia) Expect image in the range of [0, 1]
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
#torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #CIFAR10
|
||||
])
|
||||
'''
|
||||
data_train = torchvision.datasets.MNIST(
|
||||
"./data", train=True, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
#torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0),
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
)
|
||||
data_test = torchvision.datasets.MNIST(
|
||||
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
|
||||
)
|
||||
'''
|
||||
data_train = torchvision.datasets.CIFAR10(
|
||||
"./data", train=True, download=True, transform=transform
|
||||
)
|
||||
data_test = torchvision.datasets.CIFAR10(
|
||||
"./data", train=False, download=True, transform=transform
|
||||
)
|
||||
#'''
|
||||
train_subset_indices=range(int(len(data_train)/2))
|
||||
#train_subset_indices=range(BATCH_SIZE*10)
|
||||
val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
if device == torch.device('cpu'):
|
||||
device_name = 'CPU'
|
||||
else:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
|
||||
def test(model):
|
||||
model.eval()
|
||||
for i, (features, labels) in enumerate(dl_test):
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
pred = model.forward(features)
|
||||
return pred.argmax(dim=1).eq(labels).sum().item() / TEST_SIZE * 100
|
||||
|
||||
def compute_vaLoss(model, dl_val_it):
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
try:
|
||||
model.augment(mode=False) #Validation sans transfornations !
|
||||
except:
|
||||
pass
|
||||
return F.cross_entropy(model(xs_val), ys_val)
|
||||
|
||||
def train_classic(model, epochs=1):
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
print_torch_mem("Start epoch")
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
optim.zero_grad()
|
||||
pred = model.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy=test(model)
|
||||
model.train()
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
return log
|
||||
|
||||
def train_classic_higher(model, epochs=1):
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, diffopt):
|
||||
|
||||
for epoch in range(epochs):
|
||||
print_torch_mem("Start epoch "+str(epoch))
|
||||
print("Fast param ",len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
#print_torch_mem("Start iter")
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
#optim.zero_grad()
|
||||
pred = fmodel.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
#.backward()
|
||||
#optim.step()
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
optim_copy(dopt=diffopt, opt=optim)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(optim, model.parameters(),fmodel=fmodel,track_higher_grads=False)
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy=test(model)
|
||||
model.train()
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
return log
|
||||
|
||||
def train_classic_tests(model, epochs=1):
|
||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
countcopy=0
|
||||
model.train()
|
||||
dl_val_it = iter(dl_val)
|
||||
log = []
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
doptim = higher.optim.get_diff_optim(optim, model.parameters(), fmodel=fmodel, track_higher_grads=False)
|
||||
for epoch in range(epochs):
|
||||
print_torch_mem("Start epoch")
|
||||
print(len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=True) as (fmodel, doptim):
|
||||
|
||||
#fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
|
||||
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
features,labels = features.to(device), labels.to(device)
|
||||
|
||||
#with higher.innerloop_ctx(model, optim, copy_initial_weights=True, track_higher_grads=False) as (fmodel, doptim):
|
||||
|
||||
|
||||
#optim.zero_grad()
|
||||
pred = fmodel.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
doptim.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
#loss.backward()
|
||||
#new_params = doptim.step(loss, params=fmodel.parameters())
|
||||
#fmodel.update_params(new_params)
|
||||
|
||||
|
||||
#print('Fast param',len(fmodel._fast_params))
|
||||
#print('opt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][2]['momentum_buffer'].shape)
|
||||
|
||||
if False or (len(fmodel._fast_params)>1):
|
||||
print("fmodel fast param",len(fmodel._fast_params))
|
||||
'''
|
||||
#val_loss = F.cross_entropy(fmodel(features), labels)
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
#val_loss.backward()
|
||||
#print('bip')
|
||||
|
||||
tmp = fmodel.parameters()
|
||||
|
||||
#print(list(tmp)[1])
|
||||
tmp = [higher.utils._copy_tensor(t,safe_copy=True) if isinstance(t, torch.Tensor) else t for t in tmp]
|
||||
#print(len(tmp))
|
||||
|
||||
#fmodel._fast_params.clear()
|
||||
del fmodel._fast_params
|
||||
fmodel._fast_params=None
|
||||
|
||||
fmodel.fast_params=tmp # Surcharge la memoire
|
||||
#fmodel.update_params(tmp) #Meilleur perf / Surcharge la memoire avec trach higher grad
|
||||
|
||||
#optim._fmodel=fmodel
|
||||
'''
|
||||
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
#doptim.detach_dyn()
|
||||
#tmp = doptim.state
|
||||
#tmp = doptim.state_dict()
|
||||
#for k, v in tmp['state'].items():
|
||||
# print('dict',k, type(v))
|
||||
|
||||
a = optim.param_groups[0]['params'][0]
|
||||
state = optim.state[a]
|
||||
#state['momentum_buffer'] = None
|
||||
#print('opt state', type(optim.state[a]), len(optim.state[a]))
|
||||
#optim.load_state_dict(tmp)
|
||||
|
||||
|
||||
for group_idx, group in enumerate(optim.param_groups):
|
||||
# print('gp idx',group_idx)
|
||||
for p_idx, p in enumerate(group['params']):
|
||||
optim.state[p]=doptim.state[group_idx][p_idx]
|
||||
|
||||
#print('opt state', type(optim.state[a]['momentum_buffer']), optim.state[a]['momentum_buffer'][0:10])
|
||||
#print('dopt state', type(doptim.state[0][0]['momentum_buffer']), doptim.state[0][0]['momentum_buffer'][0:10])
|
||||
'''
|
||||
for a in tmp:
|
||||
#print(type(a), len(a))
|
||||
for nb, b in a.items():
|
||||
#print(nb, type(b), len(b))
|
||||
for n, state in b.items():
|
||||
#print(n, type(states))
|
||||
#print(state.grad_fn)
|
||||
state = torch.tensor(state.data).requires_grad_()
|
||||
#print(state.grad_fn)
|
||||
'''
|
||||
|
||||
|
||||
doptim = higher.optim.get_diff_optim(optim, model.parameters(), track_higher_grads=True)
|
||||
#doptim.state = tmp
|
||||
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
#### Tests ####
|
||||
tf = time.process_time()
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
val_loss = F.cross_entropy(model(xs_val), ys_val)
|
||||
accuracy=test(model)
|
||||
model.train()
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": None,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
#countcopy+=1
|
||||
#model_copy(src=fmodel, dst=model, patch_copy=False)
|
||||
#optim.load_state_dict(doptim.state_dict()) #Besoin sauver etat otpim ?
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_simple_dataug(inner_it, epochs=1):
|
||||
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
#aug_model = nn.Sequential(
|
||||
# Data_aug(),
|
||||
# LeNet(1,10),
|
||||
# )
|
||||
aug_model = Augmented_model(Data_aug(), LeNet(1,10)).to(device)
|
||||
print(str(aug_model))
|
||||
meta_opt = torch.optim.Adam(aug_model['data_aug'].parameters(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(aug_model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
log = []
|
||||
t0 = time.process_time()
|
||||
|
||||
epoch = 0
|
||||
while epoch < epochs:
|
||||
meta_opt.zero_grad()
|
||||
aug_model.train()
|
||||
with higher.innerloop_ctx(aug_model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
tf = time.process_time()
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
accuracy=test(aug_model)
|
||||
aug_model.train()
|
||||
|
||||
#### Print ####
|
||||
print('-'*9)
|
||||
print('Epoch %d/%d'%(epoch,epochs))
|
||||
print('train loss',loss.item(), '/ val loss', val_loss.item())
|
||||
print('acc', accuracy)
|
||||
print('mag', aug_model['data_aug']['mag'].item())
|
||||
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": aug_model['data_aug']['mag'].item(),
|
||||
}
|
||||
log.append(data)
|
||||
t0 = time.process_time()
|
||||
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
|
||||
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('mag', fmodel['data_aug']['mag'].grad)
|
||||
|
||||
diffopt.step(loss) # note that `step` must take `loss` as an argument!
|
||||
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
|
||||
# these new parameters, as an alternative to getting them from
|
||||
# `fmodel.fast_params` or `fmodel.parameters()` after calling
|
||||
# `diffopt.step`.
|
||||
|
||||
# At this point, or at any point in the iteration, you can take the
|
||||
# gradient of `fmodel.parameters()` (or equivalently
|
||||
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
|
||||
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
|
||||
# `grad_fn` as an attribute, and be part of the gradient tape.
|
||||
|
||||
# At the end of your inner loop you can obtain these e.g. ...
|
||||
#grad_of_grads = torch.autograd.grad(
|
||||
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
fmodel.augment(mode=False)
|
||||
val_logits = fmodel(xs_val) #Validation sans transfornations !
|
||||
val_loss = F.cross_entropy(val_logits, ys_val)
|
||||
#print('val_loss',val_loss.item())
|
||||
val_loss.backward()
|
||||
|
||||
#print('mag', fmodel['data_aug']['mag'], '/', fmodel['data_aug']['mag'].grad)
|
||||
|
||||
#model=copy.deepcopy(fmodel)
|
||||
aug_model.load_state_dict(fmodel.state_dict()) #Do not copy gradient !
|
||||
#Copie des gradients
|
||||
for paramName, paramValue, in fmodel.named_parameters():
|
||||
for netCopyName, netCopyValue, in aug_model.named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
|
||||
#print('mag', aug_model['data_aug']['mag'], '/', aug_model['data_aug']['mag'].grad)
|
||||
meta_opt.step()
|
||||
|
||||
plot_res(log, fig_name="res/{}-{} epochs- {} in_it".format(str(aug_model),epochs,inner_it))
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
print(str(aug_model),": acc", max([x["acc"] for x in log]), "in (ms):", np.mean(times), "+/-", np.std(times))
|
||||
|
||||
def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
||||
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-3)
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
high_grad_track = True
|
||||
if dataug_epoch_start>0:
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
model.train()
|
||||
|
||||
log = []
|
||||
t0 = time.process_time()
|
||||
|
||||
countcopy=0
|
||||
val_loss=torch.tensor(0)
|
||||
opt_param=None
|
||||
|
||||
epoch = 0
|
||||
while epoch < epochs:
|
||||
meta_opt.zero_grad()
|
||||
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
tf = time.process_time()
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
accuracy=test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',aug_model['data_aug']['prob'].grad)
|
||||
#############
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": [p for p in model['data_aug']['prob']],
|
||||
}
|
||||
log.append(data)
|
||||
#############
|
||||
|
||||
if epoch == dataug_epoch_start:
|
||||
print('Starting Data Augmention...')
|
||||
model.augment(mode=True)
|
||||
high_grad_track = True
|
||||
|
||||
t0 = time.process_time()
|
||||
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
'''
|
||||
#Methode exacte
|
||||
final_loss = 0
|
||||
for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
fmodel['data_aug'].transf_idx=tf_idx
|
||||
logits = fmodel(xs)
|
||||
loss = F.cross_entropy(logits, ys)
|
||||
#loss.backward(retain_graph=True)
|
||||
#print('idx', tf_idx)
|
||||
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
|
||||
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
|
||||
loss = final_loss
|
||||
'''
|
||||
#Methode uniforme
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight().to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
#'''
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
xs_val, ys_val = xs_val.to(device), ys_val.to(device)
|
||||
|
||||
fmodel.augment(mode=False) #Validation sans transfornations !
|
||||
val_loss = F.cross_entropy(fmodel(xs_val), ys_val)
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_prob() #Contrainte sum(proba)=1
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, loss_patience=None):
|
||||
|
||||
log = []
|
||||
countcopy=0
|
||||
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||
dl_val_it = iter(dl_val)
|
||||
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
high_grad_track=False
|
||||
if dataug_epoch_start!=0:
|
||||
model.augment(mode=False)
|
||||
high_grad_track = False
|
||||
|
||||
val_loss_monitor= None
|
||||
if loss_patience != None :
|
||||
if dataug_epoch_start==-1: val_loss_monitor = loss_monitor(patience=loss_patience, end_train=2) #1st limit = dataug start
|
||||
else: val_loss_monitor = loss_monitor(patience=loss_patience) #Val loss monitor
|
||||
|
||||
model.train()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
|
||||
for epoch in range(1, epochs+1):
|
||||
#print_torch_mem("Start epoch "+str(epoch))
|
||||
#print(high_grad_track, fmodel._data_augmentation, len(fmodel._fast_params))
|
||||
t0 = time.process_time()
|
||||
#with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, override=opt_param, track_higher_grads=high_grad_track) as (fmodel, diffopt):
|
||||
|
||||
for i, (xs, ys) in enumerate(dl_train):
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
'''
|
||||
#Methode exacte
|
||||
final_loss = 0
|
||||
for tf_idx in range(fmodel['data_aug']._nb_tf):
|
||||
fmodel['data_aug'].transf_idx=tf_idx
|
||||
logits = fmodel(xs)
|
||||
loss = F.cross_entropy(logits, ys)
|
||||
#loss.backward(retain_graph=True)
|
||||
#print('idx', tf_idx)
|
||||
#print(fmodel['data_aug']['prob'][tf_idx], fmodel['data_aug']['prob'][tf_idx].grad)
|
||||
final_loss += loss*fmodel['data_aug']['prob'][tf_idx] #Take it in the forward function ?
|
||||
|
||||
loss = final_loss
|
||||
'''
|
||||
#Methode uniforme
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys, reduction='none') # no need to call loss.backwards()
|
||||
#PAS PONDERE LOSS POUR DIST MIX
|
||||
if fmodel._data_augmentation: # and not fmodel['data_aug']._mix_dist: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight().to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
#'''
|
||||
|
||||
#to visualize computational graph
|
||||
#print_graph(loss)
|
||||
|
||||
#loss.backward(retain_graph=True)
|
||||
#print(fmodel['model']._params['b4'].grad)
|
||||
#print('prob grad', fmodel['data_aug']['prob'].grad)
|
||||
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
if(high_grad_track and i%inner_it==0): #Perform Meta step
|
||||
#print("meta")
|
||||
#Peu utile si high_grad_track = False
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_val_it=dl_val_it)
|
||||
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_prob(soft=False) #Contrainte sum(proba)=1
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
#viz_sample_data(imgs=xs, labels=ys, fig_name='samples/data_sample_epoch{}_noTF'.format(epoch))
|
||||
#viz_sample_data(imgs=aug_model['data_aug'](xs), labels=ys, fig_name='samples/data_sample_epoch{}'.format(epoch))
|
||||
|
||||
if(not high_grad_track):
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_val_it=dl_val_it)
|
||||
|
||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
accuracy=test(model)
|
||||
model.train()
|
||||
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d/%d'%(epoch,epochs))
|
||||
print('Time : %.00f ms'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Accuracy :', accuracy)
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',aug_model['data_aug']['prob'].grad)
|
||||
#############
|
||||
#### Log ####
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": [p.item() for p in model['data_aug']['prob']],
|
||||
}
|
||||
log.append(data)
|
||||
#############
|
||||
if val_loss_monitor :
|
||||
val_loss_monitor.register(val_loss.item())
|
||||
if val_loss_monitor.end_training(): break #Stop training
|
||||
|
||||
|
||||
if not model.is_augmenting() and (epoch == dataug_epoch_start or (val_loss_monitor and val_loss_monitor.limit_reached()==1)):
|
||||
print('Starting Data Augmention...')
|
||||
dataug_epoch_start = epoch
|
||||
model.augment(mode=True)
|
||||
if inner_it != 0: high_grad_track = True
|
||||
|
||||
print("Copy ", countcopy)
|
||||
return log
|
||||
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
n_inner_iter = 0
|
||||
epochs = 2
|
||||
dataug_epoch_start=0
|
||||
|
||||
#### Classic ####
|
||||
'''
|
||||
model = LeNet(3,10).to(device)
|
||||
#model = torchvision.models.resnet18()
|
||||
#model = Augmented_model(Data_augV3(mix_dist=0.0), LeNet(3,10)).to(device)
|
||||
#model.augment(mode=False)
|
||||
|
||||
print(str(model), 'on', device_name)
|
||||
log= train_classic_higher(model=model, epochs=epochs)
|
||||
|
||||
####
|
||||
plot_res(log, fig_name="res/{}-{} epochs".format(str(model),epochs))
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Log": log}
|
||||
print(str(model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1])
|
||||
with open("res/log/%s.json" % "{}-{} epochs".format(str(model),epochs), "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
print('-'*9)
|
||||
'''
|
||||
#### Augmented Model ####
|
||||
#'''
|
||||
aug_model = Augmented_model(Data_augV4(TF_dict=TF.TF_dict, mix_dist=0.0), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=10)
|
||||
|
||||
####
|
||||
plot_res(log, fig_name="res/{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter))
|
||||
print('-'*9)
|
||||
times = [x["time"] for x in log]
|
||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times)), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
||||
print(str(aug_model),": acc", out["Accuracy"], "in (ms):", out["Time"][0], "+/-", out["Time"][1])
|
||||
with open("res/log/%s.json" % "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter), "w+") as f:
|
||||
json.dump(out, f, indent=True)
|
||||
print('Log :\"',f.name, '\" saved !')
|
||||
print('-'*9)
|
||||
#'''
|
||||
|
||||
#### Comparison ####
|
||||
'''
|
||||
files=[
|
||||
#"res/log/LeNet-100 epochs.json",
|
||||
#"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
#"res/log/Aug_mod(Data_augV4(Uniform-4 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json",
|
||||
#"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
#"res/log/Aug_mod(Data_augV3(Uniform-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
|
||||
#"res/log/Aug_mod(Data_augV4(Mix 0,5-3 TF)-LeNet)-100 epochs (dataug:0)- 1 in_it.json",
|
||||
#"res/log/Aug_mod(Data_augV4(Mix 0.5-3 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
|
||||
#"res/log/Aug_mod(Data_augV4(Uniform-3 TF)-LeNet)-100 epochs (dataug:0)- 10 in_it.json",
|
||||
"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 10 in_it.json",
|
||||
"res/log/Aug_mod(Data_augV4(Uniform-10 TF)-LeNet)-100 epochs (dataug:50)- 0 in_it.json",
|
||||
]
|
||||
plot_compare(filenames=files, fig_name="res/compare")
|
||||
'''
|
150
higher/test_lr.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
import numpy as np
|
||||
import json, math, time, os
|
||||
|
||||
from torch.utils.data import SubsetRandomSampler
|
||||
import torch.optim as optim
|
||||
import higher
|
||||
from model import *
|
||||
|
||||
import copy
|
||||
|
||||
BATCH_SIZE = 300
|
||||
TEST_SIZE = 300
|
||||
|
||||
mnist_train = torchvision.datasets.MNIST(
|
||||
"./data", train=True, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
#torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0),
|
||||
torchvision.transforms.ToTensor()
|
||||
])
|
||||
)
|
||||
|
||||
mnist_test = torchvision.datasets.MNIST(
|
||||
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
|
||||
)
|
||||
|
||||
#train_subset_indices=range(int(len(mnist_train)/2))
|
||||
train_subset_indices=range(BATCH_SIZE)
|
||||
val_subset_indices=range(int(len(mnist_train)/2),len(mnist_train))
|
||||
|
||||
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
dl_val = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
|
||||
|
||||
def test(model):
|
||||
model.eval()
|
||||
for i, (features, labels) in enumerate(dl_test):
|
||||
pred = model.forward(features)
|
||||
return pred.argmax(dim=1).eq(labels).sum().item() / TEST_SIZE * 100
|
||||
|
||||
def train_classic(model, optim, epochs=1):
|
||||
model.train()
|
||||
log = []
|
||||
for epoch in range(epochs):
|
||||
t0 = time.process_time()
|
||||
for i, (features, labels) in enumerate(dl_train):
|
||||
|
||||
optim.zero_grad()
|
||||
pred = model.forward(features)
|
||||
loss = F.cross_entropy(pred,labels)
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
#### Log ####
|
||||
tf = time.process_time()
|
||||
data={
|
||||
"time": tf - t0,
|
||||
}
|
||||
log.append(data)
|
||||
|
||||
times = [x["time"] for x in log]
|
||||
print("Vanilla : acc", test(model), "in (ms):", np.mean(times), "+/-", np.std(times))
|
||||
##########################################
|
||||
if __name__ == "__main__":
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
||||
model = LeNet(1,10)
|
||||
opt_param = {
|
||||
"lr": torch.tensor(1e-2).requires_grad_(),
|
||||
"momentum": torch.tensor(0.9).requires_grad_()
|
||||
}
|
||||
n_inner_iter = 1
|
||||
dl_train_it = iter(dl_train)
|
||||
dl_val_it = iter(dl_val)
|
||||
epoch = 0
|
||||
epochs = 10
|
||||
|
||||
####
|
||||
train_classic(model=model, optim=torch.optim.Adam(model.parameters(), lr=0.001), epochs=epochs)
|
||||
model = LeNet(1,10)
|
||||
|
||||
meta_opt = torch.optim.Adam(opt_param.values(), lr=1e-2)
|
||||
inner_opt = torch.optim.SGD(model.parameters(), lr=opt_param['lr'], momentum=opt_param['momentum'])
|
||||
#for xs_val, ys_val in dl_val:
|
||||
while epoch < epochs:
|
||||
#print(data_aug.params["mag"], data_aug.params["mag"].grad)
|
||||
meta_opt.zero_grad()
|
||||
model.train()
|
||||
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
||||
|
||||
for param_group in diffopt.param_groups:
|
||||
param_group['lr'] = opt_param['lr']
|
||||
param_group['momentum'] = opt_param['momentum']
|
||||
|
||||
for i in range(n_inner_iter):
|
||||
try:
|
||||
xs, ys = next(dl_train_it)
|
||||
except StopIteration: #Fin epoch train
|
||||
epoch +=1
|
||||
dl_train_it = iter(dl_train)
|
||||
xs, ys = next(dl_train_it)
|
||||
|
||||
print('Epoch', epoch)
|
||||
print('train loss',loss.item(), '/ val loss', val_loss.item())
|
||||
print('acc', test(model))
|
||||
print('opt : lr', opt_param['lr'].item(), 'momentum', opt_param['momentum'].item())
|
||||
print('-'*9)
|
||||
model.train()
|
||||
|
||||
|
||||
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
||||
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
|
||||
#print('loss',loss.item())
|
||||
diffopt.step(loss) # note that `step` must take `loss` as an argument!
|
||||
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
|
||||
# these new parameters, as an alternative to getting them from
|
||||
# `fmodel.fast_params` or `fmodel.parameters()` after calling
|
||||
# `diffopt.step`.
|
||||
|
||||
# At this point, or at any point in the iteration, you can take the
|
||||
# gradient of `fmodel.parameters()` (or equivalently
|
||||
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
|
||||
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
|
||||
# `grad_fn` as an attribute, and be part of the gradient tape.
|
||||
|
||||
# At the end of your inner loop you can obtain these e.g. ...
|
||||
#grad_of_grads = torch.autograd.grad(
|
||||
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
|
||||
try:
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_val_it = iter(dl_val_it)
|
||||
xs_val, ys_val = next(dl_val_it)
|
||||
|
||||
val_logits = fmodel(xs_val)
|
||||
val_loss = F.cross_entropy(val_logits, ys_val)
|
||||
#print('val_loss',val_loss.item())
|
||||
|
||||
val_loss.backward()
|
||||
#meta_grads = torch.autograd.grad(val_loss, opt_lr, allow_unused=True)
|
||||
#print(meta_grads)
|
||||
for param_group in diffopt.param_groups:
|
||||
print(param_group['lr'], '/',param_group['lr'].grad)
|
||||
print(param_group['momentum'], '/',param_group['momentum'].grad)
|
||||
|
||||
#model=copy.deepcopy(fmodel)
|
||||
model.load_state_dict(fmodel.state_dict())
|
||||
|
||||
meta_opt.step()
|
205
higher/transformations.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
import torch
|
||||
import kornia
|
||||
import random
|
||||
|
||||
### Available TF for Dataug ###
|
||||
TF_dict={ #f(mag_normalise)=mag_reelle
|
||||
## Geometric TF ##
|
||||
'Identity' : (lambda mag: None),
|
||||
'FlipUD' : (lambda mag: None),
|
||||
'FlipLR' : (lambda mag: None),
|
||||
'Rotate': (lambda mag: random.randint(-int_parameter(mag, maxval=30), int_parameter(mag, maxval=30))),
|
||||
'TranslateX': (lambda mag: [random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20)), 0]),
|
||||
'TranslateY': (lambda mag: [0, random.randint(-int_parameter(mag, maxval=20), int_parameter(mag, maxval=20))]),
|
||||
'ShearX': (lambda mag: [random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3)), 0]),
|
||||
'ShearY': (lambda mag: [0, random.uniform(-float_parameter(mag, maxval=0.3), float_parameter(mag, maxval=0.3))]),
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast': (lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Color':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Brightness':(lambda mag: random.uniform(1., float_parameter(mag, maxval=1.9))),
|
||||
'Sharpness':(lambda mag: random.uniform(0.1, float_parameter(mag, maxval=1.9))),
|
||||
'Posterize': (lambda mag: random.randint(4, int_parameter(mag, maxval=8))),
|
||||
'Solarize': (lambda mag: random.randint(1, int_parameter(mag, maxval=256))/256.), #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
#'Equalize': (lambda mag: None),
|
||||
}
|
||||
|
||||
|
||||
def int_image(float_image): #ATTENTION : legere perte d'info (granularite : 1/256 = 0.0039)
|
||||
return (float_image*255.).type(torch.uint8)
|
||||
|
||||
def float_image(int_image):
|
||||
return int_image.type(torch.float)/255.
|
||||
|
||||
def rand_inverse(value):
|
||||
return value if random.random() < 0.5 else -value
|
||||
|
||||
#https://github.com/tensorflow/models/blob/fc2056bce6ab17eabdc139061fef8f4f2ee763ec/research/autoaugment/augmentation_transforms.py#L137
|
||||
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
|
||||
def float_parameter(level, maxval):
|
||||
"""Helper function to scale `val` between 0 and maxval .
|
||||
Args:
|
||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||
maxval: Maximum value that the operation can have. This will be scaled
|
||||
to level/PARAMETER_MAX.
|
||||
Returns:
|
||||
A float that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
return float(level) * maxval / PARAMETER_MAX
|
||||
|
||||
def int_parameter(level, maxval):
|
||||
"""Helper function to scale `val` between 0 and maxval .
|
||||
Args:
|
||||
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
|
||||
maxval: Maximum value that the operation can have. This will be scaled
|
||||
to level/PARAMETER_MAX.
|
||||
Returns:
|
||||
An int that results from scaling `maxval` according to `level`.
|
||||
"""
|
||||
return int(level * maxval / PARAMETER_MAX)
|
||||
|
||||
def flipLR(x):
|
||||
device = x.device
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
M =torch.tensor( [[[-1., 0., w-1],
|
||||
[ 0., 1., 0.],
|
||||
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
|
||||
|
||||
# warp the original image by the found transform
|
||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||
|
||||
def flipUD(x):
|
||||
device = x.device
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
M =torch.tensor( [[[ 1., 0., 0.],
|
||||
[ 0., -1., h-1],
|
||||
[ 0., 0., 1.]]], device=device).expand(batch_size,-1,-1)
|
||||
|
||||
# warp the original image by the found transform
|
||||
return kornia.warp_perspective(x, M, dsize=(h, w))
|
||||
|
||||
def rotate(x, angle):
|
||||
return kornia.rotate(x, angle=angle.type(torch.float32)) #Kornia ne supporte pas les int
|
||||
|
||||
def translate(x, translation):
|
||||
return kornia.translate(x, translation=translation.type(torch.float32)) #Kornia ne supporte pas les int
|
||||
|
||||
def shear(x, shear):
|
||||
return kornia.shear(x, shear=shear)
|
||||
|
||||
def contrast(x, contrast_factor):
|
||||
return kornia.adjust_contrast(x, contrast_factor=contrast_factor) #Expect image in the range of [0, 1]
|
||||
|
||||
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageEnhance.py
|
||||
def color(x, color_factor):
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
gray_x = kornia.rgb_to_grayscale(x)
|
||||
gray_x = gray_x.repeat_interleave(channels, dim=1)
|
||||
return blend(gray_x, x, color_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||
|
||||
def brightness(x, brightness_factor):
|
||||
device = x.device
|
||||
|
||||
return blend(torch.zeros(x.size(), device=device), x, brightness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||
|
||||
def sharpeness(x, sharpness_factor):
|
||||
device = x.device
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
|
||||
k = torch.tensor([[[ 1., 1., 1.],
|
||||
[ 1., 5., 1.],
|
||||
[ 1., 1., 1.]]], device=device) #Smooth Filter : https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageFilter.py
|
||||
smooth_x = kornia.filter2D(x, kernel=k, border_type='reflect', normalized=True) #Peut etre necessaire de s'occuper du channel Alhpa differement
|
||||
|
||||
return blend(smooth_x, x, sharpness_factor).clamp(min=0.0,max=1.0) #Expect image in the range of [0, 1]
|
||||
|
||||
#https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
|
||||
def posterize(x, bits):
|
||||
x = int_image(x) #Expect image in the range of [0, 1]
|
||||
|
||||
mask = ~(2 ** (8 - bits) - 1).type(torch.uint8)
|
||||
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
mask = mask.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
|
||||
return float_image(x & mask)
|
||||
|
||||
def auto_contrast(x): #PAS OPTIMISE POUR DES BATCH #EXTRA LENT
|
||||
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
|
||||
print("Warning : Pas encore check !")
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
x = int_image(x) #Expect image in the range of [0, 1]
|
||||
#print('Start',x[0])
|
||||
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||
#print(img.shape)
|
||||
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||
#print(chan.shape)
|
||||
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||
|
||||
# find lowest/highest samples after preprocessing
|
||||
for lo in range(256):
|
||||
if hist[lo]:
|
||||
break
|
||||
for hi in range(255, -1, -1):
|
||||
if hist[hi]:
|
||||
break
|
||||
if hi <= lo:
|
||||
# don't bother
|
||||
pass
|
||||
else:
|
||||
scale = 255.0 / (hi - lo)
|
||||
offset = -lo * scale
|
||||
for ix in range(256):
|
||||
n_ix = int(ix * scale + offset)
|
||||
if n_ix < 0: n_ix = 0
|
||||
elif n_ix > 255: n_ix = 255
|
||||
|
||||
chan[chan==ix]=n_ix
|
||||
x[im_idx, chan_idx]=chan
|
||||
|
||||
#print('End',x[0])
|
||||
return float_image(x)
|
||||
|
||||
def equalize(x): #PAS OPTIMISE POUR DES BATCH
|
||||
raise Exception(self, "not implemented")
|
||||
# Optimisation : Application de LUT efficace / Calcul d'histogramme par batch/channel
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
x = int_image(x) #Expect image in the range of [0, 1]
|
||||
#print('Start',x[0])
|
||||
for im_idx, img in enumerate(x.chunk(batch_size, dim=0)): #Operation par image
|
||||
#print(img.shape)
|
||||
for chan_idx, chan in enumerate(img.chunk(channels, dim=1)): # Operation par channel
|
||||
#print(chan.shape)
|
||||
hist = torch.histc(chan, bins=256, min=0, max=255) #PAS DIFFERENTIABLE
|
||||
|
||||
return float_image(x)
|
||||
|
||||
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
||||
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
|
||||
for idx, t in enumerate(thresholds): #Operation par image
|
||||
mask = x[idx] > t.item()
|
||||
inv_x = 1-x[idx][mask]
|
||||
x[idx][mask]=inv_x
|
||||
return x
|
||||
|
||||
#https://github.com/python-pillow/Pillow/blob/9c78c3f97291bd681bc8637922d6a2fa9415916c/src/PIL/Image.py#L2818
|
||||
def blend(x,y,alpha): #out = image1 * (1.0 - alpha) + image2 * alpha
|
||||
#return kornia.add_weighted(src1=x, alpha=(1-alpha), src2=y, beta=alpha, gamma=0) #out=src1∗alpha+src2∗beta+gamma #Ne fonctionne pas pour des batch de alpha
|
||||
|
||||
if not isinstance(x, torch.Tensor):
|
||||
raise TypeError("x should be a tensor. Got {}".format(type(x)))
|
||||
|
||||
if not isinstance(y, torch.Tensor):
|
||||
raise TypeError("y should be a tensor. Got {}".format(type(y)))
|
||||
|
||||
(batch_size, channels, h, w) = x.shape
|
||||
alpha = alpha.unsqueeze(dim=1).expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
res = x*(1-alpha) + y*alpha
|
||||
|
||||
return res
|
184
higher/utils.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
import numpy as np
|
||||
import json, math, time, os
|
||||
import matplotlib.pyplot as plt
|
||||
import copy
|
||||
import gc
|
||||
|
||||
from torchviz import make_dot
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def print_graph(PyTorch_obj, fig_name='graph'):
|
||||
graph=make_dot(PyTorch_obj) #Loss give the whole graph
|
||||
graph.format = 'svg' #https://graphviz.readthedocs.io/en/stable/manual.html#formats
|
||||
graph.render(fig_name)
|
||||
|
||||
def plot_res(log, fig_name='res'):
|
||||
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(15, 3))
|
||||
|
||||
ax[0].set_title('Loss')
|
||||
ax[0].plot(epochs,[x["train_loss"] for x in log], label='Train')
|
||||
ax[0].plot(epochs,[x["val_loss"] for x in log], label='Val')
|
||||
ax[0].legend()
|
||||
|
||||
ax[1].set_title('Acc')
|
||||
ax[1].plot(epochs,[x["acc"] for x in log])
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if isinstance(log[0]["param"],float):
|
||||
ax[2].set_title('Mag')
|
||||
ax[2].plot(epochs,[x["param"] for x in log], label='Mag')
|
||||
ax[2].legend()
|
||||
else :
|
||||
ax[2].set_title('Prob')
|
||||
for idx, _ in enumerate(log[0]["param"]):
|
||||
ax[2].plot(epochs,[x["param"][idx] for x in log], label='P'+str(idx))
|
||||
ax[2].legend()
|
||||
#ax[2].legend(('P-0', 'P-45', 'P-180'))
|
||||
|
||||
fig_name = fig_name.replace('.',',')
|
||||
plt.savefig(fig_name)
|
||||
|
||||
def plot_compare(filenames, fig_name='res'):
|
||||
|
||||
all_data=[]
|
||||
legend=""
|
||||
for idx, file in enumerate(filenames):
|
||||
legend+=str(idx)+'-'+file+'\n'
|
||||
with open(file) as json_file:
|
||||
data = json.load(json_file)
|
||||
all_data.append(data)
|
||||
|
||||
fig, ax = plt.subplots(ncols=3, figsize=(30, 8))
|
||||
|
||||
for data_idx, log in enumerate(all_data):
|
||||
log=log['Log']
|
||||
epochs = [x["epoch"] for x in log]
|
||||
|
||||
ax[0].plot(epochs,[x["train_loss"] for x in log], label=str(data_idx)+'-Train')
|
||||
ax[0].plot(epochs,[x["val_loss"] for x in log], label=str(data_idx)+'-Val')
|
||||
|
||||
ax[1].plot(epochs,[x["acc"] for x in log], label=str(data_idx))
|
||||
#ax[1].text(x=0.5,y=0,s=str(data_idx)+'-'+filenames[data_idx], transform=ax[1].transAxes)
|
||||
|
||||
if log[0]["param"]!= None:
|
||||
if isinstance(log[0]["param"],float):
|
||||
ax[2].plot(epochs,[x["param"] for x in log], label=str(data_idx)+'-Mag')
|
||||
|
||||
else :
|
||||
for idx, _ in enumerate(log[0]["param"]):
|
||||
ax[2].plot(epochs,[x["param"][idx] for x in log], label=str(data_idx)+'-P'+str(idx))
|
||||
|
||||
fig.suptitle(legend)
|
||||
ax[0].set_title('Loss')
|
||||
ax[1].set_title('Acc')
|
||||
ax[2].set_title('Param')
|
||||
for a in ax: a.legend()
|
||||
fig_name = fig_name.replace('.',',')
|
||||
|
||||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
|
||||
def viz_sample_data(imgs, labels, fig_name='data_sample'):
|
||||
|
||||
sample = imgs[0:25,].permute(0, 2, 3, 1).squeeze().cpu()
|
||||
|
||||
plt.figure(figsize=(10,10))
|
||||
for i in range(25):
|
||||
plt.subplot(5,5,i+1)
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
plt.grid(False)
|
||||
plt.imshow(sample[i,], cmap=plt.cm.binary)
|
||||
plt.xlabel(labels[i].item())
|
||||
|
||||
plt.savefig(fig_name)
|
||||
|
||||
def model_copy(src,dst, patch_copy=True, copy_grad=True):
|
||||
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
|
||||
|
||||
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
|
||||
|
||||
if patch_copy:
|
||||
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
|
||||
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
|
||||
|
||||
#Copie des gradients
|
||||
if copy_grad:
|
||||
for paramName, paramValue, in src.named_parameters():
|
||||
for netCopyName, netCopyValue, in dst.named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
#netCopyValue=copy.deepcopy(paramValue)
|
||||
|
||||
try: #Data_augV4
|
||||
dst['data_aug']._input_info = src['data_aug']._input_info
|
||||
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
|
||||
except:
|
||||
pass
|
||||
|
||||
def optim_copy(dopt, opt):
|
||||
|
||||
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
|
||||
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
|
||||
|
||||
for group_idx, group in enumerate(opt.param_groups):
|
||||
# print('gp idx',group_idx)
|
||||
for p_idx, p in enumerate(group['params']):
|
||||
opt.state[p]=dopt.state[group_idx][p_idx]
|
||||
|
||||
def print_torch_mem(add_info=''):
|
||||
|
||||
nb=0
|
||||
max_size=0
|
||||
for obj in gc.get_objects():
|
||||
#print(type(obj))
|
||||
try:
|
||||
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # and len(obj.size())>1:
|
||||
#print(i, type(obj), obj.size())
|
||||
size = np.sum(obj.size())
|
||||
if(size>max_size): max_size=size
|
||||
nb+=1
|
||||
except:
|
||||
pass
|
||||
print(add_info, "-Pytroch tensor nb:",nb," / Max dim:", max_size)
|
||||
|
||||
#print(add_info, "-Garbage size :",len(gc.garbage))
|
||||
|
||||
class loss_monitor(): #Voir https://github.com/pytorch/ignite
|
||||
def __init__(self, patience, end_train=1):
|
||||
self.patience = patience
|
||||
self.end_train = end_train
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.reached_limit = 0
|
||||
|
||||
def register(self, loss):
|
||||
if self.best_score is None:
|
||||
self.best_score = loss
|
||||
elif loss > self.best_score:
|
||||
self.counter += 1
|
||||
#if not self.reached_limit:
|
||||
print("loss no improve counter", self.counter, self.reached_limit)
|
||||
else:
|
||||
self.best_score = loss
|
||||
self.counter = 0
|
||||
def limit_reached(self):
|
||||
if self.counter >= self.patience:
|
||||
self.counter = 0
|
||||
self.reached_limit +=1
|
||||
self.best_score = None
|
||||
return self.reached_limit
|
||||
|
||||
def end_training(self):
|
||||
if self.limit_reached() >= self.end_train:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def reset(self):
|
||||
self.__init__(self.patience, self.end_train)
|