mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
150 lines
5.8 KiB
Python
Executable file
150 lines
5.8 KiB
Python
Executable file
import numpy as np
|
|
import json, math, time, os
|
|
|
|
from torch.utils.data import SubsetRandomSampler
|
|
import torch.optim as optim
|
|
import higher
|
|
from model import *
|
|
|
|
import copy
|
|
|
|
BATCH_SIZE = 300
|
|
TEST_SIZE = 300
|
|
|
|
mnist_train = torchvision.datasets.MNIST(
|
|
"./data", train=True, download=True,
|
|
transform=torchvision.transforms.Compose([
|
|
#torchvision.transforms.RandomAffine(degrees=180, translate=None, scale=None, shear=None, resample=False, fillcolor=0),
|
|
torchvision.transforms.ToTensor()
|
|
])
|
|
)
|
|
|
|
mnist_test = torchvision.datasets.MNIST(
|
|
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
|
|
)
|
|
|
|
#train_subset_indices=range(int(len(mnist_train)/2))
|
|
train_subset_indices=range(BATCH_SIZE)
|
|
val_subset_indices=range(int(len(mnist_train)/2),len(mnist_train))
|
|
|
|
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
|
dl_val = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
|
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=TEST_SIZE, shuffle=False)
|
|
|
|
|
|
def test(model):
|
|
model.eval()
|
|
for i, (features, labels) in enumerate(dl_test):
|
|
pred = model.forward(features)
|
|
return pred.argmax(dim=1).eq(labels).sum().item() / TEST_SIZE * 100
|
|
|
|
def train_classic(model, optim, epochs=1):
|
|
model.train()
|
|
log = []
|
|
for epoch in range(epochs):
|
|
t0 = time.process_time()
|
|
for i, (features, labels) in enumerate(dl_train):
|
|
|
|
optim.zero_grad()
|
|
pred = model.forward(features)
|
|
loss = F.cross_entropy(pred,labels)
|
|
loss.backward()
|
|
optim.step()
|
|
|
|
#### Log ####
|
|
tf = time.process_time()
|
|
data={
|
|
"time": tf - t0,
|
|
}
|
|
log.append(data)
|
|
|
|
times = [x["time"] for x in log]
|
|
print("Vanilla : acc", test(model), "in (ms):", np.mean(times), "+/-", np.std(times))
|
|
##########################################
|
|
if __name__ == "__main__":
|
|
|
|
device = torch.device('cpu')
|
|
|
|
model = LeNet(1,10)
|
|
opt_param = {
|
|
"lr": torch.tensor(1e-2).requires_grad_(),
|
|
"momentum": torch.tensor(0.9).requires_grad_()
|
|
}
|
|
n_inner_iter = 1
|
|
dl_train_it = iter(dl_train)
|
|
dl_val_it = iter(dl_val)
|
|
epoch = 0
|
|
epochs = 10
|
|
|
|
####
|
|
train_classic(model=model, optim=torch.optim.Adam(model.parameters(), lr=0.001), epochs=epochs)
|
|
model = LeNet(1,10)
|
|
|
|
meta_opt = torch.optim.Adam(opt_param.values(), lr=1e-2)
|
|
inner_opt = torch.optim.SGD(model.parameters(), lr=opt_param['lr'], momentum=opt_param['momentum'])
|
|
#for xs_val, ys_val in dl_val:
|
|
while epoch < epochs:
|
|
#print(data_aug.params["mag"], data_aug.params["mag"].grad)
|
|
meta_opt.zero_grad()
|
|
model.train()
|
|
with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=True, track_higher_grads=True) as (fmodel, diffopt): #effet copy_initial_weight pas clair...
|
|
|
|
for param_group in diffopt.param_groups:
|
|
param_group['lr'] = opt_param['lr']
|
|
param_group['momentum'] = opt_param['momentum']
|
|
|
|
for i in range(n_inner_iter):
|
|
try:
|
|
xs, ys = next(dl_train_it)
|
|
except StopIteration: #Fin epoch train
|
|
epoch +=1
|
|
dl_train_it = iter(dl_train)
|
|
xs, ys = next(dl_train_it)
|
|
|
|
print('Epoch', epoch)
|
|
print('train loss',loss.item(), '/ val loss', val_loss.item())
|
|
print('acc', test(model))
|
|
print('opt : lr', opt_param['lr'].item(), 'momentum', opt_param['momentum'].item())
|
|
print('-'*9)
|
|
model.train()
|
|
|
|
|
|
logits = fmodel(xs) # modified `params` can also be passed as a kwarg
|
|
loss = F.cross_entropy(logits, ys) # no need to call loss.backwards()
|
|
#print('loss',loss.item())
|
|
diffopt.step(loss) # note that `step` must take `loss` as an argument!
|
|
# The line above gets P[t+1] from P[t] and loss[t]. `step` also returns
|
|
# these new parameters, as an alternative to getting them from
|
|
# `fmodel.fast_params` or `fmodel.parameters()` after calling
|
|
# `diffopt.step`.
|
|
|
|
# At this point, or at any point in the iteration, you can take the
|
|
# gradient of `fmodel.parameters()` (or equivalently
|
|
# `fmodel.fast_params`) w.r.t. `fmodel.parameters(time=0)` (equivalently
|
|
# `fmodel.init_fast_params`). i.e. `fast_params` will always have
|
|
# `grad_fn` as an attribute, and be part of the gradient tape.
|
|
|
|
# At the end of your inner loop you can obtain these e.g. ...
|
|
#grad_of_grads = torch.autograd.grad(
|
|
# meta_loss_fn(fmodel.parameters()), fmodel.parameters(time=0))
|
|
try:
|
|
xs_val, ys_val = next(dl_val_it)
|
|
except StopIteration: #Fin epoch val
|
|
dl_val_it = iter(dl_val_it)
|
|
xs_val, ys_val = next(dl_val_it)
|
|
|
|
val_logits = fmodel(xs_val)
|
|
val_loss = F.cross_entropy(val_logits, ys_val)
|
|
#print('val_loss',val_loss.item())
|
|
|
|
val_loss.backward()
|
|
#meta_grads = torch.autograd.grad(val_loss, opt_lr, allow_unused=True)
|
|
#print(meta_grads)
|
|
for param_group in diffopt.param_groups:
|
|
print(param_group['lr'], '/',param_group['lr'].grad)
|
|
print(param_group['momentum'], '/',param_group['momentum'].grad)
|
|
|
|
#model=copy.deepcopy(fmodel)
|
|
model.load_state_dict(fmodel.state_dict())
|
|
|
|
meta_opt.step()
|