"""Parse flags and set up hyperparameters.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import random import tensorflow as tf from pba.augmentation_transforms_hp import NUM_HP_TRANSFORM def create_parser(state): """Create arg parser for flags.""" parser = argparse.ArgumentParser() parser.add_argument( '--model_name', default='wrn', choices=('wrn_28_10', 'wrn_40_2', 'shake_shake_32', 'shake_shake_96', 'shake_shake_112', 'pyramid_net', 'resnet', 'LeNet')) parser.add_argument( '--data_path', default='/tmp/datasets/', help='Directory where dataset is located.') parser.add_argument( '--dataset', default='cifar10', choices=('cifar10', 'cifar100', 'svhn', 'svhn-full', 'test')) parser.add_argument( '--recompute_dset_stats', action='store_true', help='Instead of using hardcoded mean/std, recompute from dataset.') parser.add_argument('--local_dir', type=str, default='/tmp/ray_results/', help='Ray directory.') parser.add_argument('--restore', type=str, default=None, help='If specified, tries to restore from given path.') parser.add_argument('--train_size', type=int, default=5000, help='Number of training examples.') parser.add_argument('--val_size', type=int, default=45000, help='Number of validation examples.') parser.add_argument('--checkpoint_freq', type=int, default=50, help='Checkpoint frequency.') parser.add_argument( '--cpu', type=float, default=4, help='Allocated by Ray') parser.add_argument( '--gpu', type=float, default=1, help='Allocated by Ray') parser.add_argument( '--aug_policy', type=str, default='cifar10', help= 'which augmentation policy to use (in augmentation_transforms_hp.py)') # search-use only parser.add_argument( '--explore', type=str, default='cifar10', help='which explore function to use') parser.add_argument( '--epochs', type=int, default=0, help='Number of epochs, or <=0 for default') parser.add_argument( '--no_cutout', action='store_true', help='turn off cutout') parser.add_argument('--lr', type=float, default=0.1, help='learning rate') parser.add_argument('--wd', type=float, default=0.0005, help='weight decay') parser.add_argument('--bs', type=int, default=128, help='batch size') parser.add_argument('--test_bs', type=int, default=25, help='test batch size') parser.add_argument('--num_samples', type=int, default=1, help='Number of Ray samples') if state == 'train': parser.add_argument( '--use_hp_policy', action='store_true', help='otherwise use autoaug policy') parser.add_argument( '--hp_policy', type=str, default=None, help='either a comma separated list of values or a file') parser.add_argument( '--hp_policy_epochs', type=int, default=200, help='number of epochs/iterations policy trained for') parser.add_argument( '--no_aug', action='store_true', help= 'no additional augmentation at all (besides cutout if not toggled)' ) parser.add_argument( '--flatten', action='store_true', help='randomly select aug policy from schedule') parser.add_argument('--name', type=str, default='autoaug') elif state == 'search': parser.add_argument('--perturbation_interval', type=int, default=10) parser.add_argument('--name', type=str, default='autoaug_pbt') else: raise ValueError('unknown state') args = parser.parse_args() tf.logging.info(str(args)) return args def create_hparams(state, FLAGS): # pylint: disable=invalid-name """Creates hyperparameters to pass into Ray config. Different options depending on search or eval mode. Args: state: a string, 'train' or 'search'. FLAGS: parsed command line flags. Returns: tf.hparams object. """ epochs = 0 tf.logging.info('data path: {}'.format(FLAGS.data_path)) hparams = tf.contrib.training.HParams( train_size=FLAGS.train_size, validation_size=FLAGS.val_size, dataset=FLAGS.dataset, data_path=FLAGS.data_path, batch_size=FLAGS.bs, gradient_clipping_by_global_norm=5.0, explore=FLAGS.explore, aug_policy=FLAGS.aug_policy, no_cutout=FLAGS.no_cutout, recompute_dset_stats=FLAGS.recompute_dset_stats, lr=FLAGS.lr, weight_decay_rate=FLAGS.wd, test_batch_size=FLAGS.test_bs) if state == 'train': hparams.add_hparam('no_aug', FLAGS.no_aug) hparams.add_hparam('use_hp_policy', FLAGS.use_hp_policy) if FLAGS.use_hp_policy: if FLAGS.hp_policy == 'random': tf.logging.info('RANDOM SEARCH') parsed_policy = [] for i in range(NUM_HP_TRANSFORM * 4): if i % 2 == 0: parsed_policy.append(random.randint(0, 10)) else: parsed_policy.append(random.randint(0, 9)) elif FLAGS.hp_policy.endswith('.txt') or FLAGS.hp_policy.endswith( '.p'): # will be loaded in in data_utils parsed_policy = FLAGS.hp_policy else: # parse input into a fixed augmentation policy parsed_policy = FLAGS.hp_policy.split(', ') parsed_policy = [int(p) for p in parsed_policy] hparams.add_hparam('hp_policy', parsed_policy) hparams.add_hparam('hp_policy_epochs', FLAGS.hp_policy_epochs) hparams.add_hparam('flatten', FLAGS.flatten) elif state == 'search': hparams.add_hparam('no_aug', False) hparams.add_hparam('use_hp_policy', True) # default start value of 0 hparams.add_hparam('hp_policy', [0 for _ in range(4 * NUM_HP_TRANSFORM)]) else: raise ValueError('unknown state') if FLAGS.model_name == 'wrn_40_2': hparams.add_hparam('model_name', 'wrn') epochs = 200 hparams.add_hparam('wrn_size', 32) hparams.add_hparam('wrn_depth', 40) elif FLAGS.model_name == 'wrn_28_10': hparams.add_hparam('model_name', 'wrn') epochs = 200 hparams.add_hparam('wrn_size', 160) hparams.add_hparam('wrn_depth', 28) elif FLAGS.model_name == 'resnet': hparams.add_hparam('model_name', 'resnet') epochs = 200 hparams.add_hparam('resnet_size', 20) hparams.add_hparam('num_filters', 32) elif FLAGS.model_name == 'shake_shake_32': hparams.add_hparam('model_name', 'shake_shake') epochs = 1800 hparams.add_hparam('shake_shake_widen_factor', 2) elif FLAGS.model_name == 'shake_shake_96': hparams.add_hparam('model_name', 'shake_shake') epochs = 1800 hparams.add_hparam('shake_shake_widen_factor', 6) elif FLAGS.model_name == 'shake_shake_112': hparams.add_hparam('model_name', 'shake_shake') epochs = 1800 hparams.add_hparam('shake_shake_widen_factor', 7) elif FLAGS.model_name == 'pyramid_net': hparams.add_hparam('model_name', 'pyramid_net') epochs = 1800 hparams.set_hparam('batch_size', 64) elif FLAGS.model_name == 'LeNet': hparams.add_hparam('model_name', 'LeNet') epochs = 200 else: raise ValueError('Not Valid Model Name: %s' % FLAGS.model_name) if FLAGS.epochs > 0: tf.logging.info('overwriting with custom epochs') epochs = FLAGS.epochs hparams.add_hparam('num_epochs', epochs) tf.logging.info('epochs: {}, lr: {}, wd: {}'.format( hparams.num_epochs, hparams.lr, hparams.weight_decay_rate)) return hparams