smart_augmentation/PBA/table_1_cifar10.sh
2019-11-14 17:59:53 -05:00

41 lines
1.1 KiB
Bash
Executable file

#!/bin/bash
export PYTHONPATH="$(pwd)"
# args: [model name] [lr] [wd] #Learning rate / weight decay
eval_cifar10() {
hp_policy="$PWD/schedules/rcifar10_16_wrn.txt"
local_dir="$PWD/results/"
data_path="$PWD/datasets/cifar-10-batches-py"
size=50000
dataset="cifar10"
name="eval_cifar10_$1" # has 8 cutout size
python pba/train.py \
--local_dir "$local_dir" --data_path "$data_path" \
--model_name "$1" --dataset "$dataset" \
--train_size "$size" --val_size 0 \
--checkpoint_freq 25 --gpu 1 --cpu 4 \
--use_hp_policy --hp_policy "$hp_policy" \
--hp_policy_epochs 200 \
--aug_policy cifar10 --name "$name" \
--lr "$2" --wd "$3"
}
if [ "$@" = "wrn_28_10" ]; then
eval_cifar10 wrn_28_10 0.1 0.0005
elif [ "$@" = "ss_32" ]; then
eval_cifar10 shake_shake_32 0.01 0.001
elif [ "$@" = "ss_96" ]; then
eval_cifar10 shake_shake_96 0.01 0.001
elif [ "$@" = "ss_112" ]; then
eval_cifar10 shake_shake_112 0.01 0.001
elif [ "$@" = "pyramid_net" ]; then
eval_cifar10 pyramid_net 0.05 0.00005
elif [ "$@" = "LeNet" ]; then
eval_cifar10 LeNet 0.05 0.0
else
echo "invalid args"
fi