This commit is contained in:
Harle, Antoine (Contracteur) 2020-01-13 18:02:36 -05:00
parent 53bd421670
commit e291bc2e44
9 changed files with 55 additions and 44 deletions

Binary file not shown.

View file

@ -31,12 +31,12 @@ tf_names = [
'ShearY',
## Color TF (Expect image in the range of [0, 1]) ##
'Contrast',
'Color',
'Brightness',
'Sharpness',
'Posterize',
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
#'Contrast',
#'Color',
#'Brightness',
#'Sharpness',
#'Posterize',
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
]
class Lambda(nn.Module):
@ -95,6 +95,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mas
unsupp_coeff = 1
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
#print(sup_loss.item(), (aug_loss + KL_loss).item())
optimizer.zero_grad()
loss.backward()
@ -210,7 +211,7 @@ def get_train_valid_loader(args, augment, random_seed, valid_size=0.1, shuffle=T
split = int(np.floor(valid_size * num_train))
if shuffle:
#np.random.seed(random_seed)
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
@ -277,6 +278,8 @@ def main(args):
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
if args.augment=='RandKL': Kldiv=True
model['data_aug']['mag'].data = model['data_aug']['mag'].data * args.magnitude
print("Augmodel")
# model.fc = nn.Linear(model.fc.in_features, 2)
@ -294,7 +297,7 @@ def main(args):
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
es = utils.EarlyStopping()
es = utils.EarlyStopping() if not (args.augment=='Rand' or args.augment=='RandKL') else utils.EarlyStopping(augmented_model=True)
if args.test_only:
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
@ -324,8 +327,8 @@ def main(args):
# print('Train')
# print(train_confmat)
print('Valid')
print(valid_confmat)
#print('Valid')
#print(valid_confmat)
# if es.early_stop:
# break
@ -339,9 +342,9 @@ def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path', default='/Salvador', help='dataset')
parser.add_argument('--data-path', default='/github/smart_augmentation/salvador/data', help='dataset')
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
parser.add_argument('--device', default='cuda:1', help='device')
parser.add_argument('--device', default='cuda:0', help='device')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--epochs', default=3, type=int, metavar='N',
help='number of total epochs to run')
@ -364,6 +367,10 @@ def parse_args():
parser.add_argument('-a', '--augment', default='None', type=str,
metavar='N', help='Data augment',
dest='augment')
parser.add_argument('-m', '--magnitude', default=1.0, type=float,
metavar='N', help='Augmentation magnitude',
dest='magnitude')
args = parser.parse_args()

View file

@ -549,10 +549,10 @@ def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data-path', default='/Salvador', help='dataset')
parser.add_argument('--model', default='resnet50', help='model')
parser.add_argument('--device', default='cuda:1', help='device')
parser.add_argument('-b', '--batch-size', default=4, type=int)
parser.add_argument('--data-path', default='/github/smart_augmentation/salvador/data', help='dataset')
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
parser.add_argument('--device', default='cuda:0', help='device')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--epochs', default=3, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',

View file

@ -157,7 +157,7 @@ def accuracy(output, target, topk=(1,)):
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0):
def __init__(self, patience=7, verbose=False, delta=0, augmented_model=False):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
@ -175,6 +175,8 @@ class EarlyStopping:
self.val_loss_min = np.Inf
self.delta = delta
self.augmented_model = augmented_model
def __call__(self, val_loss, model):
score = -val_loss
@ -196,5 +198,5 @@ class EarlyStopping:
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), 'checkpoint.pt')
torch.save(model.state_dict(), 'checkpoint.pt') if not self.augmented_model else torch.save(model['model'].state_dict(), 'checkpoint.pt')
self.val_loss_min = val_loss