mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Salvador tests
This commit is contained in:
parent
6c0597e7ea
commit
aade27011a
57 changed files with 29210 additions and 0 deletions
585
salvador/train_dataug.py
Normal file
585
salvador/train_dataug.py
Normal file
|
@ -0,0 +1,585 @@
|
|||
import datetime
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch import nn
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from PIL import ImageEnhance
|
||||
import random
|
||||
|
||||
import utils
|
||||
from fastprogress import master_bar, progress_bar
|
||||
import numpy as np
|
||||
|
||||
|
||||
## DATA AUG ##
|
||||
import higher
|
||||
from dataug import *
|
||||
from dataug_utils import *
|
||||
tf_names = [
|
||||
## Geometric TF ##
|
||||
'Identity',
|
||||
'FlipUD',
|
||||
'FlipLR',
|
||||
'Rotate',
|
||||
'TranslateX',
|
||||
'TranslateY',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast',
|
||||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'Posterize',
|
||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
]
|
||||
|
||||
def compute_vaLoss(model, dl_it, dl):
|
||||
device = next(model.parameters()).device
|
||||
try:
|
||||
xs, ys = next(dl_it)
|
||||
except StopIteration: #Fin epoch val
|
||||
dl_it = iter(dl)
|
||||
xs, ys = next(dl_it)
|
||||
xs, ys = xs.to(device), ys.to(device)
|
||||
|
||||
model.eval() #Validation sans transfornations !
|
||||
|
||||
return F.cross_entropy(model(xs), ys)
|
||||
|
||||
def model_copy(src,dst, patch_copy=True, copy_grad=True):
|
||||
#model=copy.deepcopy(fmodel) #Pas approprie, on ne souhaite que les poids/grad (pas tout fmodel et ses etats)
|
||||
|
||||
dst.load_state_dict(src.state_dict()) #Do not copy gradient !
|
||||
|
||||
if patch_copy:
|
||||
dst['model'].load_state_dict(src['model'].state_dict()) #Copie donnee manquante ?
|
||||
dst['data_aug'].load_state_dict(src['data_aug'].state_dict())
|
||||
|
||||
#Copie des gradients
|
||||
if copy_grad:
|
||||
for paramName, paramValue, in src.named_parameters():
|
||||
for netCopyName, netCopyValue, in dst.named_parameters():
|
||||
if paramName == netCopyName:
|
||||
netCopyValue.grad = paramValue.grad
|
||||
#netCopyValue=copy.deepcopy(paramValue)
|
||||
|
||||
try: #Data_augV4
|
||||
dst['data_aug']._input_info = src['data_aug']._input_info
|
||||
dst['data_aug']._TF_matrix = src['data_aug']._TF_matrix
|
||||
except:
|
||||
pass
|
||||
|
||||
def optim_copy(dopt, opt):
|
||||
|
||||
#inner_opt.load_state_dict(diffopt.state_dict()) #Besoin sauver etat otpim (momentum, etc.) => Ne copie pas le state...
|
||||
#opt_param=higher.optim.get_trainable_opt_params(diffopt)
|
||||
|
||||
for group_idx, group in enumerate(opt.param_groups):
|
||||
# print('gp idx',group_idx)
|
||||
for p_idx, p in enumerate(group['params']):
|
||||
opt.state[p]=dopt.state[group_idx][p_idx]
|
||||
|
||||
|
||||
#############
|
||||
|
||||
class Lambda(nn.Module):
|
||||
"Create a layer that simply calls `func` with `x`"
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func=func
|
||||
def forward(self, x): return self.func(x)
|
||||
|
||||
class SubsetSampler(torch.utils.data.SubsetRandomSampler):
|
||||
def __init__(self, indices):
|
||||
super().__init__(indices)
|
||||
|
||||
def __iter__(self):
|
||||
return (self.indices[i] for i in range(len(self.indices)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
def sharpness(img, factor):
|
||||
sharpness_factor = random.uniform(1, factor)
|
||||
sharp = ImageEnhance.Sharpness(img)
|
||||
sharped = sharp.enhance(sharpness_factor)
|
||||
return sharped
|
||||
|
||||
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, master_bar):
|
||||
model.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||
header = 'Epoch: {}'.format(epoch)
|
||||
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=master_bar):
|
||||
|
||||
image, target = image.to(device), target.to(device)
|
||||
output = model(image)
|
||||
loss = criterion(output, target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
acc1 = utils.accuracy(output, target)[0]
|
||||
batch_size = image.shape[0]
|
||||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||
metric_logger.update(loss=loss.item())
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
|
||||
return metric_logger.loss.global_avg, confmat
|
||||
|
||||
|
||||
def evaluate(model, criterion, data_loader, device):
|
||||
model.eval()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||
header = 'Test:'
|
||||
missed = []
|
||||
with torch.no_grad():
|
||||
for i, (image, target) in metric_logger.log_every(data_loader, leave=False, header=header, parent=None):
|
||||
image, target = image.to(device), target.to(device)
|
||||
output = model(image)
|
||||
loss = criterion(output, target)
|
||||
if target.item() != output.topk(1)[1].item():
|
||||
missed.append(data_loader.dataset.imgs[data_loader.sampler.indices[i]])
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
acc1 = utils.accuracy(output, target)[0]
|
||||
batch_size = image.shape[0]
|
||||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||
metric_logger.update(loss=loss.item())
|
||||
|
||||
|
||||
return metric_logger.loss.global_avg, missed, confmat
|
||||
|
||||
def get_train_valid_loader(args, augment, random_seed, train_size=0.5, test_size=0.1, shuffle=True, num_workers=4, pin_memory=True):
|
||||
"""
|
||||
Utility function for loading and returning train and valid
|
||||
multi-process iterators over the CIFAR-10 dataset. A sample
|
||||
9x9 grid of the images can be optionally displayed.
|
||||
If using CUDA, num_workers should be set to 1 and pin_memory to True.
|
||||
Params
|
||||
------
|
||||
- data_dir: path directory to the dataset.
|
||||
- batch_size: how many samples per batch to load.
|
||||
- augment: whether to apply the data augmentation scheme
|
||||
mentioned in the paper. Only applied on the train split.
|
||||
- random_seed: fix seed for reproducibility.
|
||||
- valid_size: percentage split of the training set used for
|
||||
the validation set. Should be a float in the range [0, 1].
|
||||
- shuffle: whether to shuffle the train/validation indices.
|
||||
- show_sample: plot 9x9 sample grid of the dataset.
|
||||
- num_workers: number of subprocesses to use when loading the dataset.
|
||||
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
|
||||
True if using GPU.
|
||||
Returns
|
||||
-------
|
||||
- train_loader: training set iterator.
|
||||
- valid_loader: validation set iterator.
|
||||
"""
|
||||
error_msg = "[!] test_size should be in the range [0, 1]."
|
||||
assert ((test_size >= 0) and (test_size <= 1)), error_msg
|
||||
|
||||
# normalize = transforms.Normalize(
|
||||
# mean=[0.4914, 0.4822, 0.4465],
|
||||
# std=[0.2023, 0.1994, 0.2010],
|
||||
# )
|
||||
|
||||
# define transforms
|
||||
if augment:
|
||||
train_transform = transforms.Compose([
|
||||
# transforms.ColorJitter(brightness=0.3),
|
||||
# transforms.Lambda(lambda img: sharpness(img, 5)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
# normalize,
|
||||
])
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
# transforms.ColorJitter(brightness=0.3),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
# normalize,
|
||||
])
|
||||
else:
|
||||
train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
# normalize,
|
||||
])
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
# normalize,
|
||||
])
|
||||
|
||||
|
||||
# load the dataset
|
||||
train_dataset = torchvision.datasets.ImageFolder(
|
||||
root=args.data_path, transform=train_transform
|
||||
)
|
||||
|
||||
test_dataset = torchvision.datasets.ImageFolder(
|
||||
root=args.data_path, transform=valid_transform
|
||||
)
|
||||
|
||||
num_train = len(train_dataset)
|
||||
indices = list(range(num_train))
|
||||
split = int(np.floor(test_size * num_train))
|
||||
|
||||
if shuffle:
|
||||
np.random.seed(random_seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_idx, test_idx = indices[split:], indices[:split]
|
||||
train_idx, valid_idx = train_idx[:int(len(train_idx)*train_size)], train_idx[int(len(train_idx)*train_size):]
|
||||
print("\nTrain", len(train_idx), "\nValid", len(valid_idx), "\nTest", len(test_idx))
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx) if not args.test_only else SubsetSampler(train_idx)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx) if not args.test_only else SubsetSampler(valid_idx)
|
||||
test_sampler = SubsetSampler(test_idx)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=train_sampler,
|
||||
num_workers=num_workers, pin_memory=pin_memory,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size if not args.test_only else 1, sampler=valid_sampler,
|
||||
num_workers=num_workers, pin_memory=pin_memory,
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=1, sampler=test_sampler,
|
||||
num_workers=num_workers, pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
imgs = np.asarray(train_dataset.imgs)
|
||||
|
||||
# print('Train')
|
||||
# print(imgs[train_idx])
|
||||
#print('Valid')
|
||||
#print(imgs[valid_idx])
|
||||
|
||||
return (train_loader, valid_loader, test_loader)
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
#augment = True if not args.test_only else False
|
||||
augment = False
|
||||
|
||||
data_loader, dl_val, data_loader_test = get_train_valid_loader(args=args, pin_memory=True, augment=augment,
|
||||
num_workers=args.workers, train_size=0.99, test_size=0.2, random_seed=999)
|
||||
|
||||
print("Creating model")
|
||||
model = torchvision.models.__dict__[args.model](pretrained=True)
|
||||
flat = list(model.children())
|
||||
|
||||
body, head = nn.Sequential(*flat[:-2]), nn.Sequential(flat[-2], Lambda(func=lambda x: torch.flatten(x, 1)), nn.Linear(flat[-1].in_features, len(data_loader.dataset.classes)))
|
||||
model = nn.Sequential(body, head)
|
||||
|
||||
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
|
||||
# optimizer = torch.optim.SGD(
|
||||
# model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
'''
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer,
|
||||
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
||||
'''
|
||||
es = utils.EarlyStopping()
|
||||
|
||||
if args.test_only:
|
||||
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||
model = model.to(device)
|
||||
print('TEST')
|
||||
_, missed, _ = evaluate(model, criterion, data_loader_test, device=device)
|
||||
print(missed)
|
||||
print('TRAIN')
|
||||
_, missed, _ = evaluate(model, criterion, data_loader, device=device)
|
||||
print(missed)
|
||||
return
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
print("Start training")
|
||||
start_time = time.time()
|
||||
mb = master_bar(range(args.epochs))
|
||||
"""
|
||||
for epoch in mb:
|
||||
_, train_confmat = train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mb)
|
||||
lr_scheduler.step( (epoch+1)*len(data_loader) )
|
||||
val_loss, _, valid_confmat = evaluate(model, criterion, data_loader_test, device=device)
|
||||
es(val_loss, model)
|
||||
|
||||
# print('Valid Missed')
|
||||
# print(valid_missed)
|
||||
|
||||
|
||||
# print('Train')
|
||||
# print(train_confmat)
|
||||
print('Valid')
|
||||
print(valid_confmat)
|
||||
|
||||
# if es.early_stop:
|
||||
# break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
"""
|
||||
|
||||
#######
|
||||
|
||||
inner_it = args.inner_it
|
||||
dataug_epoch_start=0
|
||||
print_freq=1
|
||||
KLdiv=False
|
||||
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=False, shared_mag=False), model).to(device)
|
||||
#model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
val_loss=torch.tensor(0) #Necessaire si pas de metastep sur une epoch
|
||||
dl_val_it = iter(dl_val)
|
||||
countcopy=0
|
||||
|
||||
#if inner_it!=0:
|
||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=args.lr) #lr=1e-2
|
||||
#inner_opt = torch.optim.SGD(model['model'].parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #lr=1e-2 / momentum=0.9
|
||||
inner_opt = torch.optim.Adam(model['model'].parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
inner_opt,
|
||||
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
||||
|
||||
high_grad_track = True
|
||||
if inner_it == 0:
|
||||
high_grad_track=False
|
||||
|
||||
model.train()
|
||||
model.augment(mode=False)
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel,track_higher_grads=high_grad_track)
|
||||
|
||||
i=0
|
||||
|
||||
for epoch in mb:
|
||||
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
confmat = utils.ConfusionMatrix(num_classes=len(data_loader.dataset.classes))
|
||||
header = 'Epoch: {}'.format(epoch)
|
||||
|
||||
t0 = time.process_time()
|
||||
for _, (image, target) in metric_logger.log_every(data_loader, header=header, parent=mb):
|
||||
#for i, (xs, ys) in enumerate(dl_train):
|
||||
#print_torch_mem("it"+str(i))
|
||||
i+=1
|
||||
image, target = image.to(device), target.to(device)
|
||||
|
||||
if(not KLdiv):
|
||||
#Methode uniforme
|
||||
logits = fmodel(image) # modified `params` can also be passed as a kwarg
|
||||
output = F.log_softmax(logits, dim=1)
|
||||
loss = F.cross_entropy(output, target, reduction='none') # no need to call loss.backwards()
|
||||
|
||||
if fmodel._data_augmentation: #Weight loss
|
||||
w_loss = fmodel['data_aug'].loss_weight()#.to(device)
|
||||
loss = loss * w_loss
|
||||
loss = loss.mean()
|
||||
|
||||
else:
|
||||
#Methode KL div
|
||||
fmodel.augment(mode=False)
|
||||
sup_logits = fmodel(xs)
|
||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||
fmodel.augment(mode=True)
|
||||
loss = F.cross_entropy(log_sup, ys)
|
||||
|
||||
if fmodel._data_augmentation:
|
||||
aug_logits = fmodel(xs)
|
||||
log_aug=F.log_softmax(aug_logits, dim=1)
|
||||
aug_loss=0
|
||||
if epoch>50: #debut differe ?
|
||||
#KL div w/ logits - Similarite predictions (distributions)
|
||||
aug_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_aug)
|
||||
aug_loss=aug_loss.sum(dim=-1)
|
||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none')
|
||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||
aug_loss = (w_loss * aug_loss).mean()
|
||||
|
||||
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||
#print(aug_loss)
|
||||
unsupp_coeff = 1
|
||||
loss += aug_loss * unsupp_coeff
|
||||
|
||||
diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
if(high_grad_track and i%inner_it==0): #Perform Meta step
|
||||
#print("meta")
|
||||
#Peu utile si high_grad_track = False
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val) + fmodel['data_aug'].reg_loss()
|
||||
#print_graph(val_loss)
|
||||
|
||||
val_loss.backward()
|
||||
|
||||
countcopy+=1
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
|
||||
#if epoch>50:
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
#model['data_aug'].next_TF_set()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
||||
|
||||
acc1 = utils.accuracy(output, target)[0]
|
||||
batch_size = image.shape[0]
|
||||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||
metric_logger.update(loss=loss.item())
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
if(not high_grad_track and (torch.cuda.memory_cached()/1024.0**2)>20000):
|
||||
countcopy+=1
|
||||
print_torch_mem("copy")
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
print_torch_mem("copy")
|
||||
|
||||
if(not high_grad_track):
|
||||
countcopy+=1
|
||||
print_torch_mem("end copy")
|
||||
model_copy(src=fmodel, dst=model)
|
||||
optim_copy(dopt=diffopt, opt=inner_opt)
|
||||
val_loss = compute_vaLoss(model=fmodel, dl_it=dl_val_it, dl=dl_val)
|
||||
|
||||
#Necessaire pour reset higher (Accumule les fast_param meme avec track_higher_grads = False)
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
print_torch_mem("end copy")
|
||||
|
||||
|
||||
tf = time.process_time()
|
||||
|
||||
|
||||
#### Print ####
|
||||
if(print_freq and epoch%print_freq==0):
|
||||
print('-'*9)
|
||||
print('Epoch : %d'%(epoch))
|
||||
print('Time : %.00f'%(tf - t0))
|
||||
print('Train loss :',loss.item(), '/ val loss', val_loss.item())
|
||||
print('Data Augmention : {} (Epoch {})'.format(model._data_augmentation, dataug_epoch_start))
|
||||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
#print('Aug loss', aug_loss.item())
|
||||
#############
|
||||
#### Log ####
|
||||
#print(type(model['data_aug']) is dataug.Data_augV5)
|
||||
'''
|
||||
param = [{'p': p.item(), 'm':model['data_aug']['mag'].item()} for p in model['data_aug']['prob']] if model['data_aug']._shared_mag else [{'p': p.item(), 'm': m.item()} for p, m in zip(model['data_aug']['prob'], model['data_aug']['mag'])]
|
||||
data={
|
||||
"epoch": epoch,
|
||||
"train_loss": loss.item(),
|
||||
"val_loss": val_loss.item(),
|
||||
"acc": accuracy,
|
||||
"time": tf - t0,
|
||||
|
||||
"param": param #if isinstance(model['data_aug'], Data_augV5)
|
||||
#else [p.item() for p in model['data_aug']['prob']],
|
||||
}
|
||||
log.append(data)
|
||||
'''
|
||||
#############
|
||||
|
||||
train_confmat=confmat
|
||||
lr_scheduler.step( (epoch+1)*len(data_loader) )
|
||||
|
||||
test_loss, _, test_confmat = evaluate(model, criterion, data_loader_test, device=device)
|
||||
es(test_loss, model)
|
||||
|
||||
# print('Valid Missed')
|
||||
# print(valid_missed)
|
||||
|
||||
|
||||
# print('Train')
|
||||
# print(train_confmat)
|
||||
print('Test')
|
||||
print(test_confmat)
|
||||
|
||||
# if es.early_stop:
|
||||
# break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
||||
|
||||
parser.add_argument('--data-path', default='/Salvador', help='dataset')
|
||||
parser.add_argument('--model', default='resnet50', help='model')
|
||||
parser.add_argument('--device', default='cuda:1', help='device')
|
||||
parser.add_argument('-b', '--batch-size', default=4, type=int)
|
||||
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 16)')
|
||||
parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
|
||||
parser.add_argument(
|
||||
"--test-only",
|
||||
dest="test_only",
|
||||
help="Only test the model",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
parser.add_argument('--in_it', '--inner_it', default=0, type=int,
|
||||
metavar='N', help='higher inner_it',
|
||||
dest='inner_it')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
Loading…
Add table
Add a link
Reference in a new issue