mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Fix LeNet Tensorflow
This commit is contained in:
parent
0e7ec8b5b0
commit
758d6e9b78
40 changed files with 103882 additions and 444 deletions
38
PBA/LeNet.py
38
PBA/LeNet.py
|
@ -14,7 +14,7 @@ def bias_variable(shape, name = None):
|
|||
|
||||
# 2D convolution
|
||||
def conv2d(x, W, name = None):
|
||||
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME', name = name)
|
||||
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID', name = name)
|
||||
|
||||
# max pooling
|
||||
def max_pool_2x2(x, name = None):
|
||||
|
@ -30,44 +30,38 @@ def LeNet(images, num_classes):
|
|||
n_n_fc1 = 500; # number of neurons of first fully connected layer (default = 576)
|
||||
n_n_fc2 = 500; # number of neurons of first fully connected layer (default = 576)
|
||||
|
||||
#print(images.shape)
|
||||
# 1.layer: convolution + max pooling
|
||||
W_conv1_tf = weight_variable([s_f_conv1, s_f_conv1, images.shape[3], n_f_conv1], name = 'W_conv1_tf') # (5,5,1,32)
|
||||
W_conv1_tf = weight_variable([s_f_conv1, s_f_conv1, int(images.shape[3]), n_f_conv1], name = 'W_conv1_tf') # (5,5,1,32)
|
||||
b_conv1_tf = bias_variable([n_f_conv1], name = 'b_conv1_tf') # (32)
|
||||
h_conv1_tf = tf.nn.relu(conv2d(images,
|
||||
W_conv1_tf) + b_conv1_tf,
|
||||
name = 'h_conv1_tf') # (.,28,28,32)
|
||||
h_pool1_tf = max_pool_2x2(h_conv1_tf,
|
||||
name = 'h_pool1_tf') # (.,14,14,32)
|
||||
|
||||
h_conv1_tf = tf.nn.relu(conv2d(images, W_conv1_tf) + b_conv1_tf, name = 'h_conv1_tf') # (.,28,28,32)
|
||||
h_pool1_tf = max_pool_2x2(h_conv1_tf, name = 'h_pool1_tf') # (.,14,14,32)
|
||||
#print(h_conv1_tf.shape)
|
||||
#print(h_pool1_tf.shape)
|
||||
# 2.layer: convolution + max pooling
|
||||
W_conv2_tf = weight_variable([s_f_conv2, s_f_conv2,
|
||||
n_f_conv1, n_f_conv2],
|
||||
name = 'W_conv2_tf')
|
||||
W_conv2_tf = weight_variable([s_f_conv2, s_f_conv2, n_f_conv1, n_f_conv2], name = 'W_conv2_tf')
|
||||
b_conv2_tf = bias_variable([n_f_conv2], name = 'b_conv2_tf')
|
||||
h_conv2_tf = tf.nn.relu(conv2d(h_pool1_tf,
|
||||
W_conv2_tf) + b_conv2_tf,
|
||||
name ='h_conv2_tf') #(.,14,14,32)
|
||||
h_conv2_tf = tf.nn.relu(conv2d(h_pool1_tf, W_conv2_tf) + b_conv2_tf, name ='h_conv2_tf') #(.,14,14,32)
|
||||
h_pool2_tf = max_pool_2x2(h_conv2_tf, name = 'h_pool2_tf') #(.,7,7,32)
|
||||
|
||||
#print(h_pool2_tf.shape)
|
||||
|
||||
# 4.layer: fully connected
|
||||
W_fc1_tf = weight_variable([5*5*n_f_conv2,n_n_fc1],
|
||||
name = 'W_fc1_tf') # (4*4*32, 1024)
|
||||
W_fc1_tf = weight_variable([5*5*n_f_conv2,n_n_fc1], name = 'W_fc1_tf') # (4*4*32, 1024)
|
||||
b_fc1_tf = bias_variable([n_n_fc1], name = 'b_fc1_tf') # (1024)
|
||||
h_pool2_flat_tf = tf.reshape(h_pool2_tf, [-1,5*5*n_f_conv2],
|
||||
name = 'h_pool3_flat_tf') # (.,1024)
|
||||
h_fc1_tf = tf.nn.relu(tf.matmul(h_pool2_flat_tf,
|
||||
W_fc1_tf) + b_fc1_tf,
|
||||
h_pool2_flat_tf = tf.reshape(h_pool2_tf, [int(h_pool2_tf.shape[0]), -1], name = 'h_pool3_flat_tf') # (.,1024)
|
||||
h_fc1_tf = tf.nn.relu(tf.matmul(h_pool2_flat_tf, W_fc1_tf) + b_fc1_tf,
|
||||
name = 'h_fc1_tf') # (.,1024)
|
||||
|
||||
# add dropout
|
||||
#keep_prob_tf = tf.placeholder(dtype=tf.float32, name = 'keep_prob_tf')
|
||||
#h_fc1_drop_tf = tf.nn.dropout(h_fc1_tf, keep_prob_tf, name = 'h_fc1_drop_tf')
|
||||
print(h_fc1_tf.shape)
|
||||
|
||||
# 5.layer: fully connected
|
||||
W_fc2_tf = weight_variable([n_n_fc1, num_classes], name = 'W_fc2_tf')
|
||||
b_fc2_tf = bias_variable([num_classes], name = 'b_fc2_tf')
|
||||
z_pred_tf = tf.add(tf.matmul(h_fc1_tf, W_fc2_tf),
|
||||
b_fc2_tf, name = 'z_pred_tf')# => (.,10)
|
||||
z_pred_tf = tf.add(tf.matmul(h_fc1_tf, W_fc2_tf), b_fc2_tf, name = 'z_pred_tf')# => (.,10)
|
||||
# predicted probabilities in one-hot encoding
|
||||
#y_pred_proba_tf = tf.nn.softmax(z_pred_tf, name='y_pred_proba_tf')
|
||||
|
||||
|
|
59
PBA/search.sh
Executable file
59
PBA/search.sh
Executable file
|
@ -0,0 +1,59 @@
|
|||
#!/bin/bash
|
||||
export PYTHONPATH="$(pwd)"
|
||||
|
||||
cifar10_LeNet_search() {
|
||||
local_dir="$PWD/results/"
|
||||
data_path="$PWD/datasets/cifar-10-batches-py"
|
||||
|
||||
python pba/search.py \
|
||||
--local_dir "$local_dir" \
|
||||
--model_name LeNet \
|
||||
--data_path "$data_path" --dataset cifar10 \
|
||||
--train_size 4000 --val_size 46000 \
|
||||
--checkpoint_freq 0 \
|
||||
--name "cifar10_search" --gpu 0.15 --cpu 2 \
|
||||
--num_samples 16 --perturbation_interval 3 --epochs 150 \
|
||||
--explore cifar10 --aug_policy cifar10 \
|
||||
--lr 0.1 --wd 0.0005
|
||||
}
|
||||
|
||||
cifar10_search() {
|
||||
local_dir="$PWD/results/"
|
||||
data_path="$PWD/datasets/cifar-10-batches-py"
|
||||
|
||||
python pba/search.py \
|
||||
--local_dir "$local_dir" \
|
||||
--model_name wrn_40_2 \
|
||||
--data_path "$data_path" --dataset cifar10 \
|
||||
--train_size 4000 --val_size 46000 \
|
||||
--checkpoint_freq 0 \
|
||||
--name "cifar10_search" --gpu 0.15 --cpu 2 \
|
||||
--num_samples 16 --perturbation_interval 3 --epochs 200 \
|
||||
--explore cifar10 --aug_policy cifar10 \
|
||||
--lr 0.1 --wd 0.0005
|
||||
}
|
||||
|
||||
svhn_search() {
|
||||
local_dir="$PWD/results/"
|
||||
data_path="$PWD/datasets/"
|
||||
|
||||
python pba/search.py \
|
||||
--local_dir "$local_dir" --data_path "$data_path" \
|
||||
--model_name wrn_40_2 --dataset svhn \
|
||||
--train_size 1000 --val_size 7325 \
|
||||
--checkpoint_freq 0 \
|
||||
--name "svhn_search" --gpu 0.19 --cpu 2 \
|
||||
--num_samples 16 --perturbation_interval 3 --epochs 160 \
|
||||
--explore cifar10 --aug_policy cifar10 --no_cutout \
|
||||
--lr 0.1 --wd 0.005
|
||||
}
|
||||
|
||||
if [ "$1" = "rcifar10" ]; then
|
||||
cifar10_search
|
||||
elif [ "$1" = "rsvhn" ]; then
|
||||
svhn_search
|
||||
elif [ "$1" = "LeNet" ]; then
|
||||
cifar10_LeNet_search
|
||||
else
|
||||
echo "invalid args"
|
||||
fi
|
Loading…
Add table
Add a link
Reference in a new issue