mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Rangement
This commit is contained in:
parent
ca3367d19f
commit
4166922c34
453 changed files with 9797 additions and 7 deletions
73
Old/PBA/LeNet.py
Executable file
73
Old/PBA/LeNet.py
Executable file
|
@ -0,0 +1,73 @@
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
## build the neural network class
|
||||
# weight initialization
|
||||
def weight_variable(shape, name = None):
|
||||
initial = tf.truncated_normal(shape, stddev=0.1)
|
||||
return tf.Variable(initial, name = name)
|
||||
|
||||
# bias initialization
|
||||
def bias_variable(shape, name = None):
|
||||
initial = tf.constant(0.1, shape=shape) # positive bias
|
||||
return tf.Variable(initial, name = name)
|
||||
|
||||
# 2D convolution
|
||||
def conv2d(x, W, name = None):
|
||||
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID', name = name)
|
||||
|
||||
# max pooling
|
||||
def max_pool_2x2(x, name = None):
|
||||
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
|
||||
padding='SAME', name = name)
|
||||
|
||||
def LeNet(images, num_classes):
|
||||
# tunable hyperparameters for nn architecture
|
||||
s_f_conv1 = 5; # filter size of first convolution layer (default = 3)
|
||||
n_f_conv1 = 20; # number of features of first convolution layer (default = 36)
|
||||
s_f_conv2 = 5; # filter size of second convolution layer (default = 3)
|
||||
n_f_conv2 = 50; # number of features of second convolution layer (default = 36)
|
||||
n_n_fc1 = 500; # number of neurons of first fully connected layer (default = 576)
|
||||
n_n_fc2 = 500; # number of neurons of first fully connected layer (default = 576)
|
||||
|
||||
#print(images.shape)
|
||||
# 1.layer: convolution + max pooling
|
||||
W_conv1_tf = weight_variable([s_f_conv1, s_f_conv1, int(images.shape[3]), n_f_conv1], name = 'W_conv1_tf') # (5,5,1,32)
|
||||
b_conv1_tf = bias_variable([n_f_conv1], name = 'b_conv1_tf') # (32)
|
||||
h_conv1_tf = tf.nn.relu(conv2d(images, W_conv1_tf) + b_conv1_tf, name = 'h_conv1_tf') # (.,28,28,32)
|
||||
h_pool1_tf = max_pool_2x2(h_conv1_tf, name = 'h_pool1_tf') # (.,14,14,32)
|
||||
#print(h_conv1_tf.shape)
|
||||
#print(h_pool1_tf.shape)
|
||||
# 2.layer: convolution + max pooling
|
||||
W_conv2_tf = weight_variable([s_f_conv2, s_f_conv2, n_f_conv1, n_f_conv2], name = 'W_conv2_tf')
|
||||
b_conv2_tf = bias_variable([n_f_conv2], name = 'b_conv2_tf')
|
||||
h_conv2_tf = tf.nn.relu(conv2d(h_pool1_tf, W_conv2_tf) + b_conv2_tf, name ='h_conv2_tf') #(.,14,14,32)
|
||||
h_pool2_tf = max_pool_2x2(h_conv2_tf, name = 'h_pool2_tf') #(.,7,7,32)
|
||||
|
||||
#print(h_pool2_tf.shape)
|
||||
|
||||
# 4.layer: fully connected
|
||||
W_fc1_tf = weight_variable([5*5*n_f_conv2,n_n_fc1], name = 'W_fc1_tf') # (4*4*32, 1024)
|
||||
b_fc1_tf = bias_variable([n_n_fc1], name = 'b_fc1_tf') # (1024)
|
||||
h_pool2_flat_tf = tf.reshape(h_pool2_tf, [int(h_pool2_tf.shape[0]), -1], name = 'h_pool3_flat_tf') # (.,1024)
|
||||
h_fc1_tf = tf.nn.relu(tf.matmul(h_pool2_flat_tf, W_fc1_tf) + b_fc1_tf,
|
||||
name = 'h_fc1_tf') # (.,1024)
|
||||
|
||||
# add dropout
|
||||
#keep_prob_tf = tf.placeholder(dtype=tf.float32, name = 'keep_prob_tf')
|
||||
#h_fc1_drop_tf = tf.nn.dropout(h_fc1_tf, keep_prob_tf, name = 'h_fc1_drop_tf')
|
||||
print(h_fc1_tf.shape)
|
||||
|
||||
# 5.layer: fully connected
|
||||
W_fc2_tf = weight_variable([n_n_fc1, num_classes], name = 'W_fc2_tf')
|
||||
b_fc2_tf = bias_variable([num_classes], name = 'b_fc2_tf')
|
||||
z_pred_tf = tf.add(tf.matmul(h_fc1_tf, W_fc2_tf), b_fc2_tf, name = 'z_pred_tf')# => (.,10)
|
||||
# predicted probabilities in one-hot encoding
|
||||
#y_pred_proba_tf = tf.nn.softmax(z_pred_tf, name='y_pred_proba_tf')
|
||||
|
||||
# tensor of correct predictions
|
||||
#y_pred_correct_tf = tf.equal(tf.argmax(y_pred_proba_tf, 1),
|
||||
# tf.argmax(y_data_tf, 1),
|
||||
# name = 'y_pred_correct_tf')
|
||||
logits = z_pred_tf
|
||||
return logits #y_pred_proba_tf
|
353
Old/PBA/model.py
Executable file
353
Old/PBA/model.py
Executable file
|
@ -0,0 +1,353 @@
|
|||
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""PBA & AutoAugment Train/Eval module.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import autoaugment.custom_ops as ops
|
||||
from autoaugment.shake_drop import build_shake_drop_model
|
||||
from autoaugment.shake_shake import build_shake_shake_model
|
||||
import pba.data_utils as data_utils
|
||||
import pba.helper_utils as helper_utils
|
||||
from pba.wrn import build_wrn_model
|
||||
from pba.resnet import build_resnet_model
|
||||
|
||||
from pba.LeNet import LeNet
|
||||
|
||||
arg_scope = tf.contrib.framework.arg_scope
|
||||
|
||||
|
||||
def setup_arg_scopes(is_training):
|
||||
"""Sets up the argscopes that will be used when building an image model.
|
||||
|
||||
Args:
|
||||
is_training: Is the model training or not.
|
||||
|
||||
Returns:
|
||||
Arg scopes to be put around the model being constructed.
|
||||
"""
|
||||
|
||||
batch_norm_decay = 0.9
|
||||
batch_norm_epsilon = 1e-5
|
||||
batch_norm_params = {
|
||||
# Decay for the moving averages.
|
||||
'decay': batch_norm_decay,
|
||||
# epsilon to prevent 0s in variance.
|
||||
'epsilon': batch_norm_epsilon,
|
||||
'scale': True,
|
||||
# collection containing the moving mean and moving variance.
|
||||
'is_training': is_training,
|
||||
}
|
||||
|
||||
scopes = []
|
||||
|
||||
scopes.append(arg_scope([ops.batch_norm], **batch_norm_params))
|
||||
return scopes
|
||||
|
||||
|
||||
def build_model(inputs, num_classes, is_training, hparams):
|
||||
"""Constructs the vision model being trained/evaled.
|
||||
|
||||
Args:
|
||||
inputs: input features/images being fed to the image model build built.
|
||||
num_classes: number of output classes being predicted.
|
||||
is_training: is the model training or not.
|
||||
hparams: additional hyperparameters associated with the image model.
|
||||
|
||||
Returns:
|
||||
The logits of the image model.
|
||||
"""
|
||||
scopes = setup_arg_scopes(is_training)
|
||||
if len(scopes) != 1:
|
||||
raise ValueError('Nested scopes depreciated in py3.')
|
||||
with scopes[0]:
|
||||
if hparams.model_name == 'pyramid_net':
|
||||
logits = build_shake_drop_model(inputs, num_classes, is_training)
|
||||
elif hparams.model_name == 'wrn':
|
||||
logits = build_wrn_model(inputs, num_classes, hparams.wrn_size)
|
||||
elif hparams.model_name == 'shake_shake':
|
||||
logits = build_shake_shake_model(inputs, num_classes, hparams,
|
||||
is_training)
|
||||
elif hparams.model_name == 'resnet':
|
||||
logits = build_resnet_model(inputs, num_classes, hparams,
|
||||
is_training)
|
||||
elif hparams.model_name == 'LeNet':
|
||||
logits = LeNet(inputs, num_classes)
|
||||
else:
|
||||
raise ValueError("Unknown model name.")
|
||||
return logits
|
||||
|
||||
|
||||
class Model(object):
|
||||
"""Builds an model."""
|
||||
|
||||
def __init__(self, hparams, num_classes, image_size):
|
||||
self.hparams = hparams
|
||||
self.num_classes = num_classes
|
||||
self.image_size = image_size
|
||||
|
||||
def build(self, mode):
|
||||
"""Construct the model."""
|
||||
assert mode in ['train', 'eval']
|
||||
self.mode = mode
|
||||
self._setup_misc(mode)
|
||||
self._setup_images_and_labels(self.hparams.dataset)
|
||||
self._build_graph(self.images, self.labels, mode)
|
||||
|
||||
self.init = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
|
||||
def _setup_misc(self, mode):
|
||||
"""Sets up miscellaneous in the model constructor."""
|
||||
self.lr_rate_ph = tf.Variable(0.0, name='lrn_rate', trainable=False)
|
||||
self.reuse = None if (mode == 'train') else True
|
||||
self.batch_size = self.hparams.batch_size
|
||||
if mode == 'eval':
|
||||
self.batch_size = self.hparams.test_batch_size
|
||||
|
||||
def _setup_images_and_labels(self, dataset):
|
||||
"""Sets up image and label placeholders for the model."""
|
||||
if dataset == 'cifar10' or dataset == 'cifar100' or self.mode == 'train':
|
||||
self.images = tf.placeholder(tf.float32,
|
||||
[self.batch_size, self.image_size, self.image_size, 3])
|
||||
self.labels = tf.placeholder(tf.float32,
|
||||
[self.batch_size, self.num_classes])
|
||||
else:
|
||||
self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 3])
|
||||
self.labels = tf.placeholder(tf.float32, [None, self.num_classes])
|
||||
|
||||
def assign_epoch(self, session, epoch_value):
|
||||
session.run(
|
||||
self._epoch_update, feed_dict={self._new_epoch: epoch_value})
|
||||
|
||||
def _build_graph(self, images, labels, mode):
|
||||
"""Constructs the TF graph for the model.
|
||||
|
||||
Args:
|
||||
images: A 4-D image Tensor
|
||||
labels: A 2-D labels Tensor.
|
||||
mode: string indicating training mode ( e.g., 'train', 'valid', 'test').
|
||||
"""
|
||||
is_training = 'train' in mode
|
||||
if is_training:
|
||||
self.global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
logits = build_model(images, self.num_classes, is_training,
|
||||
self.hparams)
|
||||
self.predictions, self.cost = helper_utils.setup_loss(logits, labels)
|
||||
|
||||
self._calc_num_trainable_params()
|
||||
|
||||
# Adds L2 weight decay to the cost
|
||||
self.cost = helper_utils.decay_weights(self.cost,
|
||||
self.hparams.weight_decay_rate)
|
||||
|
||||
if is_training:
|
||||
self._build_train_op()
|
||||
|
||||
# Setup checkpointing for this child model
|
||||
# Keep 2 or more checkpoints around during training.
|
||||
with tf.device('/cpu:0'):
|
||||
self.saver = tf.train.Saver(max_to_keep=10)
|
||||
|
||||
self.init = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
|
||||
def _calc_num_trainable_params(self):
|
||||
self.num_trainable_params = np.sum([
|
||||
np.prod(var.get_shape().as_list())
|
||||
for var in tf.trainable_variables()
|
||||
])
|
||||
tf.logging.info('number of trainable params: {}'.format(
|
||||
self.num_trainable_params))
|
||||
|
||||
def _build_train_op(self):
|
||||
"""Builds the train op for the model."""
|
||||
hparams = self.hparams
|
||||
tvars = tf.trainable_variables()
|
||||
grads = tf.gradients(self.cost, tvars)
|
||||
if hparams.gradient_clipping_by_global_norm > 0.0:
|
||||
grads, norm = tf.clip_by_global_norm(
|
||||
grads, hparams.gradient_clipping_by_global_norm)
|
||||
tf.summary.scalar('grad_norm', norm)
|
||||
|
||||
# Setup the initial learning rate
|
||||
initial_lr = self.lr_rate_ph
|
||||
optimizer = tf.train.MomentumOptimizer(
|
||||
initial_lr, 0.9, use_nesterov=True)
|
||||
|
||||
self.optimizer = optimizer
|
||||
apply_op = optimizer.apply_gradients(
|
||||
zip(grads, tvars), global_step=self.global_step, name='train_step')
|
||||
train_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
||||
with tf.control_dependencies([apply_op]):
|
||||
self.train_op = tf.group(*train_ops)
|
||||
|
||||
|
||||
class ModelTrainer(object):
|
||||
"""Trains an instance of the Model class."""
|
||||
|
||||
def __init__(self, hparams):
|
||||
self._session = None
|
||||
self.hparams = hparams
|
||||
|
||||
# Set the random seed to be sure the same validation set
|
||||
# is used for each model
|
||||
np.random.seed(0)
|
||||
self.data_loader = data_utils.DataSet(hparams)
|
||||
np.random.seed() # Put the random seed back to random
|
||||
self.data_loader.reset()
|
||||
|
||||
# extra stuff for ray
|
||||
self._build_models()
|
||||
self._new_session()
|
||||
self._session.__enter__()
|
||||
|
||||
def save_model(self, checkpoint_dir, step=None):
|
||||
"""Dumps model into the backup_dir.
|
||||
|
||||
Args:
|
||||
step: If provided, creates a checkpoint with the given step
|
||||
number, instead of overwriting the existing checkpoints.
|
||||
"""
|
||||
model_save_name = os.path.join(checkpoint_dir,
|
||||
'model.ckpt') + '-' + str(step)
|
||||
save_path = self.saver.save(self.session, model_save_name)
|
||||
tf.logging.info('Saved child model')
|
||||
return model_save_name
|
||||
|
||||
def extract_model_spec(self, checkpoint_path):
|
||||
"""Loads a checkpoint with the architecture structure stored in the name."""
|
||||
self.saver.restore(self.session, checkpoint_path)
|
||||
tf.logging.warning(
|
||||
'Loaded child model checkpoint from {}'.format(checkpoint_path))
|
||||
|
||||
def eval_child_model(self, model, data_loader, mode):
|
||||
"""Evaluate the child model.
|
||||
|
||||
Args:
|
||||
model: image model that will be evaluated.
|
||||
data_loader: dataset object to extract eval data from.
|
||||
mode: will the model be evalled on train, val or test.
|
||||
|
||||
Returns:
|
||||
Accuracy of the model on the specified dataset.
|
||||
"""
|
||||
tf.logging.info('Evaluating child model in mode {}'.format(mode))
|
||||
while True:
|
||||
try:
|
||||
accuracy = helper_utils.eval_child_model(
|
||||
self.session, model, data_loader, mode)
|
||||
tf.logging.info(
|
||||
'Eval child model accuracy: {}'.format(accuracy))
|
||||
# If epoch trained without raising the below errors, break
|
||||
# from loop.
|
||||
break
|
||||
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
|
||||
tf.logging.info(
|
||||
'Retryable error caught: {}. Retrying.'.format(e))
|
||||
|
||||
return accuracy
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _new_session(self):
|
||||
"""Creates a new session for model m."""
|
||||
# Create a new session for this model, initialize
|
||||
# variables, and save / restore from checkpoint.
|
||||
sess_cfg = tf.ConfigProto(
|
||||
allow_soft_placement=True, log_device_placement=False)
|
||||
sess_cfg.gpu_options.allow_growth = True
|
||||
self._session = tf.Session('', config=sess_cfg)
|
||||
self._session.run([self.m.init, self.meval.init])
|
||||
return self._session
|
||||
|
||||
def _build_models(self):
|
||||
"""Builds the image models for train and eval."""
|
||||
# Determine if we should build the train and eval model. When using
|
||||
# distributed training we only want to build one or the other and not both.
|
||||
with tf.variable_scope('model', use_resource=False):
|
||||
m = Model(self.hparams, self.data_loader.num_classes, self.data_loader.image_size)
|
||||
m.build('train')
|
||||
self._num_trainable_params = m.num_trainable_params
|
||||
self._saver = m.saver
|
||||
with tf.variable_scope('model', reuse=True, use_resource=False):
|
||||
meval = Model(self.hparams, self.data_loader.num_classes, self.data_loader.image_size)
|
||||
meval.build('eval')
|
||||
self.m = m
|
||||
self.meval = meval
|
||||
|
||||
def _run_training_loop(self, curr_epoch):
|
||||
"""Trains the model `m` for one epoch."""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
train_accuracy = helper_utils.run_epoch_training(
|
||||
self.session, self.m, self.data_loader, curr_epoch)
|
||||
break
|
||||
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
|
||||
tf.logging.info(
|
||||
'Retryable error caught: {}. Retrying.'.format(e))
|
||||
tf.logging.info('Finished epoch: {}'.format(curr_epoch))
|
||||
tf.logging.info('Epoch time(min): {}'.format(
|
||||
(time.time() - start_time) / 60.0))
|
||||
return train_accuracy
|
||||
|
||||
def _compute_final_accuracies(self, iteration):
|
||||
"""Run once training is finished to compute final test accuracy."""
|
||||
if (iteration >= self.hparams.num_epochs - 1):
|
||||
test_accuracy = self.eval_child_model(self.meval, self.data_loader,
|
||||
'test')
|
||||
else:
|
||||
test_accuracy = 0
|
||||
tf.logging.info('Test Accuracy: {}'.format(test_accuracy))
|
||||
return test_accuracy
|
||||
|
||||
def run_model(self, epoch):
|
||||
"""Trains and evalutes the image model."""
|
||||
valid_accuracy = 0.
|
||||
training_accuracy = self._run_training_loop(epoch)
|
||||
if self.hparams.validation_size > 0:
|
||||
valid_accuracy = self.eval_child_model(self.meval,
|
||||
self.data_loader, 'val')
|
||||
tf.logging.info('Train Acc: {}, Valid Acc: {}'.format(
|
||||
training_accuracy, valid_accuracy))
|
||||
return training_accuracy, valid_accuracy
|
||||
|
||||
def reset_config(self, new_hparams):
|
||||
self.hparams = new_hparams
|
||||
self.data_loader.reset_policy(new_hparams)
|
||||
return
|
||||
|
||||
@property
|
||||
def saver(self):
|
||||
return self._saver
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self._session
|
||||
|
||||
@property
|
||||
def num_trainable_params(self):
|
||||
return self._num_trainable_params
|
59
Old/PBA/search.sh
Executable file
59
Old/PBA/search.sh
Executable file
|
@ -0,0 +1,59 @@
|
|||
#!/bin/bash
|
||||
export PYTHONPATH="$(pwd)"
|
||||
|
||||
cifar10_LeNet_search() {
|
||||
local_dir="$PWD/results/"
|
||||
data_path="$PWD/datasets/cifar-10-batches-py"
|
||||
|
||||
python pba/search.py \
|
||||
--local_dir "$local_dir" \
|
||||
--model_name LeNet \
|
||||
--data_path "$data_path" --dataset cifar10 \
|
||||
--train_size 4000 --val_size 46000 \
|
||||
--checkpoint_freq 0 \
|
||||
--name "cifar10_search" --gpu 0.15 --cpu 2 \
|
||||
--num_samples 16 --perturbation_interval 3 --epochs 150 \
|
||||
--explore cifar10 --aug_policy cifar10 \
|
||||
--lr 0.1 --wd 0.0005
|
||||
}
|
||||
|
||||
cifar10_search() {
|
||||
local_dir="$PWD/results/"
|
||||
data_path="$PWD/datasets/cifar-10-batches-py"
|
||||
|
||||
python pba/search.py \
|
||||
--local_dir "$local_dir" \
|
||||
--model_name wrn_40_2 \
|
||||
--data_path "$data_path" --dataset cifar10 \
|
||||
--train_size 4000 --val_size 46000 \
|
||||
--checkpoint_freq 0 \
|
||||
--name "cifar10_search" --gpu 0.15 --cpu 2 \
|
||||
--num_samples 16 --perturbation_interval 3 --epochs 200 \
|
||||
--explore cifar10 --aug_policy cifar10 \
|
||||
--lr 0.1 --wd 0.0005
|
||||
}
|
||||
|
||||
svhn_search() {
|
||||
local_dir="$PWD/results/"
|
||||
data_path="$PWD/datasets/"
|
||||
|
||||
python pba/search.py \
|
||||
--local_dir "$local_dir" --data_path "$data_path" \
|
||||
--model_name wrn_40_2 --dataset svhn \
|
||||
--train_size 1000 --val_size 7325 \
|
||||
--checkpoint_freq 0 \
|
||||
--name "svhn_search" --gpu 0.19 --cpu 2 \
|
||||
--num_samples 16 --perturbation_interval 3 --epochs 160 \
|
||||
--explore cifar10 --aug_policy cifar10 --no_cutout \
|
||||
--lr 0.1 --wd 0.005
|
||||
}
|
||||
|
||||
if [ "$1" = "rcifar10" ]; then
|
||||
cifar10_search
|
||||
elif [ "$1" = "rsvhn" ]; then
|
||||
svhn_search
|
||||
elif [ "$1" = "LeNet" ]; then
|
||||
cifar10_LeNet_search
|
||||
else
|
||||
echo "invalid args"
|
||||
fi
|
210
Old/PBA/setup.py
Executable file
210
Old/PBA/setup.py
Executable file
|
@ -0,0 +1,210 @@
|
|||
"""Parse flags and set up hyperparameters."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import tensorflow as tf
|
||||
|
||||
from pba.augmentation_transforms_hp import NUM_HP_TRANSFORM
|
||||
|
||||
|
||||
def create_parser(state):
|
||||
"""Create arg parser for flags."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_name',
|
||||
default='wrn',
|
||||
choices=('wrn_28_10', 'wrn_40_2', 'shake_shake_32', 'shake_shake_96',
|
||||
'shake_shake_112', 'pyramid_net', 'resnet', 'LeNet'))
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
default='/tmp/datasets/',
|
||||
help='Directory where dataset is located.')
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
default='cifar10',
|
||||
choices=('cifar10', 'cifar100', 'svhn', 'svhn-full', 'test'))
|
||||
parser.add_argument(
|
||||
'--recompute_dset_stats',
|
||||
action='store_true',
|
||||
help='Instead of using hardcoded mean/std, recompute from dataset.')
|
||||
parser.add_argument('--local_dir', type=str, default='/tmp/ray_results/', help='Ray directory.')
|
||||
parser.add_argument('--restore', type=str, default=None, help='If specified, tries to restore from given path.')
|
||||
parser.add_argument('--train_size', type=int, default=5000, help='Number of training examples.')
|
||||
parser.add_argument('--val_size', type=int, default=45000, help='Number of validation examples.')
|
||||
parser.add_argument('--checkpoint_freq', type=int, default=50, help='Checkpoint frequency.')
|
||||
parser.add_argument(
|
||||
'--cpu', type=float, default=4, help='Allocated by Ray')
|
||||
parser.add_argument(
|
||||
'--gpu', type=float, default=1, help='Allocated by Ray')
|
||||
parser.add_argument(
|
||||
'--aug_policy',
|
||||
type=str,
|
||||
default='cifar10',
|
||||
help=
|
||||
'which augmentation policy to use (in augmentation_transforms_hp.py)')
|
||||
# search-use only
|
||||
parser.add_argument(
|
||||
'--explore',
|
||||
type=str,
|
||||
default='cifar10',
|
||||
help='which explore function to use')
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Number of epochs, or <=0 for default')
|
||||
parser.add_argument(
|
||||
'--no_cutout', action='store_true', help='turn off cutout')
|
||||
parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
|
||||
parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')
|
||||
parser.add_argument('--bs', type=int, default=128, help='batch size')
|
||||
parser.add_argument('--test_bs', type=int, default=25, help='test batch size')
|
||||
parser.add_argument('--num_samples', type=int, default=1, help='Number of Ray samples')
|
||||
|
||||
if state == 'train':
|
||||
parser.add_argument(
|
||||
'--use_hp_policy',
|
||||
action='store_true',
|
||||
help='otherwise use autoaug policy')
|
||||
parser.add_argument(
|
||||
'--hp_policy',
|
||||
type=str,
|
||||
default=None,
|
||||
help='either a comma separated list of values or a file')
|
||||
parser.add_argument(
|
||||
'--hp_policy_epochs',
|
||||
type=int,
|
||||
default=200,
|
||||
help='number of epochs/iterations policy trained for')
|
||||
parser.add_argument(
|
||||
'--no_aug',
|
||||
action='store_true',
|
||||
help=
|
||||
'no additional augmentation at all (besides cutout if not toggled)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--flatten',
|
||||
action='store_true',
|
||||
help='randomly select aug policy from schedule')
|
||||
parser.add_argument('--name', type=str, default='autoaug')
|
||||
|
||||
elif state == 'search':
|
||||
parser.add_argument('--perturbation_interval', type=int, default=10)
|
||||
parser.add_argument('--name', type=str, default='autoaug_pbt')
|
||||
else:
|
||||
raise ValueError('unknown state')
|
||||
args = parser.parse_args()
|
||||
tf.logging.info(str(args))
|
||||
return args
|
||||
|
||||
|
||||
def create_hparams(state, FLAGS): # pylint: disable=invalid-name
|
||||
"""Creates hyperparameters to pass into Ray config.
|
||||
|
||||
Different options depending on search or eval mode.
|
||||
|
||||
Args:
|
||||
state: a string, 'train' or 'search'.
|
||||
FLAGS: parsed command line flags.
|
||||
|
||||
Returns:
|
||||
tf.hparams object.
|
||||
"""
|
||||
epochs = 0
|
||||
tf.logging.info('data path: {}'.format(FLAGS.data_path))
|
||||
hparams = tf.contrib.training.HParams(
|
||||
train_size=FLAGS.train_size,
|
||||
validation_size=FLAGS.val_size,
|
||||
dataset=FLAGS.dataset,
|
||||
data_path=FLAGS.data_path,
|
||||
batch_size=FLAGS.bs,
|
||||
gradient_clipping_by_global_norm=5.0,
|
||||
explore=FLAGS.explore,
|
||||
aug_policy=FLAGS.aug_policy,
|
||||
no_cutout=FLAGS.no_cutout,
|
||||
recompute_dset_stats=FLAGS.recompute_dset_stats,
|
||||
lr=FLAGS.lr,
|
||||
weight_decay_rate=FLAGS.wd,
|
||||
test_batch_size=FLAGS.test_bs)
|
||||
|
||||
if state == 'train':
|
||||
hparams.add_hparam('no_aug', FLAGS.no_aug)
|
||||
hparams.add_hparam('use_hp_policy', FLAGS.use_hp_policy)
|
||||
if FLAGS.use_hp_policy:
|
||||
if FLAGS.hp_policy == 'random':
|
||||
tf.logging.info('RANDOM SEARCH')
|
||||
parsed_policy = []
|
||||
for i in range(NUM_HP_TRANSFORM * 4):
|
||||
if i % 2 == 0:
|
||||
parsed_policy.append(random.randint(0, 10))
|
||||
else:
|
||||
parsed_policy.append(random.randint(0, 9))
|
||||
elif FLAGS.hp_policy.endswith('.txt') or FLAGS.hp_policy.endswith(
|
||||
'.p'):
|
||||
# will be loaded in in data_utils
|
||||
parsed_policy = FLAGS.hp_policy
|
||||
else:
|
||||
# parse input into a fixed augmentation policy
|
||||
parsed_policy = FLAGS.hp_policy.split(', ')
|
||||
parsed_policy = [int(p) for p in parsed_policy]
|
||||
hparams.add_hparam('hp_policy', parsed_policy)
|
||||
hparams.add_hparam('hp_policy_epochs', FLAGS.hp_policy_epochs)
|
||||
hparams.add_hparam('flatten', FLAGS.flatten)
|
||||
elif state == 'search':
|
||||
hparams.add_hparam('no_aug', False)
|
||||
hparams.add_hparam('use_hp_policy', True)
|
||||
# default start value of 0
|
||||
hparams.add_hparam('hp_policy',
|
||||
[0 for _ in range(4 * NUM_HP_TRANSFORM)])
|
||||
else:
|
||||
raise ValueError('unknown state')
|
||||
|
||||
if FLAGS.model_name == 'wrn_40_2':
|
||||
hparams.add_hparam('model_name', 'wrn')
|
||||
epochs = 200
|
||||
hparams.add_hparam('wrn_size', 32)
|
||||
hparams.add_hparam('wrn_depth', 40)
|
||||
elif FLAGS.model_name == 'wrn_28_10':
|
||||
hparams.add_hparam('model_name', 'wrn')
|
||||
epochs = 200
|
||||
hparams.add_hparam('wrn_size', 160)
|
||||
hparams.add_hparam('wrn_depth', 28)
|
||||
elif FLAGS.model_name == 'resnet':
|
||||
hparams.add_hparam('model_name', 'resnet')
|
||||
epochs = 200
|
||||
hparams.add_hparam('resnet_size', 20)
|
||||
hparams.add_hparam('num_filters', 32)
|
||||
elif FLAGS.model_name == 'shake_shake_32':
|
||||
hparams.add_hparam('model_name', 'shake_shake')
|
||||
epochs = 1800
|
||||
hparams.add_hparam('shake_shake_widen_factor', 2)
|
||||
elif FLAGS.model_name == 'shake_shake_96':
|
||||
hparams.add_hparam('model_name', 'shake_shake')
|
||||
epochs = 1800
|
||||
hparams.add_hparam('shake_shake_widen_factor', 6)
|
||||
elif FLAGS.model_name == 'shake_shake_112':
|
||||
hparams.add_hparam('model_name', 'shake_shake')
|
||||
epochs = 1800
|
||||
hparams.add_hparam('shake_shake_widen_factor', 7)
|
||||
elif FLAGS.model_name == 'pyramid_net':
|
||||
hparams.add_hparam('model_name', 'pyramid_net')
|
||||
epochs = 1800
|
||||
hparams.set_hparam('batch_size', 64)
|
||||
|
||||
elif FLAGS.model_name == 'LeNet':
|
||||
hparams.add_hparam('model_name', 'LeNet')
|
||||
epochs = 200
|
||||
|
||||
else:
|
||||
raise ValueError('Not Valid Model Name: %s' % FLAGS.model_name)
|
||||
if FLAGS.epochs > 0:
|
||||
tf.logging.info('overwriting with custom epochs')
|
||||
epochs = FLAGS.epochs
|
||||
hparams.add_hparam('num_epochs', epochs)
|
||||
tf.logging.info('epochs: {}, lr: {}, wd: {}'.format(
|
||||
hparams.num_epochs, hparams.lr, hparams.weight_decay_rate))
|
||||
return hparams
|
41
Old/PBA/table_1_cifar10.sh
Executable file
41
Old/PBA/table_1_cifar10.sh
Executable file
|
@ -0,0 +1,41 @@
|
|||
#!/bin/bash
|
||||
export PYTHONPATH="$(pwd)"
|
||||
|
||||
# args: [model name] [lr] [wd] #Learning rate / weight decay
|
||||
eval_cifar10() {
|
||||
hp_policy="$PWD/schedules/rcifar10_16_wrn.txt"
|
||||
local_dir="$PWD/results/"
|
||||
data_path="$PWD/datasets/cifar-10-batches-py"
|
||||
|
||||
size=50000
|
||||
dataset="cifar10"
|
||||
name="eval_cifar10_$1" # has 8 cutout size
|
||||
|
||||
python pba/train.py \
|
||||
--local_dir "$local_dir" --data_path "$data_path" \
|
||||
--model_name "$1" --dataset "$dataset" \
|
||||
--train_size "$size" --val_size 0 \
|
||||
--checkpoint_freq 25 --gpu 1 --cpu 4 \
|
||||
--use_hp_policy --hp_policy "$hp_policy" \
|
||||
--hp_policy_epochs 200 \
|
||||
--aug_policy cifar10 --name "$name" \
|
||||
--lr "$2" --wd "$3"
|
||||
}
|
||||
|
||||
if [ "$@" = "wrn_28_10" ]; then
|
||||
eval_cifar10 wrn_28_10 0.1 0.0005
|
||||
elif [ "$@" = "ss_32" ]; then
|
||||
eval_cifar10 shake_shake_32 0.01 0.001
|
||||
elif [ "$@" = "ss_96" ]; then
|
||||
eval_cifar10 shake_shake_96 0.01 0.001
|
||||
elif [ "$@" = "ss_112" ]; then
|
||||
eval_cifar10 shake_shake_112 0.01 0.001
|
||||
elif [ "$@" = "pyramid_net" ]; then
|
||||
eval_cifar10 pyramid_net 0.05 0.00005
|
||||
|
||||
elif [ "$@" = "LeNet" ]; then
|
||||
eval_cifar10 LeNet 0.05 0.0
|
||||
|
||||
else
|
||||
echo "invalid args"
|
||||
fi
|
Loading…
Add table
Add a link
Reference in a new issue