diff --git a/PBA/LeNet.py b/PBA/LeNet.py new file mode 100644 index 0000000..659badf --- /dev/null +++ b/PBA/LeNet.py @@ -0,0 +1,79 @@ +import numpy as np +import tensorflow as tf + +## build the neural network class +# weight initialization +def weight_variable(shape, name = None): + initial = tf.truncated_normal(shape, stddev=0.1) + return tf.Variable(initial, name = name) + +# bias initialization +def bias_variable(shape, name = None): + initial = tf.constant(0.1, shape=shape) # positive bias + return tf.Variable(initial, name = name) + +# 2D convolution +def conv2d(x, W, name = None): + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME', name = name) + +# max pooling +def max_pool_2x2(x, name = None): + return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], + padding='SAME', name = name) + +def LeNet(images, num_classes): + # tunable hyperparameters for nn architecture + s_f_conv1 = 5; # filter size of first convolution layer (default = 3) + n_f_conv1 = 20; # number of features of first convolution layer (default = 36) + s_f_conv2 = 5; # filter size of second convolution layer (default = 3) + n_f_conv2 = 50; # number of features of second convolution layer (default = 36) + n_n_fc1 = 500; # number of neurons of first fully connected layer (default = 576) + n_n_fc2 = 500; # number of neurons of first fully connected layer (default = 576) + + # 1.layer: convolution + max pooling + W_conv1_tf = weight_variable([s_f_conv1, s_f_conv1, 1, n_f_conv1], name = 'W_conv1_tf') # (5,5,1,32) + b_conv1_tf = bias_variable([n_f_conv1], name = 'b_conv1_tf') # (32) + h_conv1_tf = tf.nn.relu(conv2d(images, + W_conv1_tf) + b_conv1_tf, + name = 'h_conv1_tf') # (.,28,28,32) + h_pool1_tf = max_pool_2x2(h_conv1_tf, + name = 'h_pool1_tf') # (.,14,14,32) + + # 2.layer: convolution + max pooling + W_conv2_tf = weight_variable([s_f_conv2, s_f_conv2, + n_f_conv1, n_f_conv2], + name = 'W_conv2_tf') + b_conv2_tf = bias_variable([n_f_conv2], name = 'b_conv2_tf') + h_conv2_tf = tf.nn.relu(conv2d(h_pool1_tf, + W_conv2_tf) + b_conv2_tf, + name ='h_conv2_tf') #(.,14,14,32) + h_pool2_tf = max_pool_2x2(h_conv2_tf, name = 'h_pool2_tf') #(.,7,7,32) + + # 4.layer: fully connected + W_fc1_tf = weight_variable([5*5*n_f_conv2,n_n_fc1], + name = 'W_fc1_tf') # (4*4*32, 1024) + b_fc1_tf = bias_variable([n_n_fc1], name = 'b_fc1_tf') # (1024) + h_pool2_flat_tf = tf.reshape(h_pool2_tf, [-1,5*5*n_f_conv2], + name = 'h_pool3_flat_tf') # (.,1024) + h_fc1_tf = tf.nn.relu(tf.matmul(h_pool2_flat_tf, + W_fc1_tf) + b_fc1_tf, + name = 'h_fc1_tf') # (.,1024) + + # add dropout + #keep_prob_tf = tf.placeholder(dtype=tf.float32, name = 'keep_prob_tf') + #h_fc1_drop_tf = tf.nn.dropout(h_fc1_tf, keep_prob_tf, name = 'h_fc1_drop_tf') + + # 5.layer: fully connected + W_fc2_tf = weight_variable([n_n_fc1, num_classes], name = 'W_fc2_tf') + b_fc2_tf = bias_variable([num_classes], name = 'b_fc2_tf') + z_pred_tf = tf.add(tf.matmul(h_fc1_tf, W_fc2_tf), + b_fc2_tf, name = 'z_pred_tf')# => (.,10) + # predicted probabilities in one-hot encoding + #y_pred_proba_tf = tf.nn.softmax(z_pred_tf, name='y_pred_proba_tf') + + # tensor of correct predictions + #y_pred_correct_tf = tf.equal(tf.argmax(y_pred_proba_tf, 1), + # tf.argmax(y_data_tf, 1), + # name = 'y_pred_correct_tf') + logits = z_pred_tf + return logits #y_pred_proba_tf \ No newline at end of file diff --git a/PBA/model.py b/PBA/model.py new file mode 100755 index 0000000..47a0aa9 --- /dev/null +++ b/PBA/model.py @@ -0,0 +1,353 @@ +# Copyright 2018 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""PBA & AutoAugment Train/Eval module. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import os +import time + +import numpy as np +import tensorflow as tf + +import autoaugment.custom_ops as ops +from autoaugment.shake_drop import build_shake_drop_model +from autoaugment.shake_shake import build_shake_shake_model +import pba.data_utils as data_utils +import pba.helper_utils as helper_utils +from pba.wrn import build_wrn_model +from pba.resnet import build_resnet_model + +from pba.LeNet import LeNet + +arg_scope = tf.contrib.framework.arg_scope + + +def setup_arg_scopes(is_training): + """Sets up the argscopes that will be used when building an image model. + + Args: + is_training: Is the model training or not. + + Returns: + Arg scopes to be put around the model being constructed. + """ + + batch_norm_decay = 0.9 + batch_norm_epsilon = 1e-5 + batch_norm_params = { + # Decay for the moving averages. + 'decay': batch_norm_decay, + # epsilon to prevent 0s in variance. + 'epsilon': batch_norm_epsilon, + 'scale': True, + # collection containing the moving mean and moving variance. + 'is_training': is_training, + } + + scopes = [] + + scopes.append(arg_scope([ops.batch_norm], **batch_norm_params)) + return scopes + + +def build_model(inputs, num_classes, is_training, hparams): + """Constructs the vision model being trained/evaled. + + Args: + inputs: input features/images being fed to the image model build built. + num_classes: number of output classes being predicted. + is_training: is the model training or not. + hparams: additional hyperparameters associated with the image model. + + Returns: + The logits of the image model. + """ + scopes = setup_arg_scopes(is_training) + if len(scopes) != 1: + raise ValueError('Nested scopes depreciated in py3.') + with scopes[0]: + if hparams.model_name == 'pyramid_net': + logits = build_shake_drop_model(inputs, num_classes, is_training) + elif hparams.model_name == 'wrn': + logits = build_wrn_model(inputs, num_classes, hparams.wrn_size) + elif hparams.model_name == 'shake_shake': + logits = build_shake_shake_model(inputs, num_classes, hparams, + is_training) + elif hparams.model_name == 'resnet': + logits = build_resnet_model(inputs, num_classes, hparams, + is_training) + elif hparams.model_name == 'LeNet': + logits = LeNet(inputs, num_classes) + else: + raise ValueError("Unknown model name.") + return logits + + +class Model(object): + """Builds an model.""" + + def __init__(self, hparams, num_classes, image_size): + self.hparams = hparams + self.num_classes = num_classes + self.image_size = image_size + + def build(self, mode): + """Construct the model.""" + assert mode in ['train', 'eval'] + self.mode = mode + self._setup_misc(mode) + self._setup_images_and_labels(self.hparams.dataset) + self._build_graph(self.images, self.labels, mode) + + self.init = tf.group(tf.global_variables_initializer(), + tf.local_variables_initializer()) + + def _setup_misc(self, mode): + """Sets up miscellaneous in the model constructor.""" + self.lr_rate_ph = tf.Variable(0.0, name='lrn_rate', trainable=False) + self.reuse = None if (mode == 'train') else True + self.batch_size = self.hparams.batch_size + if mode == 'eval': + self.batch_size = self.hparams.test_batch_size + + def _setup_images_and_labels(self, dataset): + """Sets up image and label placeholders for the model.""" + if dataset == 'cifar10' or dataset == 'cifar100' or self.mode == 'train': + self.images = tf.placeholder(tf.float32, + [self.batch_size, self.image_size, self.image_size, 3]) + self.labels = tf.placeholder(tf.float32, + [self.batch_size, self.num_classes]) + else: + self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 3]) + self.labels = tf.placeholder(tf.float32, [None, self.num_classes]) + + def assign_epoch(self, session, epoch_value): + session.run( + self._epoch_update, feed_dict={self._new_epoch: epoch_value}) + + def _build_graph(self, images, labels, mode): + """Constructs the TF graph for the model. + + Args: + images: A 4-D image Tensor + labels: A 2-D labels Tensor. + mode: string indicating training mode ( e.g., 'train', 'valid', 'test'). + """ + is_training = 'train' in mode + if is_training: + self.global_step = tf.train.get_or_create_global_step() + + logits = build_model(images, self.num_classes, is_training, + self.hparams) + self.predictions, self.cost = helper_utils.setup_loss(logits, labels) + + self._calc_num_trainable_params() + + # Adds L2 weight decay to the cost + self.cost = helper_utils.decay_weights(self.cost, + self.hparams.weight_decay_rate) + + if is_training: + self._build_train_op() + + # Setup checkpointing for this child model + # Keep 2 or more checkpoints around during training. + with tf.device('/cpu:0'): + self.saver = tf.train.Saver(max_to_keep=10) + + self.init = tf.group(tf.global_variables_initializer(), + tf.local_variables_initializer()) + + def _calc_num_trainable_params(self): + self.num_trainable_params = np.sum([ + np.prod(var.get_shape().as_list()) + for var in tf.trainable_variables() + ]) + tf.logging.info('number of trainable params: {}'.format( + self.num_trainable_params)) + + def _build_train_op(self): + """Builds the train op for the model.""" + hparams = self.hparams + tvars = tf.trainable_variables() + grads = tf.gradients(self.cost, tvars) + if hparams.gradient_clipping_by_global_norm > 0.0: + grads, norm = tf.clip_by_global_norm( + grads, hparams.gradient_clipping_by_global_norm) + tf.summary.scalar('grad_norm', norm) + + # Setup the initial learning rate + initial_lr = self.lr_rate_ph + optimizer = tf.train.MomentumOptimizer( + initial_lr, 0.9, use_nesterov=True) + + self.optimizer = optimizer + apply_op = optimizer.apply_gradients( + zip(grads, tvars), global_step=self.global_step, name='train_step') + train_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies([apply_op]): + self.train_op = tf.group(*train_ops) + + +class ModelTrainer(object): + """Trains an instance of the Model class.""" + + def __init__(self, hparams): + self._session = None + self.hparams = hparams + + # Set the random seed to be sure the same validation set + # is used for each model + np.random.seed(0) + self.data_loader = data_utils.DataSet(hparams) + np.random.seed() # Put the random seed back to random + self.data_loader.reset() + + # extra stuff for ray + self._build_models() + self._new_session() + self._session.__enter__() + + def save_model(self, checkpoint_dir, step=None): + """Dumps model into the backup_dir. + + Args: + step: If provided, creates a checkpoint with the given step + number, instead of overwriting the existing checkpoints. + """ + model_save_name = os.path.join(checkpoint_dir, + 'model.ckpt') + '-' + str(step) + save_path = self.saver.save(self.session, model_save_name) + tf.logging.info('Saved child model') + return model_save_name + + def extract_model_spec(self, checkpoint_path): + """Loads a checkpoint with the architecture structure stored in the name.""" + self.saver.restore(self.session, checkpoint_path) + tf.logging.warning( + 'Loaded child model checkpoint from {}'.format(checkpoint_path)) + + def eval_child_model(self, model, data_loader, mode): + """Evaluate the child model. + + Args: + model: image model that will be evaluated. + data_loader: dataset object to extract eval data from. + mode: will the model be evalled on train, val or test. + + Returns: + Accuracy of the model on the specified dataset. + """ + tf.logging.info('Evaluating child model in mode {}'.format(mode)) + while True: + try: + accuracy = helper_utils.eval_child_model( + self.session, model, data_loader, mode) + tf.logging.info( + 'Eval child model accuracy: {}'.format(accuracy)) + # If epoch trained without raising the below errors, break + # from loop. + break + except (tf.errors.AbortedError, tf.errors.UnavailableError) as e: + tf.logging.info( + 'Retryable error caught: {}. Retrying.'.format(e)) + + return accuracy + + @contextlib.contextmanager + def _new_session(self): + """Creates a new session for model m.""" + # Create a new session for this model, initialize + # variables, and save / restore from checkpoint. + sess_cfg = tf.ConfigProto( + allow_soft_placement=True, log_device_placement=False) + sess_cfg.gpu_options.allow_growth = True + self._session = tf.Session('', config=sess_cfg) + self._session.run([self.m.init, self.meval.init]) + return self._session + + def _build_models(self): + """Builds the image models for train and eval.""" + # Determine if we should build the train and eval model. When using + # distributed training we only want to build one or the other and not both. + with tf.variable_scope('model', use_resource=False): + m = Model(self.hparams, self.data_loader.num_classes, self.data_loader.image_size) + m.build('train') + self._num_trainable_params = m.num_trainable_params + self._saver = m.saver + with tf.variable_scope('model', reuse=True, use_resource=False): + meval = Model(self.hparams, self.data_loader.num_classes, self.data_loader.image_size) + meval.build('eval') + self.m = m + self.meval = meval + + def _run_training_loop(self, curr_epoch): + """Trains the model `m` for one epoch.""" + start_time = time.time() + while True: + try: + train_accuracy = helper_utils.run_epoch_training( + self.session, self.m, self.data_loader, curr_epoch) + break + except (tf.errors.AbortedError, tf.errors.UnavailableError) as e: + tf.logging.info( + 'Retryable error caught: {}. Retrying.'.format(e)) + tf.logging.info('Finished epoch: {}'.format(curr_epoch)) + tf.logging.info('Epoch time(min): {}'.format( + (time.time() - start_time) / 60.0)) + return train_accuracy + + def _compute_final_accuracies(self, iteration): + """Run once training is finished to compute final test accuracy.""" + if (iteration >= self.hparams.num_epochs - 1): + test_accuracy = self.eval_child_model(self.meval, self.data_loader, + 'test') + else: + test_accuracy = 0 + tf.logging.info('Test Accuracy: {}'.format(test_accuracy)) + return test_accuracy + + def run_model(self, epoch): + """Trains and evalutes the image model.""" + valid_accuracy = 0. + training_accuracy = self._run_training_loop(epoch) + if self.hparams.validation_size > 0: + valid_accuracy = self.eval_child_model(self.meval, + self.data_loader, 'val') + tf.logging.info('Train Acc: {}, Valid Acc: {}'.format( + training_accuracy, valid_accuracy)) + return training_accuracy, valid_accuracy + + def reset_config(self, new_hparams): + self.hparams = new_hparams + self.data_loader.reset_policy(new_hparams) + return + + @property + def saver(self): + return self._saver + + @property + def session(self): + return self._session + + @property + def num_trainable_params(self): + return self._num_trainable_params diff --git a/PBA/setup.py b/PBA/setup.py new file mode 100755 index 0000000..cc9b38b --- /dev/null +++ b/PBA/setup.py @@ -0,0 +1,210 @@ +"""Parse flags and set up hyperparameters.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import random +import tensorflow as tf + +from pba.augmentation_transforms_hp import NUM_HP_TRANSFORM + + +def create_parser(state): + """Create arg parser for flags.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_name', + default='wrn', + choices=('wrn_28_10', 'wrn_40_2', 'shake_shake_32', 'shake_shake_96', + 'shake_shake_112', 'pyramid_net', 'resnet', 'LeNet')) + parser.add_argument( + '--data_path', + default='/tmp/datasets/', + help='Directory where dataset is located.') + parser.add_argument( + '--dataset', + default='cifar10', + choices=('cifar10', 'cifar100', 'svhn', 'svhn-full', 'test')) + parser.add_argument( + '--recompute_dset_stats', + action='store_true', + help='Instead of using hardcoded mean/std, recompute from dataset.') + parser.add_argument('--local_dir', type=str, default='/tmp/ray_results/', help='Ray directory.') + parser.add_argument('--restore', type=str, default=None, help='If specified, tries to restore from given path.') + parser.add_argument('--train_size', type=int, default=5000, help='Number of training examples.') + parser.add_argument('--val_size', type=int, default=45000, help='Number of validation examples.') + parser.add_argument('--checkpoint_freq', type=int, default=50, help='Checkpoint frequency.') + parser.add_argument( + '--cpu', type=float, default=4, help='Allocated by Ray') + parser.add_argument( + '--gpu', type=float, default=1, help='Allocated by Ray') + parser.add_argument( + '--aug_policy', + type=str, + default='cifar10', + help= + 'which augmentation policy to use (in augmentation_transforms_hp.py)') + # search-use only + parser.add_argument( + '--explore', + type=str, + default='cifar10', + help='which explore function to use') + parser.add_argument( + '--epochs', + type=int, + default=0, + help='Number of epochs, or <=0 for default') + parser.add_argument( + '--no_cutout', action='store_true', help='turn off cutout') + parser.add_argument('--lr', type=float, default=0.1, help='learning rate') + parser.add_argument('--wd', type=float, default=0.0005, help='weight decay') + parser.add_argument('--bs', type=int, default=128, help='batch size') + parser.add_argument('--test_bs', type=int, default=25, help='test batch size') + parser.add_argument('--num_samples', type=int, default=1, help='Number of Ray samples') + + if state == 'train': + parser.add_argument( + '--use_hp_policy', + action='store_true', + help='otherwise use autoaug policy') + parser.add_argument( + '--hp_policy', + type=str, + default=None, + help='either a comma separated list of values or a file') + parser.add_argument( + '--hp_policy_epochs', + type=int, + default=200, + help='number of epochs/iterations policy trained for') + parser.add_argument( + '--no_aug', + action='store_true', + help= + 'no additional augmentation at all (besides cutout if not toggled)' + ) + parser.add_argument( + '--flatten', + action='store_true', + help='randomly select aug policy from schedule') + parser.add_argument('--name', type=str, default='autoaug') + + elif state == 'search': + parser.add_argument('--perturbation_interval', type=int, default=10) + parser.add_argument('--name', type=str, default='autoaug_pbt') + else: + raise ValueError('unknown state') + args = parser.parse_args() + tf.logging.info(str(args)) + return args + + +def create_hparams(state, FLAGS): # pylint: disable=invalid-name + """Creates hyperparameters to pass into Ray config. + + Different options depending on search or eval mode. + + Args: + state: a string, 'train' or 'search'. + FLAGS: parsed command line flags. + + Returns: + tf.hparams object. + """ + epochs = 0 + tf.logging.info('data path: {}'.format(FLAGS.data_path)) + hparams = tf.contrib.training.HParams( + train_size=FLAGS.train_size, + validation_size=FLAGS.val_size, + dataset=FLAGS.dataset, + data_path=FLAGS.data_path, + batch_size=FLAGS.bs, + gradient_clipping_by_global_norm=5.0, + explore=FLAGS.explore, + aug_policy=FLAGS.aug_policy, + no_cutout=FLAGS.no_cutout, + recompute_dset_stats=FLAGS.recompute_dset_stats, + lr=FLAGS.lr, + weight_decay_rate=FLAGS.wd, + test_batch_size=FLAGS.test_bs) + + if state == 'train': + hparams.add_hparam('no_aug', FLAGS.no_aug) + hparams.add_hparam('use_hp_policy', FLAGS.use_hp_policy) + if FLAGS.use_hp_policy: + if FLAGS.hp_policy == 'random': + tf.logging.info('RANDOM SEARCH') + parsed_policy = [] + for i in range(NUM_HP_TRANSFORM * 4): + if i % 2 == 0: + parsed_policy.append(random.randint(0, 10)) + else: + parsed_policy.append(random.randint(0, 9)) + elif FLAGS.hp_policy.endswith('.txt') or FLAGS.hp_policy.endswith( + '.p'): + # will be loaded in in data_utils + parsed_policy = FLAGS.hp_policy + else: + # parse input into a fixed augmentation policy + parsed_policy = FLAGS.hp_policy.split(', ') + parsed_policy = [int(p) for p in parsed_policy] + hparams.add_hparam('hp_policy', parsed_policy) + hparams.add_hparam('hp_policy_epochs', FLAGS.hp_policy_epochs) + hparams.add_hparam('flatten', FLAGS.flatten) + elif state == 'search': + hparams.add_hparam('no_aug', False) + hparams.add_hparam('use_hp_policy', True) + # default start value of 0 + hparams.add_hparam('hp_policy', + [0 for _ in range(4 * NUM_HP_TRANSFORM)]) + else: + raise ValueError('unknown state') + + if FLAGS.model_name == 'wrn_40_2': + hparams.add_hparam('model_name', 'wrn') + epochs = 200 + hparams.add_hparam('wrn_size', 32) + hparams.add_hparam('wrn_depth', 40) + elif FLAGS.model_name == 'wrn_28_10': + hparams.add_hparam('model_name', 'wrn') + epochs = 200 + hparams.add_hparam('wrn_size', 160) + hparams.add_hparam('wrn_depth', 28) + elif FLAGS.model_name == 'resnet': + hparams.add_hparam('model_name', 'resnet') + epochs = 200 + hparams.add_hparam('resnet_size', 20) + hparams.add_hparam('num_filters', 32) + elif FLAGS.model_name == 'shake_shake_32': + hparams.add_hparam('model_name', 'shake_shake') + epochs = 1800 + hparams.add_hparam('shake_shake_widen_factor', 2) + elif FLAGS.model_name == 'shake_shake_96': + hparams.add_hparam('model_name', 'shake_shake') + epochs = 1800 + hparams.add_hparam('shake_shake_widen_factor', 6) + elif FLAGS.model_name == 'shake_shake_112': + hparams.add_hparam('model_name', 'shake_shake') + epochs = 1800 + hparams.add_hparam('shake_shake_widen_factor', 7) + elif FLAGS.model_name == 'pyramid_net': + hparams.add_hparam('model_name', 'pyramid_net') + epochs = 1800 + hparams.set_hparam('batch_size', 64) + + elif FLAGS.model_name == 'LeNet': + hparams.add_hparam('model_name', 'LeNet') + epochs = 200 + + else: + raise ValueError('Not Valid Model Name: %s' % FLAGS.model_name) + if FLAGS.epochs > 0: + tf.logging.info('overwriting with custom epochs') + epochs = FLAGS.epochs + hparams.add_hparam('num_epochs', epochs) + tf.logging.info('epochs: {}, lr: {}, wd: {}'.format( + hparams.num_epochs, hparams.lr, hparams.weight_decay_rate)) + return hparams diff --git a/PBA/table_1_cifar10.sh b/PBA/table_1_cifar10.sh new file mode 100755 index 0000000..4d35bd6 --- /dev/null +++ b/PBA/table_1_cifar10.sh @@ -0,0 +1,41 @@ +#!/bin/bash +export PYTHONPATH="$(pwd)" + +# args: [model name] [lr] [wd] #Learning rate / weight decay +eval_cifar10() { + hp_policy="$PWD/schedules/rcifar10_16_wrn.txt" + local_dir="$PWD/results/" + data_path="$PWD/datasets/cifar-10-batches-py" + + size=50000 + dataset="cifar10" + name="eval_cifar10_$1" # has 8 cutout size + + python pba/train.py \ + --local_dir "$local_dir" --data_path "$data_path" \ + --model_name "$1" --dataset "$dataset" \ + --train_size "$size" --val_size 0 \ + --checkpoint_freq 25 --gpu 1 --cpu 4 \ + --use_hp_policy --hp_policy "$hp_policy" \ + --hp_policy_epochs 200 \ + --aug_policy cifar10 --name "$name" \ + --lr "$2" --wd "$3" +} + +if [ "$@" = "wrn_28_10" ]; then + eval_cifar10 wrn_28_10 0.1 0.0005 +elif [ "$@" = "ss_32" ]; then + eval_cifar10 shake_shake_32 0.01 0.001 +elif [ "$@" = "ss_96" ]; then + eval_cifar10 shake_shake_96 0.01 0.001 +elif [ "$@" = "ss_112" ]; then + eval_cifar10 shake_shake_112 0.01 0.001 +elif [ "$@" = "pyramid_net" ]; then + eval_cifar10 pyramid_net 0.05 0.00005 + +elif [ "$@" = "LeNet" ]; then + eval_cifar10 LeNet 0.05 0.0 + +else + echo "invalid args" +fi diff --git a/higher/test_dataug.py b/higher/test_dataug.py index de65d64..bd0c023 100644 --- a/higher/test_dataug.py +++ b/higher/test_dataug.py @@ -91,7 +91,7 @@ if __name__ == "__main__": #''' res_folder="res/TF_nb_tests/" epochs= 100 - inner_its = [0, 10] + inner_its = [0, 1, 10] dataug_epoch_starts= [0] TF_nb = [len(TF.TF_dict)] #range(10,len(TF.TF_dict)+1) #[len(TF.TF_dict)] N_seq_TF= [2, 3, 4, 6]