mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Brutus
This commit is contained in:
parent
53bd421670
commit
e291bc2e44
9 changed files with 55 additions and 44 deletions
|
@ -31,12 +31,12 @@ tf_names = [
|
|||
'ShearY',
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
'Contrast',
|
||||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'Posterize',
|
||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
#'Contrast',
|
||||
#'Color',
|
||||
#'Brightness',
|
||||
#'Sharpness',
|
||||
#'Posterize',
|
||||
#'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
]
|
||||
|
||||
class Lambda(nn.Module):
|
||||
|
@ -95,6 +95,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, mas
|
|||
|
||||
unsupp_coeff = 1
|
||||
loss = sup_loss + (aug_loss + KL_loss) * unsupp_coeff
|
||||
#print(sup_loss.item(), (aug_loss + KL_loss).item())
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
@ -210,7 +211,7 @@ def get_train_valid_loader(args, augment, random_seed, valid_size=0.1, shuffle=T
|
|||
split = int(np.floor(valid_size * num_train))
|
||||
|
||||
if shuffle:
|
||||
#np.random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_idx, valid_idx = indices[split:], indices[:split]
|
||||
|
@ -277,6 +278,8 @@ def main(args):
|
|||
model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||
|
||||
if args.augment=='RandKL': Kldiv=True
|
||||
|
||||
model['data_aug']['mag'].data = model['data_aug']['mag'].data * args.magnitude
|
||||
print("Augmodel")
|
||||
|
||||
# model.fc = nn.Linear(model.fc.in_features, 2)
|
||||
|
@ -294,7 +297,7 @@ def main(args):
|
|||
optimizer,
|
||||
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
|
||||
|
||||
es = utils.EarlyStopping()
|
||||
es = utils.EarlyStopping() if not (args.augment=='Rand' or args.augment=='RandKL') else utils.EarlyStopping(augmented_model=True)
|
||||
|
||||
if args.test_only:
|
||||
model.load_state_dict(torch.load('checkpoint.pt', map_location=lambda storage, loc: storage))
|
||||
|
@ -324,8 +327,8 @@ def main(args):
|
|||
|
||||
# print('Train')
|
||||
# print(train_confmat)
|
||||
print('Valid')
|
||||
print(valid_confmat)
|
||||
#print('Valid')
|
||||
#print(valid_confmat)
|
||||
|
||||
# if es.early_stop:
|
||||
# break
|
||||
|
@ -339,9 +342,9 @@ def parse_args():
|
|||
import argparse
|
||||
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
|
||||
|
||||
parser.add_argument('--data-path', default='/Salvador', help='dataset')
|
||||
parser.add_argument('--data-path', default='/github/smart_augmentation/salvador/data', help='dataset')
|
||||
parser.add_argument('--model', default='resnet18', help='model') #'resnet18'
|
||||
parser.add_argument('--device', default='cuda:1', help='device')
|
||||
parser.add_argument('--device', default='cuda:0', help='device')
|
||||
parser.add_argument('-b', '--batch-size', default=8, type=int)
|
||||
parser.add_argument('--epochs', default=3, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
|
@ -364,6 +367,10 @@ def parse_args():
|
|||
parser.add_argument('-a', '--augment', default='None', type=str,
|
||||
metavar='N', help='Data augment',
|
||||
dest='augment')
|
||||
parser.add_argument('-m', '--magnitude', default=1.0, type=float,
|
||||
metavar='N', help='Augmentation magnitude',
|
||||
dest='magnitude')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue