# coding=utf-8 # Copyright 2019 The Google UDA Team Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """UDA on CIFAR-10 and SVHN. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import contextlib import os import time import json import numpy as np from absl import flags import absl.logging as _logging # pylint: disable=unused-import import tensorflow as tf from randaugment import custom_ops as ops import data import utils from randaugment.wrn import build_wrn_model from randaugment.shake_drop import build_shake_drop_model from randaugment.shake_shake import build_shake_shake_model from randaugment.LeNet import LeNet # TPU related flags.DEFINE_string( "master", default=None, help="the TPU address. This should be set when using Cloud TPU") flags.DEFINE_string( "tpu", default=None, help="The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.") flags.DEFINE_string( "gcp_project", default=None, help="Project name for the Cloud TPU-enabled project. If not specified, " "we will attempt to automatically detect the GCE project from metadata.") flags.DEFINE_string( "tpu_zone", default=None, help="GCE zone where the Cloud TPU is located in. If not specified, we " "will attempt to automatically detect the GCE project from metadata.") flags.DEFINE_bool( "use_tpu", default=False, help="Use TPUs rather than GPU/CPU.") flags.DEFINE_enum( "task_name", "cifar10", enum_values=["cifar10", "svhn"], help="The task to use") # UDA config: flags.DEFINE_integer( "sup_size", default=4000, help="Number of supervised pairs to use. " "-1: all training samples. 4000: 4000 supervised examples.") flags.DEFINE_integer( "aug_copy", default=0, help="Number of different augmented data generated.") flags.DEFINE_integer( "unsup_ratio", default=0, help="The ratio between batch size of unlabeled data and labeled data, " "i.e., unsup_ratio * train_batch_size is the batch_size for unlabeled data." "Do not use the unsupervised objective if set to 0.") flags.DEFINE_enum( "tsa", "", enum_values=["", "linear_schedule", "log_schedule", "exp_schedule"], help="anneal schedule of training signal annealing. " "tsa='' means not using TSA. See the paper for other schedules.") flags.DEFINE_float( "uda_confidence_thresh", default=-1, help="The threshold on predicted probability on unsupervised data. If set," "UDA loss will only be calculated on unlabeled examples whose largest" "probability is larger than the threshold") flags.DEFINE_float( "uda_softmax_temp", -1, help="The temperature of the Softmax when making prediction on unlabeled" "examples. -1 means to use normal Softmax") flags.DEFINE_float( "ent_min_coeff", default=0, help="") flags.DEFINE_integer( "unsup_coeff", default=1, help="The coefficient on the UDA loss. " "setting unsup_coeff to 1 works for most settings. " "When you have extermely few samples, consider increasing unsup_coeff") # Experiment (data/checkpoint/directory) config flags.DEFINE_string( "data_dir", default=None, help="Path to data directory containing `*.tfrecords`.") flags.DEFINE_string( "model_dir", default=None, help="model dir of the saved checkpoints.") flags.DEFINE_bool( "do_train", default=True, help="Whether to run training.") flags.DEFINE_bool( "do_eval", default=False, help="Whether to run eval on the test set.") flags.DEFINE_integer( "dev_size", default=-1, help="dev set size.") flags.DEFINE_bool( "verbose", default=False, help="Whether to print additional information.") # Training config flags.DEFINE_integer( "train_batch_size", default=32, help="Size of train batch.") flags.DEFINE_integer( "eval_batch_size", default=8, help="Size of evalation batch.") flags.DEFINE_integer( "train_steps", default=100000, help="Total number of training steps.") flags.DEFINE_integer( "iterations", default=10000, help="Number of iterations per repeat loop.") flags.DEFINE_integer( "save_steps", default=10000, help="number of steps for model checkpointing.") flags.DEFINE_integer( "max_save", default=10, help="Maximum number of checkpoints to save.") # Model config flags.DEFINE_enum( "model_name", default="wrn", enum_values=["wrn", "shake_shake_32", "shake_shake_96", "shake_shake_112", "pyramid_net", "LeNet"], help="Name of the model") flags.DEFINE_integer( "num_classes", default=10, help="Number of categories for classification.") flags.DEFINE_integer( "wrn_size", default=32, help="The size of WideResNet. It should be set to 32 for WRN-28-2" "and should be set to 160 for WRN-28-10") # Optimization config flags.DEFINE_float( "learning_rate", default=0.03, help="Maximum learning rate.") flags.DEFINE_float( "weight_decay_rate", default=5e-4, help="Weight decay rate.") flags.DEFINE_float( "min_lr_ratio", default=0.004, help="Minimum ratio learning rate.") flags.DEFINE_integer( "warmup_steps", default=20000, help="Number of steps for linear lr warmup.") FLAGS = tf.flags.FLAGS arg_scope = tf.contrib.framework.arg_scope def get_tsa_threshold(schedule, global_step, num_train_steps, start, end): step_ratio = tf.to_float(global_step) / tf.to_float(num_train_steps) if schedule == "linear_schedule": coeff = step_ratio elif schedule == "exp_schedule": scale = 5 # [exp(-5), exp(0)] = [1e-2, 1] coeff = tf.exp((step_ratio - 1) * scale) elif schedule == "log_schedule": scale = 5 # [1 - exp(0), 1 - exp(-5)] = [0, 0.99] coeff = 1 - tf.exp((-step_ratio) * scale) return coeff * (end - start) + start def setup_arg_scopes(is_training): """Sets up the argscopes that will be used when building an image model. Args: is_training: Is the model training or not. Returns: Arg scopes to be put around the model being constructed. """ batch_norm_decay = 0.9 batch_norm_epsilon = 1e-5 batch_norm_params = { # Decay for the moving averages. "decay": batch_norm_decay, # epsilon to prevent 0s in variance. "epsilon": batch_norm_epsilon, "scale": True, # collection containing the moving mean and moving variance. "is_training": is_training, } scopes = [] scopes.append(arg_scope([ops.batch_norm], **batch_norm_params)) return scopes def build_model(inputs, num_classes, is_training, update_bn, hparams): """Constructs the vision model being trained/evaled. Args: inputs: input features/images being fed to the image model build built. num_classes: number of output classes being predicted. is_training: is the model training or not. hparams: additional hyperparameters associated with the image model. Returns: The logits of the image model. """ scopes = setup_arg_scopes(is_training) try: from contextlib import nested except ImportError: from contextlib import ExitStack, contextmanager @contextmanager def nested(*contexts): with ExitStack() as stack: for ctx in contexts: stack.enter_context(ctx) yield contexts with nested(*scopes): if hparams.model_name == "pyramid_net": logits = build_shake_drop_model( inputs, num_classes, is_training) elif hparams.model_name == "wrn": logits = build_wrn_model( inputs, num_classes, hparams.wrn_size, update_bn) elif hparams.model_name == "shake_shake": logits = build_shake_shake_model( inputs, num_classes, hparams, is_training) elif hparams.model_name == "LeNet": logits = LeNet(inputs, num_classes) return logits def _kl_divergence_with_logits(p_logits, q_logits): p = tf.nn.softmax(p_logits) log_p = tf.nn.log_softmax(p_logits) log_q = tf.nn.log_softmax(q_logits) kl = tf.reduce_sum(p * (log_p - log_q), -1) return kl def anneal_sup_loss(sup_logits, sup_labels, sup_loss, global_step, metric_dict): tsa_start = 1. / FLAGS.num_classes eff_train_prob_threshold = get_tsa_threshold( FLAGS.tsa, global_step, FLAGS.train_steps, tsa_start, end=1) one_hot_labels = tf.one_hot( sup_labels, depth=FLAGS.num_classes, dtype=tf.float32) sup_probs = tf.nn.softmax(sup_logits, axis=-1) correct_label_probs = tf.reduce_sum( one_hot_labels * sup_probs, axis=-1) larger_than_threshold = tf.greater( correct_label_probs, eff_train_prob_threshold) loss_mask = 1 - tf.cast(larger_than_threshold, tf.float32) loss_mask = tf.stop_gradient(loss_mask) sup_loss = sup_loss * loss_mask avg_sup_loss = (tf.reduce_sum(sup_loss) / tf.maximum(tf.reduce_sum(loss_mask), 1)) metric_dict["sup/sup_trained_ratio"] = tf.reduce_mean(loss_mask) metric_dict["sup/eff_train_prob_threshold"] = eff_train_prob_threshold return sup_loss, avg_sup_loss def get_ent(logits, return_mean=True): log_prob = tf.nn.log_softmax(logits, axis=-1) prob = tf.exp(log_prob) ent = tf.reduce_sum(-prob * log_prob, axis=-1) if return_mean: ent = tf.reduce_mean(ent) return ent def get_model_fn(hparams): def model_fn(features, labels, mode, params): sup_labels = tf.reshape(features["label"], [-1]) #### Configuring the optimizer global_step = tf.train.get_global_step() metric_dict = {} is_training = (mode == tf.estimator.ModeKeys.TRAIN) if FLAGS.unsup_ratio > 0 and is_training: all_images = tf.concat([features["image"], features["ori_image"], features["aug_image"]], 0) else: all_images = features["image"] with tf.variable_scope("model", reuse=tf.AUTO_REUSE): all_logits = build_model( inputs=all_images, num_classes=FLAGS.num_classes, is_training=is_training, update_bn=True and is_training, hparams=hparams, ) sup_bsz = tf.shape(features["image"])[0] sup_logits = all_logits[:sup_bsz] sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=sup_labels, logits=sup_logits) sup_prob = tf.nn.softmax(sup_logits, axis=-1) metric_dict["sup/pred_prob"] = tf.reduce_mean( tf.reduce_max(sup_prob, axis=-1)) if FLAGS.tsa: sup_loss, avg_sup_loss = anneal_sup_loss(sup_logits, sup_labels, sup_loss, global_step, metric_dict) else: avg_sup_loss = tf.reduce_mean(sup_loss) total_loss = avg_sup_loss if FLAGS.unsup_ratio > 0 and is_training: aug_bsz = tf.shape(features["ori_image"])[0] ori_logits = all_logits[sup_bsz : sup_bsz + aug_bsz] aug_logits = all_logits[sup_bsz + aug_bsz:] if FLAGS.uda_softmax_temp != -1: ori_logits_tgt = ori_logits / FLAGS.uda_softmax_temp else: ori_logits_tgt = ori_logits ori_prob = tf.nn.softmax(ori_logits, axis=-1) aug_prob = tf.nn.softmax(aug_logits, axis=-1) metric_dict["unsup/ori_prob"] = tf.reduce_mean( tf.reduce_max(ori_prob, axis=-1)) metric_dict["unsup/aug_prob"] = tf.reduce_mean( tf.reduce_max(aug_prob, axis=-1)) aug_loss = _kl_divergence_with_logits( p_logits=tf.stop_gradient(ori_logits_tgt), q_logits=aug_logits) if FLAGS.uda_confidence_thresh != -1: ori_prob = tf.nn.softmax(ori_logits, axis=-1) largest_prob = tf.reduce_max(ori_prob, axis=-1) loss_mask = tf.cast(tf.greater( largest_prob, FLAGS.uda_confidence_thresh), tf.float32) metric_dict["unsup/high_prob_ratio"] = tf.reduce_mean(loss_mask) loss_mask = tf.stop_gradient(loss_mask) aug_loss = aug_loss * loss_mask metric_dict["unsup/high_prob_loss"] = tf.reduce_mean(aug_loss) if FLAGS.ent_min_coeff > 0: ent_min_coeff = FLAGS.ent_min_coeff metric_dict["unsup/ent_min_coeff"] = ent_min_coeff per_example_ent = get_ent(ori_logits) ent_min_loss = tf.reduce_mean(per_example_ent) total_loss = total_loss + ent_min_coeff * ent_min_loss avg_unsup_loss = tf.reduce_mean(aug_loss) total_loss += FLAGS.unsup_coeff * avg_unsup_loss metric_dict["unsup/loss"] = avg_unsup_loss total_loss = utils.decay_weights( total_loss, FLAGS.weight_decay_rate) #### Check model parameters num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) tf.logging.info("#params: {}".format(num_params)) if FLAGS.verbose: format_str = "{{:<{0}s}}\t{{}}".format( max([len(v.name) for v in tf.trainable_variables()])) for v in tf.trainable_variables(): tf.logging.info(format_str.format(v.name, v.get_shape())) #### Evaluation mode if mode == tf.estimator.ModeKeys.EVAL: #### Metric function for classification def metric_fn(per_example_loss, label_ids, logits): # classification loss & accuracy loss = tf.metrics.mean(per_example_loss) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(label_ids, predictions) ret_dict = { "eval/classify_loss": loss, "eval/classify_accuracy": accuracy } return ret_dict eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits]) #### Constucting evaluation TPUEstimatorSpec. eval_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics) return eval_spec # increase the learning rate linearly if FLAGS.warmup_steps > 0: warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ * FLAGS.learning_rate else: warmup_lr = 0.0 # decay the learning rate using the cosine schedule decay_lr = tf.train.cosine_decay( FLAGS.learning_rate, global_step=global_step-FLAGS.warmup_steps, decay_steps=FLAGS.train_steps-FLAGS.warmup_steps, alpha=FLAGS.min_lr_ratio) learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr, decay_lr) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=0.9, use_nesterov=True) if FLAGS.use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) grads_and_vars = optimizer.compute_gradients(total_loss) gradients, variables = zip(*grads_and_vars) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.apply_gradients( zip(gradients, variables), global_step=tf.train.get_global_step()) #### Creating training logging hook # compute accuracy sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype) is_correct = tf.to_float(tf.equal(sup_pred, sup_labels)) acc = tf.reduce_mean(is_correct) metric_dict["sup/sup_loss"] = avg_sup_loss metric_dict["training/loss"] = total_loss metric_dict["sup/acc"] = acc metric_dict["training/lr"] = learning_rate metric_dict["training/step"] = global_step if not FLAGS.use_tpu: log_info = ("step [{training/step}] lr {training/lr:.6f} " "loss {training/loss:.4f} " "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ") if FLAGS.unsup_ratio > 0: log_info += "unsup/loss {unsup/loss:.6f} " formatter = lambda kwargs: log_info.format(**kwargs) logging_hook = tf.train.LoggingTensorHook( tensors=metric_dict, every_n_iter=FLAGS.iterations, formatter=formatter) training_hooks = [logging_hook] #### Constucting training TPUEstimatorSpec. train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) else: #### Constucting training TPUEstimatorSpec. host_call = utils.construct_scalar_host_call( metric_dict=metric_dict, model_dir=params["model_dir"], prefix="", reduce_fn=tf.reduce_mean) train_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, host_call=host_call) return train_spec return model_fn def train(hparams): ##### Create input function if FLAGS.unsup_ratio == 0: FLAGS.aug_copy = 0 if FLAGS.dev_size != -1: FLAGS.do_train = True FLAGS.do_eval = True if FLAGS.do_train: train_input_fn = data.get_input_fn( data_dir=FLAGS.data_dir, split="train", task_name=FLAGS.task_name, sup_size=FLAGS.sup_size, unsup_ratio=FLAGS.unsup_ratio, aug_copy=FLAGS.aug_copy, ) if FLAGS.do_eval: if FLAGS.dev_size != -1: eval_input_fn = data.get_input_fn( data_dir=FLAGS.data_dir, split="dev", task_name=FLAGS.task_name, sup_size=FLAGS.dev_size, unsup_ratio=0, aug_copy=0) eval_size = FLAGS.dev_size else: eval_input_fn = data.get_input_fn( data_dir=FLAGS.data_dir, split="test", task_name=FLAGS.task_name, sup_size=-1, unsup_ratio=0, aug_copy=0) if FLAGS.task_name == "cifar10": eval_size = 10000 elif FLAGS.task_name == "svhn": eval_size = 26032 else: assert False, "You need to specify the size of your test set." eval_steps = eval_size // FLAGS.eval_batch_size ##### Get model function model_fn = get_model_fn(hparams) estimator = utils.get_TPU_estimator(FLAGS, model_fn) #### Training if FLAGS.dev_size != -1: tf.logging.info("***** Running training and validation *****") tf.logging.info(" Supervised batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Unsupervised batch size = %d", FLAGS.train_batch_size * FLAGS.unsup_ratio) tf.logging.info(" Num train steps = %d", FLAGS.train_steps) curr_step = 0 while True: if curr_step >= FLAGS.train_steps: break tf.logging.info("Current step {}".format(curr_step)) train_step = min(FLAGS.save_steps, FLAGS.train_steps - curr_step) estimator.train(input_fn=train_input_fn, steps=train_step) estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) curr_step += FLAGS.save_steps else: if FLAGS.do_train: tf.logging.info("***** Running training *****") tf.logging.info(" Supervised batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Unsupervised batch size = %d", FLAGS.train_batch_size * FLAGS.unsup_ratio) estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) if FLAGS.do_eval: tf.logging.info("***** Running evaluation *****") results = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) tf.logging.info(">> Results:") for key in results.keys(): tf.logging.info(" %s = %s", key, str(results[key])) results[key] = results[key].item() acc = results["eval/classify_accuracy"] with tf.gfile.Open("{}/results.txt".format(FLAGS.model_dir), "w") as ouf: ouf.write(str(acc)) def main(_): if FLAGS.do_train: tf.gfile.MakeDirs(FLAGS.model_dir) flags_dict = tf.app.flags.FLAGS.flag_values_dict() with tf.gfile.Open(os.path.join(FLAGS.model_dir, "FLAGS.json"), "w") as ouf: json.dump(flags_dict, ouf) hparams = tf.contrib.training.HParams() if FLAGS.model_name == "wrn": hparams.add_hparam("model_name", "wrn") hparams.add_hparam("wrn_size", FLAGS.wrn_size) elif FLAGS.model_name == "shake_shake_32": hparams.add_hparam("model_name", "shake_shake") hparams.add_hparam("shake_shake_widen_factor", 2) elif FLAGS.model_name == "shake_shake_96": hparams.add_hparam("model_name", "shake_shake") hparams.add_hparam("shake_shake_widen_factor", 6) elif FLAGS.model_name == "shake_shake_112": hparams.add_hparam("model_name", "shake_shake") hparams.add_hparam("shake_shake_widen_factor", 7) elif FLAGS.model_name == "pyramid_net": hparams.add_hparam("model_name", "pyramid_net") elif FLAGS.model_name == "LeNet": hparams.add_hparam("model_name", "LeNet") else: raise ValueError("Not Valid Model Name: %s" % FLAGS.model_name) train(hparams) if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.INFO) tf.app.run()