smart_augmentation/PBA/search.sh
Harle, Antoine (Contracteur) 758d6e9b78 Fix LeNet Tensorflow
2019-11-21 12:29:17 -05:00

59 lines
1.6 KiB
Bash
Executable file

#!/bin/bash
export PYTHONPATH="$(pwd)"
cifar10_LeNet_search() {
local_dir="$PWD/results/"
data_path="$PWD/datasets/cifar-10-batches-py"
python pba/search.py \
--local_dir "$local_dir" \
--model_name LeNet \
--data_path "$data_path" --dataset cifar10 \
--train_size 4000 --val_size 46000 \
--checkpoint_freq 0 \
--name "cifar10_search" --gpu 0.15 --cpu 2 \
--num_samples 16 --perturbation_interval 3 --epochs 150 \
--explore cifar10 --aug_policy cifar10 \
--lr 0.1 --wd 0.0005
}
cifar10_search() {
local_dir="$PWD/results/"
data_path="$PWD/datasets/cifar-10-batches-py"
python pba/search.py \
--local_dir "$local_dir" \
--model_name wrn_40_2 \
--data_path "$data_path" --dataset cifar10 \
--train_size 4000 --val_size 46000 \
--checkpoint_freq 0 \
--name "cifar10_search" --gpu 0.15 --cpu 2 \
--num_samples 16 --perturbation_interval 3 --epochs 200 \
--explore cifar10 --aug_policy cifar10 \
--lr 0.1 --wd 0.0005
}
svhn_search() {
local_dir="$PWD/results/"
data_path="$PWD/datasets/"
python pba/search.py \
--local_dir "$local_dir" --data_path "$data_path" \
--model_name wrn_40_2 --dataset svhn \
--train_size 1000 --val_size 7325 \
--checkpoint_freq 0 \
--name "svhn_search" --gpu 0.19 --cpu 2 \
--num_samples 16 --perturbation_interval 3 --epochs 160 \
--explore cifar10 --aug_policy cifar10 --no_cutout \
--lr 0.1 --wd 0.005
}
if [ "$1" = "rcifar10" ]; then
cifar10_search
elif [ "$1" = "rsvhn" ]; then
svhn_search
elif [ "$1" = "LeNet" ]; then
cifar10_LeNet_search
else
echo "invalid args"
fi