diff --git a/gpu_environment.py b/gpu_environment.py new file mode 100644 index 0000000..948c3fa --- /dev/null +++ b/gpu_environment.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np + +def float32_variable_storage_getter(getter, name, shape=None, dtype=None, + initializer=None, regularizer=None, + trainable=True, + *args, **kwargs): + """Custom variable getter that forces trainable variables to be stored in + float32 precision and then casts them to the training precision. + """ + storage_dtype = tf.float32 if trainable else dtype + variable = getter(name, shape, dtype=storage_dtype, + initializer=initializer, regularizer=regularizer, + trainable=trainable, + *args, **kwargs) + if trainable and dtype != tf.float32: + variable = tf.cast(variable, dtype) + return variable + +def get_custom_getter(compute_type): + return float32_variable_storage_getter if compute_type == tf.float16 else None diff --git a/modeling.py b/modeling.py index fed5259..6ef7684 100644 --- a/modeling.py +++ b/modeling.py @@ -26,6 +26,7 @@ import re import numpy as np import six import tensorflow as tf +from gpu_environment import get_custom_getter class BertConfig(object): @@ -135,7 +136,8 @@ class BertModel(object): input_mask=None, token_type_ids=None, use_one_hot_embeddings=False, - scope=None): + scope=None, + compute_type=tf.float32): """Constructor for BertModel. Args: @@ -168,7 +170,7 @@ class BertModel(object): if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) - with tf.variable_scope(scope, default_name="bert"): + with tf.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. (self.embedding_output, self.embedding_table) = embedding_lookup( @@ -203,7 +205,7 @@ class BertModel(object): # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. self.all_encoder_layers = transformer_model( - input_tensor=self.embedding_output, + input_tensor=tf.saturate_cast(self.embedding_output, compute_type), attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, @@ -215,7 +217,7 @@ class BertModel(object): initializer_range=config.initializer_range, do_return_all_layers=True) - self.sequence_output = self.all_encoder_layers[-1] + self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.float32) # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level @@ -709,7 +711,7 @@ def attention_layer(from_tensor, # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. - adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 + adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. diff --git a/optimization.py b/optimization.py index d33dabd..ac1fb3d 100644 --- a/optimization.py +++ b/optimization.py @@ -21,26 +21,52 @@ from __future__ import print_function import re import tensorflow as tf +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +# pylint: enable=g-direct-tensorflow-import -def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): +import horovod.tensorflow as hvd + + +def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd, + use_compression, use_fp16, clip, cos_decay, use_lamb=False, + previous_train_steps=0, post_train_steps=0): """Creates an optimizer training op.""" global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) - # Implements linear decay of the learning rate. - learning_rate = tf.train.polynomial_decay( - learning_rate, - global_step, - num_train_steps, - end_learning_rate=0.0, - power=1.0, - cycle=False) + # if adjust_lr: + # # avoid step change in learning rate at end of warmup phase + # decayed_learning_rate_at_crossover_point = init_lr * (1.0-float(num_warmup_steps)/float(num_train_steps + previous_train_steps + post_train_steps)) + # adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point) + # print('decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e' % (decayed_learning_rate_at_crossover_point, adjusted_init_lr)) + + # learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32) + + if cos_decay: + # Implements cosine decay of the learning rate. + learning_rate = tf.train.cosine_decay( + learning_rate, + global_step + previous_train_steps - (num_warmup_steps if adjust_lr else 0), + num_train_steps + previous_train_steps + post_train_steps - (num_warmup_steps if adjust_lr else 0), + alpha=0.0) + else: + # Implements linear decay of the learning rate. + learning_rate = tf.train.polynomial_decay( + learning_rate, + global_step + previous_train_steps - (num_warmup_steps if adjust_lr else 0), + num_train_steps + previous_train_steps + post_train_steps - (num_warmup_steps if adjust_lr else 0), + end_learning_rate=0.0, + power=1.0, + cycle=False) # Implements linear warmup. I.e., if global_step < num_warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. if num_warmup_steps: - global_steps_int = tf.cast(global_step, tf.int32) + global_steps_int = tf.cast(global_step + previous_train_steps, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) @@ -56,7 +82,16 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): # It is recommended that you use this optimizer for fine tuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) - optimizer = AdamWeightDecayOptimizer( + if use_lamb: + optimizer = LAMBOptimizer( + learning_rate=learning_rate, + weight_decay_rate=0.01, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) + else: + optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, @@ -64,24 +99,51 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) + tf.summary.scalar("learning_rate", optimizer.learning_rate) + + if use_hvd: + optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=hvd.Compression.fp16 if use_compression else hvd.Compression.none) + + if use_fp16: + loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=2**32, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5) + optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager) + if use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) + all_are_finite = tf.constant(True, dtype=tf.bool) tvars = tf.trainable_variables() - grads = tf.gradients(loss, tvars) + if use_hvd: + grads_and_vars = optimizer.compute_gradients(loss, tvars) + grads_and_vars = [(g,v) for g,v in grads_and_vars if g is not None] + grads, tvars = list(zip(*grads_and_vars)) + all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 else tf.constant(True, dtype=tf.bool) + else: + grads = tf.gradients(loss, tvars) # This is how the model was pre-trained. - (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) + # ensure global norm is a finite number + # to prevent clip_by_global_norm from having a hizzy fit. + (clipped_grads, _) = tf.clip_by_global_norm( + grads, clip_norm=1.0, + use_norm=tf.cond( + all_are_finite, + lambda: tf.global_norm(grads), + lambda: tf.constant(1.0))) train_op = optimizer.apply_gradients( - zip(grads, tvars), global_step=global_step) + list(zip(clipped_grads, tvars)), global_step=global_step) # Normally the global step update is done inside of `apply_gradients`. # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use # a different optimizer, you should probably take this line out. - new_global_step = global_step + 1 + if clip: + new_global_step = global_step + 1 + else: + new_global_step = tf.cond(all_are_finite, lambda: global_step+1, lambda: global_step) + new_global_step = tf.identity(new_global_step, name='step_update') train_op = tf.group(train_op, [global_step.assign(new_global_step)]) - return train_op + return train_op, learning_rate class AdamWeightDecayOptimizer(tf.train.Optimizer): @@ -98,7 +160,7 @@ class AdamWeightDecayOptimizer(tf.train.Optimizer): """Constructs a AdamWeightDecayOptimizer.""" super(AdamWeightDecayOptimizer, self).__init__(False, name) - self.learning_rate = learning_rate + self.learning_rate = tf.identity(learning_rate, name='learning_rate') self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 @@ -172,3 +234,121 @@ class AdamWeightDecayOptimizer(tf.train.Optimizer): if m is not None: param_name = m.group(1) return param_name + + +class LAMBOptimizer(tf.train.Optimizer): + """LAMB (Layer-wise Adaptive Moments optimizer for Batch training).""" + # A new optimizer that includes correct L2 weight decay, adaptive + # element-wise updating, and layer-wise justification. The LAMB optimizer + # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song, + # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT + # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962) + + def __init__(self, + learning_rate, + weight_decay_rate=0.0, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + exclude_from_weight_decay=None, + exclude_from_layer_adaptation=None, + name="LAMBOptimizer"): + """Constructs a LAMBOptimizer.""" + super(LAMBOptimizer, self).__init__(False, name) + + self.learning_rate = learning_rate + self.weight_decay_rate = weight_decay_rate + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.exclude_from_weight_decay = exclude_from_weight_decay + # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the + # arg is None. + # TODO(jingli): validate if exclude_from_layer_adaptation is necessary. + if exclude_from_layer_adaptation: + self.exclude_from_layer_adaptation = exclude_from_layer_adaptation + else: + self.exclude_from_layer_adaptation = exclude_from_weight_decay + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """See base class.""" + assignments = [] + for (grad, param) in grads_and_vars: + if grad is None or param is None: + continue + + param_name = self._get_variable_name(param.name) + + m = tf.get_variable( + name=param_name + "/adam_m", + shape=param.shape.as_list(), + dtype=tf.float32, + trainable=False, + initializer=tf.zeros_initializer()) + v = tf.get_variable( + name=param_name + "/adam_v", + shape=param.shape.as_list(), + dtype=tf.float32, + trainable=False, + initializer=tf.zeros_initializer()) + + # Standard Adam update. + next_m = ( + tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) + next_v = ( + tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, + tf.square(grad))) + + update = next_m / (tf.sqrt(next_v) + self.epsilon) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want ot decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + if self._do_use_weight_decay(param_name): + update += self.weight_decay_rate * param + + ratio = 1.0 + if self._do_layer_adaptation(param_name): + w_norm = linalg_ops.norm(param, ord=2) + g_norm = linalg_ops.norm(update, ord=2) + ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( + math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) + + update_with_lr = ratio * self.learning_rate * update + + next_param = param - update_with_lr + + assignments.extend( + [param.assign(next_param), + m.assign(next_m), + v.assign(next_v)]) + return tf.group(*assignments, name=name) + + def _do_use_weight_decay(self, param_name): + """Whether to use L2 weight decay for `param_name`.""" + if not self.weight_decay_rate: + return False + if self.exclude_from_weight_decay: + for r in self.exclude_from_weight_decay: + if re.search(r, param_name) is not None: + return False + return True + + def _do_layer_adaptation(self, param_name): + """Whether to do layer-wise learning rate adaptation for `param_name`.""" + if self.exclude_from_layer_adaptation: + for r in self.exclude_from_layer_adaptation: + if re.search(r, param_name) is not None: + return False + return True + + def _get_variable_name(self, param_name): + """Get the variable name from the tensor name.""" + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name diff --git a/run_classifier.py b/run_classifier.py index 817b147..e69e41d 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -25,6 +25,8 @@ import modeling import optimization import tokenization import tensorflow as tf +import horovod.tensorflow as hvd +from tensorflow.python import debug as tf_debug flags = tf.flags @@ -36,6 +38,16 @@ flags.DEFINE_string( "The input data dir. Should contain the .tsv files (or other data files) " "for the task.") +flags.DEFINE_string( + "validation_data_dir", None, + "The input validation data dir. Should contain the .tsv files (or other data files) " + "for the task.") + +flags.DEFINE_string( + "test_data_dir", None, + "The input test data dir. Should contain the .tsv files (or other data files) " + "for the task.") + flags.DEFINE_string( "bert_config_file", None, "The config json file corresponding to the pre-trained BERT model. " @@ -71,6 +83,8 @@ flags.DEFINE_bool("do_train", False, "Whether to run training.") flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") +flags.DEFINE_bool("do_train_eval", False, "Whether to run train with eval.") + flags.DEFINE_bool( "do_predict", False, "Whether to run the model in inference mode on the test set.") @@ -91,7 +105,7 @@ flags.DEFINE_float( "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10% of training.") -flags.DEFINE_integer("save_checkpoints_steps", 1000, +flags.DEFINE_integer("save_checkpoints_steps", None, "How often to save the model checkpoint.") flags.DEFINE_integer("iterations_per_loop", 1000, @@ -99,31 +113,98 @@ flags.DEFINE_integer("iterations_per_loop", 1000, flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") -tf.flags.DEFINE_string( +flags.DEFINE_string( "tpu_name", None, "The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " "url.") -tf.flags.DEFINE_string( +flags.DEFINE_string( "tpu_zone", None, "[Optional] GCE zone where the Cloud TPU is located in. If not " "specified, we will attempt to automatically detect the GCE project from " "metadata.") -tf.flags.DEFINE_string( +flags.DEFINE_string( "gcp_project", None, "[Optional] Project name for the Cloud TPU-enabled project. If not " "specified, we will attempt to automatically detect the GCE project from " "metadata.") -tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") +flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") flags.DEFINE_integer( "num_tpu_cores", 8, "Only used if `use_tpu` is True. Total number of TPU cores to use.") +flags.DEFINE_bool( + "do_export", False, + "Whether to export the model.") + +flags.DEFINE_string( + "export_dir", None, + "The dir where the exported model will be written.") + +flags.DEFINE_string("label_list", None, "Label list.") + +flags.DEFINE_integer("bad_label_num", 1, "Bad label num.") + +flags.DEFINE_bool("add_header", False, "Add header.") + +flags.DEFINE_bool("use_tfrecord", False, "Use tfrecord.") + +flags.DEFINE_bool("use_validation_tfrecord", False, "Use validation tfrecord.") + +flags.DEFINE_bool("use_test_tfrecord", False, "Use test tfrecord.") + +flags.DEFINE_string("tfrecord_name", "train.tf_record", "tfrecord name.") + +flags.DEFINE_string("validation_tfrecord_name", "eval.tf_record", "validation tfrecord name.") + +flags.DEFINE_string("test_tfrecord_name", "predict.tf_record", "test tfrecord name.") + +flags.DEFINE_bool("clean_tfrecord", False, "Clean tfrecord.") + +flags.DEFINE_integer("train_examples_count", None, "Train examples count.") + +flags.DEFINE_integer("hooking_frequence", 100, "Hooking frequence.") + +flags.DEFINE_bool("reduce_log", False, "Reduce log.") + +flags.DEFINE_integer("keep_checkpoint_max", None, "Keep checkpoint max.") + +flags.DEFINE_bool("xla", True, "Whether to train with XLA optimization.") + +flags.DEFINE_bool("adjust_lr", True, "Whether to adjust learning_rate.") + +flags.DEFINE_integer("previous_train_steps", 0, "Previous train steps.") + +flags.DEFINE_integer("post_train_steps", 0, "Post train steps.") + +flags.DEFINE_bool("use_hvd", True, "Whether to use Horovod.") + +flags.DEFINE_bool("use_compression", True, "Whether to use compression in Horovod.") + +flags.DEFINE_bool("use_fp16", True, "Whether to use fp16.") + +flags.DEFINE_bool("cos_decay", False, "Whether to use cos decay.") + +flags.DEFINE_bool("use_lamb", False, "Whether to use lamb.") + +flags.DEFINE_bool("auto_recover", False, "Whether to use auto recover.") + +flags.DEFINE_string("recover_dir", None, "The output directory where the model checkpoints will be recovered.") + +flags.DEFINE_integer("ckpt_no", None, "Checkpoint number of model to be recovered.") + +flags.DEFINE_integer("ckpt_no_input", None, "Checkpoint number of input to be recovered.") + +flags.DEFINE_bool("clip", False, "Whether to use clip.") + +flags.DEFINE_bool("profile", False, "Whether to use profile.") + + class InputExample(object): """A single training/test example for simple sequence classification.""" @@ -165,11 +246,13 @@ class InputFeatures(object): input_ids, input_mask, segment_ids, + row_id, label_id, is_real_example=True): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids + self.row_id = row_id self.label_id = label_id self.is_real_example = is_real_example @@ -203,6 +286,64 @@ class DataProcessor(object): lines.append(line) return lines + @classmethod + def _read_tsv_from_dir(cls, input_dir, quotechar=None): + """Reads a tab separated value file.""" + input_files = [input_dir + "/" + i for i in tf.gfile.ListDirectory(input_dir)] + lines = [] + for input_file in input_files: + with tf.gfile.Open(input_file, "r") as f: + reader = csv.reader((line.replace('\0', '') for line in f), delimiter="\t", quotechar=quotechar) + for line in reader: + lines.append(line) + return lines + + @classmethod + def _read_tsv_from_dir_by_name(cls, input_dir, quotechar=None, name='0'): + """Reads a tab separated value file.""" + lines = [] + input_file = input_dir + "/" + name + with tf.gfile.Open(input_file, "r") as f: + reader = csv.reader((line.replace('\0', '') for line in f), delimiter="\t", quotechar=quotechar) + for line in reader: + lines.append(line) + return lines + + +class QKProcessor(DataProcessor): + """Processor for the MRPC data set (GLUE version).""" + + def get_train_examples(self, data_dir): + """See base class.""" + # return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + return self._create_examples(self._read_tsv_from_dir(data_dir), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + # return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + return self._create_examples(self._read_tsv_from_dir(data_dir), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + # return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + return self._create_examples(self._read_tsv_from_dir(data_dir), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training and dev sets.""" + examples = [] + for (i, line) in enumerate(lines): + guid = tokenization.convert_to_unicode(line[0]) + label = tokenization.convert_to_unicode(line[1]) + text_a = tokenization.convert_to_unicode(line[2]) + text_b = tokenization.convert_to_unicode(line[3]) + + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + class XnliProcessor(DataProcessor): """Processor for the XNLI data set.""" @@ -383,6 +524,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, input_ids=[0] * max_seq_length, input_mask=[0] * max_seq_length, segment_ids=[0] * max_seq_length, + row_id=0, label_id=0, is_real_example=False) @@ -471,6 +613,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, + row_id=int(example.guid), label_id=label_id, is_real_example=True) return feature @@ -497,6 +640,7 @@ def file_based_convert_examples_to_features( features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) + features["input_rowid"] = create_int_feature([feature.row_id]) features["label_ids"] = create_int_feature([feature.label_id]) features["is_real_example"] = create_int_feature( [int(feature.is_real_example)]) @@ -507,13 +651,14 @@ def file_based_convert_examples_to_features( def file_based_input_fn_builder(input_file, seq_length, is_training, - drop_remainder): + drop_remainder, batch_size=None, use_hvd=True): """Creates an `input_fn` closure to be passed to TPUEstimator.""" name_to_features = { "input_ids": tf.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.FixedLenFeature([seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), + "input_rowid": tf.FixedLenFeature([], tf.int64), "label_ids": tf.FixedLenFeature([], tf.int64), "is_real_example": tf.FixedLenFeature([], tf.int64), } @@ -526,7 +671,7 @@ def file_based_input_fn_builder(input_file, seq_length, is_training, # So cast all int64 to int32. for name in list(example.keys()): t = example[name] - if t.dtype == tf.int64: + if t.dtype == tf.int64 and name != "input_rowid": t = tf.to_int32(t) example[name] = t @@ -534,12 +679,17 @@ def file_based_input_fn_builder(input_file, seq_length, is_training, def input_fn(params): """The actual input function.""" - batch_size = params["batch_size"] + # batch_size = params["batch_size"] # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. d = tf.data.TFRecordDataset(input_file) if is_training: + + if use_hvd: + d = d.shard(hvd.size(), hvd.rank()) #TODO only for Horovod, shard to mimic single_GPU = False + print("Data shard: %s %s" % (hvd.size(), hvd.rank())) + d = d.repeat() d = d.shuffle(buffer_size=100) @@ -572,7 +722,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, - labels, num_labels, use_one_hot_embeddings): + labels, num_labels, use_one_hot_embeddings, use_fp16, clip): """Creates a classification model.""" model = modeling.BertModel( config=bert_config, @@ -580,7 +730,8 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) + use_one_hot_embeddings=use_one_hot_embeddings, + compute_type=tf.float16 if use_fp16 else tf.float32) # In the demo, we are doing a simple classification task on the entire # segment. @@ -605,20 +756,30 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) - probabilities = tf.nn.softmax(logits, axis=-1) - log_probs = tf.nn.log_softmax(logits, axis=-1) + if clip: + probabilities = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6) + log_probs = tf.log(probabilities) + else: + probabilities = tf.nn.softmax(logits, axis=-1) + log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) + p0 = tf.reduce_sum(probabilities[:, 0:FLAGS.bad_label_num], axis=-1) + p1 = tf.subtract(1.0, p0) + probabilities = tf.stack([p0, p1], axis=-1) + return (loss, per_example_loss, logits, probabilities) def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, - use_one_hot_embeddings): + use_one_hot_embeddings, adjust_lr, use_hvd, + use_compression, use_fp16, clip, cos_decay, + use_lamb, previous_train_steps, post_train_steps): """Returns `model_fn` closure for TPUEstimator.""" def model_fn(features, labels, mode, params): # pylint: disable=unused-argument @@ -631,6 +792,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] + input_rowid = features["input_rowid"] label_ids = features["label_ids"] is_real_example = None if "is_real_example" in features: @@ -642,7 +804,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, (total_loss, per_example_loss, logits, probabilities) = create_model( bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, - num_labels, use_one_hot_embeddings) + num_labels, use_one_hot_embeddings, use_fp16, clip) tvars = tf.trainable_variables() initialized_variable_names = {} @@ -670,39 +832,42 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: + train_op, update_learning_rate = optimization.create_optimizer( + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd, + use_compression, use_fp16, clip, cos_decay, use_lamb, previous_train_steps, post_train_steps) - train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) - - output_spec = tf.contrib.tpu.TPUEstimatorSpec( + logging_hook = tf.train.LoggingTensorHook({"loss": total_loss, "learning_rate": update_learning_rate}, every_n_iter=FLAGS.hooking_frequence) + output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, - scaffold_fn=scaffold_fn) + training_hooks=[logging_hook]) elif mode == tf.estimator.ModeKeys.EVAL: - def metric_fn(per_example_loss, label_ids, logits, is_real_example): + def metric_fn(per_example_loss, label_ids, logits, probabilities, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy( labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) + rocauc = tf.metrics.auc(labels=label_ids, predictions=probabilities[:, 1], curve="ROC", summation_method="careful_interpolation", weights=is_real_example) + prauc = tf.metrics.auc(labels=label_ids, predictions=probabilities[:, 1], curve="PR", summation_method="careful_interpolation", weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, + "rocauc": rocauc, + "prauc": prauc, } - eval_metrics = (metric_fn, - [per_example_loss, label_ids, logits, is_real_example]) - output_spec = tf.contrib.tpu.TPUEstimatorSpec( + eval_metrics = metric_fn( + per_example_loss, label_ids, logits, probabilities, is_real_example) + output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, - eval_metrics=eval_metrics, - scaffold_fn=scaffold_fn) + eval_metric_ops=eval_metrics) else: - output_spec = tf.contrib.tpu.TPUEstimatorSpec( + output_spec = tf.estimator.EstimatorSpec( mode=mode, - predictions={"probabilities": probabilities}, - scaffold_fn=scaffold_fn) + predictions={"probabilities": probabilities, "labels": label_ids, "rowids": input_rowid}) return output_spec return model_fn @@ -783,17 +948,26 @@ def convert_examples_to_features(examples, label_list, max_seq_length, def main(_): tf.logging.set_verbosity(tf.logging.INFO) + if FLAGS.use_hvd: + hvd.init() + + if FLAGS.reduce_log and (hvd.rank() != 0): + tf.logging.set_verbosity(tf.logging.ERROR) + + FLAGS.output_dir = FLAGS.output_dir if hvd.rank() == 0 else os.path.join(FLAGS.output_dir, str(hvd.rank())) + processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, + "qk": QKProcessor, } tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) - if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: + if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict and not FLAGS.do_train_eval and not FLAGS.do_export: raise ValueError( "At least one of `do_train`, `do_eval` or `do_predict' must be True.") @@ -807,6 +981,35 @@ def main(_): tf.gfile.MakeDirs(FLAGS.output_dir) + if FLAGS.recover_dir is not None: + if FLAGS.use_hvd: + FLAGS.recover_dir = FLAGS.recover_dir if hvd.rank() == 0 else os.path.join(FLAGS.recover_dir, str(hvd.rank())) + path_ckpt = os.path.join(FLAGS.output_dir, "checkpoint") + path_ckpt_input = os.path.join(FLAGS.output_dir, "checkpoint_input") + + if FLAGS.ckpt_no is not None and not tf.gfile.Exists(path_ckpt): + with tf.gfile.GFile(path_ckpt, "w") as writer: + writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no))) + writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no))) + + if FLAGS.ckpt_no_input is not None and not tf.gfile.Exists(path_ckpt_input): + with tf.gfile.GFile(path_ckpt_input, "w") as writer: + writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input))) + writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input))) + + if FLAGS.use_hvd and hvd.rank() == 0 and (FLAGS.do_train or FLAGS.do_train_eval): + (vpath, vname) = os.path.split(FLAGS.vocab_file) + tf.gfile.Copy(FLAGS.vocab_file, os.path.join(FLAGS.output_dir, vname), True) + + (cpath, cname) = os.path.split(FLAGS.bert_config_file) + tf.gfile.Copy(FLAGS.bert_config_file, os.path.join(FLAGS.output_dir, cname), True) + + if FLAGS.validation_data_dir is None: + FLAGS.validation_data_dir = FLAGS.data_dir + + if FLAGS.test_data_dir is None: + FLAGS.test_data_dir = FLAGS.validation_data_dir + task_name = FLAGS.task_name.lower() if task_name not in processors: @@ -816,33 +1019,59 @@ def main(_): label_list = processor.get_labels() + if FLAGS.label_list: + label_list = FLAGS.label_list.split(",") + tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) - tpu_cluster_resolver = None - if FLAGS.use_tpu and FLAGS.tpu_name: - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.contrib.tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=FLAGS.master, - model_dir=FLAGS.output_dir, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, - tpu_config=tf.contrib.tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) - + num_gpu = 1 if not FLAGS.use_hvd else hvd.size() train_examples = None num_train_steps = None num_warmup_steps = None - if FLAGS.do_train: - train_examples = processor.get_train_examples(FLAGS.data_dir) - num_train_steps = int( - len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) - num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) + if FLAGS.do_train or FLAGS.do_train_eval: + if FLAGS.use_tfrecord: + if FLAGS.train_examples_count is None: + FLAGS.train_examples_count = 0 + for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name)): + FLAGS.train_examples_count += 1 + + num_train_steps = int( + FLAGS.train_examples_count / (FLAGS.train_batch_size * num_gpu) * FLAGS.num_train_epochs) + else: + train_examples = processor.get_train_examples(FLAGS.data_dir) + num_train_steps = int( + len(train_examples) / (FLAGS.train_batch_size * num_gpu) * FLAGS.num_train_epochs) + num_warmup_steps = int((num_train_steps + FLAGS.previous_train_steps + FLAGS.post_train_steps) * FLAGS.warmup_proportion) + + if FLAGS.save_checkpoints_steps is None: + FLAGS.save_checkpoints_steps = 1000 if num_train_steps is None else int(num_train_steps / FLAGS.num_train_epochs) + + if FLAGS.keep_checkpoint_max is None: + FLAGS.keep_checkpoint_max = int(FLAGS.num_train_epochs + 1.0) + + config = tf.ConfigProto() + if FLAGS.xla: + config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 + if FLAGS.use_hvd: + config.gpu_options.visible_device_list = str(hvd.local_rank()) + config.gpu_options.allow_growth=True + + run_config = tf.estimator.RunConfig( + model_dir=FLAGS.output_dir, + keep_checkpoint_max=FLAGS.keep_checkpoint_max, + save_checkpoints_steps=FLAGS.save_checkpoints_steps, + log_step_count_steps=FLAGS.hooking_frequence, + session_config=config) + + if FLAGS.use_hvd and hvd.rank() != 0 and not FLAGS.auto_recover: + run_config = tf.estimator.RunConfig( + model_dir=FLAGS.output_dir, + keep_checkpoint_max=FLAGS.keep_checkpoint_max, + save_checkpoints_steps=None, + save_checkpoints_secs=None, + log_step_count_steps=FLAGS.hooking_frequence, + session_config=config) model_fn = model_fn_builder( bert_config=bert_config, @@ -852,53 +1081,95 @@ def main(_): num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_tpu) + use_one_hot_embeddings=FLAGS.use_tpu, + adjust_lr=FLAGS.adjust_lr, + use_hvd=FLAGS.use_hvd, + use_compression=FLAGS.use_compression, + use_fp16=FLAGS.use_fp16, + clip=FLAGS.clip, + cos_decay=FLAGS.cos_decay, + use_lamb=FLAGS.use_lamb, + previous_train_steps=FLAGS.previous_train_steps, + post_train_steps=FLAGS.post_train_steps) + + hooks = [] + + if FLAGS.use_hvd: + hooks.append(hvd.BroadcastGlobalVariablesHook(0)) + + if hvd.rank() == -1: #if debug, set 0 + CLIDebugHook = tf_debug.LocalCLIDebugHook(ui_type='readline') + CLIDebugHook.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) + hooks.append(CLIDebugHook) + + if FLAGS.profile and hvd.rank() == 0: + ProfilerHook = tf.train.ProfilerHook(save_steps=FLAGS.hooking_frequence, output_dir=FLAGS.output_dir, show_dataflow=True, show_memory=True) + hooks.append(ProfilerHook) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. - estimator = tf.contrib.tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, + estimator = tf.estimator.Estimator( model_fn=model_fn, - config=run_config, - train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size, - predict_batch_size=FLAGS.predict_batch_size) + config=run_config) if FLAGS.do_train: - train_file = os.path.join(FLAGS.output_dir, "train.tf_record") - file_based_convert_examples_to_features( - train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) + if FLAGS.use_tfrecord: + train_file = os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name) + else: + train_file = os.path.join(FLAGS.output_dir, "train.tf_record") + file_based_convert_examples_to_features( + train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) tf.logging.info("***** Running training *****") - tf.logging.info(" Num examples = %d", len(train_examples)) + tf.logging.info(" Num examples = %d", FLAGS.train_examples_count if train_examples is None else len(train_examples)) tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Num steps = %d", num_train_steps) train_input_fn = file_based_input_fn_builder( input_file=train_file, seq_length=FLAGS.max_seq_length, is_training=True, - drop_remainder=True) - estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) + drop_remainder=True, + batch_size=FLAGS.train_batch_size, + use_hvd=FLAGS.use_hvd) + + if FLAGS.auto_recover: + hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator)) + + estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=hooks) + + if FLAGS.clean_tfrecord and tf.gfile.Exists(train_file): + tf.gfile.Remove(train_file) if FLAGS.do_eval: - eval_examples = processor.get_dev_examples(FLAGS.data_dir) - num_actual_eval_examples = len(eval_examples) - if FLAGS.use_tpu: - # TPU requires a fixed batch size for all batches, therefore the number - # of examples must be a multiple of the batch size, or else examples - # will get dropped. So we pad with fake examples which are ignored - # later on. These do NOT count towards the metric (all tf.metrics - # support a per-instance weight, and these get a weight of 0.0). - while len(eval_examples) % FLAGS.eval_batch_size != 0: - eval_examples.append(PaddingInputExample()) + if FLAGS.use_validation_tfrecord: + num_actual_eval_examples = 0 + for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)): + num_actual_eval_examples += 1 - eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") - file_based_convert_examples_to_features( - eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) + validation_examples_count = num_actual_eval_examples + + eval_file = os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name) + else: + eval_examples = processor.get_dev_examples(FLAGS.validation_data_dir) + num_actual_eval_examples = len(eval_examples) + if FLAGS.use_tpu: + # TPU requires a fixed batch size for all batches, therefore the number + # of examples must be a multiple of the batch size, or else examples + # will get dropped. So we pad with fake examples which are ignored + # later on. These do NOT count towards the metric (all tf.metrics + # support a per-instance weight, and these get a weight of 0.0). + while len(eval_examples) % FLAGS.eval_batch_size != 0: + eval_examples.append(PaddingInputExample()) + + validation_examples_count = len(eval_examples) + + eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") + file_based_convert_examples_to_features( + eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) tf.logging.info("***** Running evaluation *****") tf.logging.info(" Num examples = %d (%d actual, %d padding)", - len(eval_examples), num_actual_eval_examples, - len(eval_examples) - num_actual_eval_examples) + validation_examples_count, num_actual_eval_examples, + validation_examples_count - num_actual_eval_examples) tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) # This tells the estimator to run through the entire set. @@ -906,15 +1177,17 @@ def main(_): # However, if running eval on the TPU, you will need to specify the # number of steps. if FLAGS.use_tpu: - assert len(eval_examples) % FLAGS.eval_batch_size == 0 - eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) + assert validation_examples_count % FLAGS.eval_batch_size == 0 + eval_steps = int(validation_examples_count // FLAGS.eval_batch_size) eval_drop_remainder = True if FLAGS.use_tpu else False eval_input_fn = file_based_input_fn_builder( input_file=eval_file, seq_length=FLAGS.max_seq_length, is_training=False, - drop_remainder=eval_drop_remainder) + drop_remainder=eval_drop_remainder, + batch_size=FLAGS.eval_batch_size, + use_hvd=FLAGS.use_hvd) result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) @@ -925,26 +1198,40 @@ def main(_): tf.logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) - if FLAGS.do_predict: - predict_examples = processor.get_test_examples(FLAGS.data_dir) - num_actual_predict_examples = len(predict_examples) - if FLAGS.use_tpu: - # TPU requires a fixed batch size for all batches, therefore the number - # of examples must be a multiple of the batch size, or else examples - # will get dropped. So we pad with fake examples which are ignored - # later on. - while len(predict_examples) % FLAGS.predict_batch_size != 0: - predict_examples.append(PaddingInputExample()) + if FLAGS.clean_tfrecord and tf.gfile.Exists(eval_file): + tf.gfile.Remove(eval_file) - predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") - file_based_convert_examples_to_features(predict_examples, label_list, - FLAGS.max_seq_length, tokenizer, - predict_file) + if FLAGS.do_predict: + if FLAGS.use_test_tfrecord: + num_actual_predict_examples = 0 + for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.test_data_dir, FLAGS.test_tfrecord_name)): + num_actual_predict_examples += 1 + + test_examples_count = num_actual_predict_examples + + predict_file = os.path.join(FLAGS.test_data_dir, FLAGS.test_tfrecord_name) + else: + predict_examples = processor.get_test_examples(FLAGS.test_data_dir) + num_actual_predict_examples = len(predict_examples) + if FLAGS.use_tpu: + # TPU requires a fixed batch size for all batches, therefore the number + # of examples must be a multiple of the batch size, or else examples + # will get dropped. So we pad with fake examples which are ignored + # later on. + while len(predict_examples) % FLAGS.predict_batch_size != 0: + predict_examples.append(PaddingInputExample()) + + test_examples_count = len(predict_examples) + + predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") + file_based_convert_examples_to_features(predict_examples, label_list, + FLAGS.max_seq_length, tokenizer, + predict_file) tf.logging.info("***** Running prediction*****") tf.logging.info(" Num examples = %d (%d actual, %d padding)", - len(predict_examples), num_actual_predict_examples, - len(predict_examples) - num_actual_predict_examples) + test_examples_count, num_actual_predict_examples, + test_examples_count - num_actual_predict_examples) tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) predict_drop_remainder = True if FLAGS.use_tpu else False @@ -952,7 +1239,9 @@ def main(_): input_file=predict_file, seq_length=FLAGS.max_seq_length, is_training=False, - drop_remainder=predict_drop_remainder) + drop_remainder=predict_drop_remainder, + batch_size=FLAGS.predict_batch_size, + use_hvd=FLAGS.use_hvd) result = estimator.predict(input_fn=predict_input_fn) @@ -960,20 +1249,123 @@ def main(_): with tf.gfile.GFile(output_predict_file, "w") as writer: num_written_lines = 0 tf.logging.info("***** Predict results *****") + if FLAGS.add_header: + writer.write("rowids\tprobabilities\tlabels\n") for (i, prediction) in enumerate(result): probabilities = prediction["probabilities"] if i >= num_actual_predict_examples: break - output_line = "\t".join( - str(class_probability) - for class_probability in probabilities) + "\n" + output_line = str(prediction["rowids"]) + "\t" + str(probabilities[1]) + "\t" + str(prediction["labels"]) + "\n" writer.write(output_line) num_written_lines += 1 assert num_written_lines == num_actual_predict_examples + if FLAGS.clean_tfrecord and tf.gfile.Exists(predict_file): + tf.gfile.Remove(predict_file) + + if FLAGS.do_train_eval: + if FLAGS.use_tfrecord: + train_file = os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name) + else: + train_file = os.path.join(FLAGS.output_dir, "train.tf_record") + file_based_convert_examples_to_features( + train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) + tf.logging.info("***** Running training *****") + tf.logging.info(" Num examples = %d", FLAGS.train_examples_count if train_examples is None else len(train_examples)) + tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) + tf.logging.info(" Num steps = %d", num_train_steps) + train_input_fn = file_based_input_fn_builder( + input_file=train_file, + seq_length=FLAGS.max_seq_length, + is_training=True, + drop_remainder=True, + batch_size=FLAGS.train_batch_size, + use_hvd=FLAGS.use_hvd) + + if FLAGS.use_validation_tfrecord: + num_actual_eval_examples = 0 + for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)): + num_actual_eval_examples += 1 + + validation_examples_count = num_actual_eval_examples + + eval_file = os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name) + else: + eval_examples = processor.get_dev_examples(FLAGS.validation_data_dir) + num_actual_eval_examples = len(eval_examples) + if FLAGS.use_tpu: + # TPU requires a fixed batch size for all batches, therefore the number + # of examples must be a multiple of the batch size, or else examples + # will get dropped. So we pad with fake examples which are ignored + # later on. These do NOT count towards the metric (all tf.metrics + # support a per-instance weight, and these get a weight of 0.0). + while len(eval_examples) % FLAGS.eval_batch_size != 0: + eval_examples.append(PaddingInputExample()) + + validation_examples_count = len(eval_examples) + + eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") + file_based_convert_examples_to_features( + eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) + + tf.logging.info("***** Running evaluation *****") + tf.logging.info(" Num examples = %d (%d actual, %d padding)", + validation_examples_count, num_actual_eval_examples, + validation_examples_count - num_actual_eval_examples) + tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) + + # This tells the estimator to run through the entire set. + eval_steps = None + # However, if running eval on the TPU, you will need to specify the + # number of steps. + if FLAGS.use_tpu: + assert validation_examples_count % FLAGS.eval_batch_size == 0 + eval_steps = int(validation_examples_count // FLAGS.eval_batch_size) + + eval_drop_remainder = True if FLAGS.use_tpu else False + eval_input_fn = file_based_input_fn_builder( + input_file=eval_file, + seq_length=FLAGS.max_seq_length, + is_training=False, + drop_remainder=eval_drop_remainder, + batch_size=FLAGS.eval_batch_size, + use_hvd=FLAGS.use_hvd) + + if FLAGS.auto_recover: + hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator)) + + train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps, hooks=hooks) + eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=eval_steps) + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + + if FLAGS.clean_tfrecord and tf.gfile.Exists(train_file): + tf.gfile.Remove(train_file) + + if FLAGS.clean_tfrecord and tf.gfile.Exists(eval_file): + tf.gfile.Remove(eval_file) + + if FLAGS.do_export: + def serving_input_fn(): + input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids') + input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask') + segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids') + input_rowid = tf.placeholder(tf.int64, [None], name='input_rowid') + label_ids = tf.placeholder(tf.int32, [None], name='label_ids') + input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ + 'input_ids': input_ids, + 'input_mask': input_mask, + 'segment_ids': segment_ids, + 'input_rowid': input_rowid, + 'label_ids': label_ids, + })() + return input_fn + + estimator._export_to_tpu = False + estimator.export_savedmodel(FLAGS.export_dir, serving_input_fn) + if __name__ == "__main__": - flags.mark_flag_as_required("data_dir") + # flags.mark_flag_as_required("data_dir") flags.mark_flag_as_required("task_name") flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("bert_config_file") diff --git a/run_pretraining.py b/run_pretraining.py index b118f62..791d685 100644 --- a/run_pretraining.py +++ b/run_pretraining.py @@ -22,6 +22,8 @@ import os import modeling import optimization import tensorflow as tf +import horovod.tensorflow as hvd +from tensorflow.python import debug as tf_debug flags = tf.flags @@ -37,6 +39,18 @@ flags.DEFINE_string( "input_file", None, "Input TF example files (can be a glob or comma separated).") +flags.DEFINE_string( + "validation_input_file", None, + "Input validation TF example files (can be a glob or comma separated).") + +flags.DEFINE_string( + "input_dir", None, + "Input TF example dir.") + +flags.DEFINE_string( + "validation_input_dir", None, + "Input validation TF example dir.") + flags.DEFINE_string( "output_dir", None, "The output directory where the model checkpoints will be written.") @@ -61,6 +75,8 @@ flags.DEFINE_bool("do_train", False, "Whether to run training.") flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") +flags.DEFINE_bool("do_train_eval", False, "Whether to run train with eval.") + flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") @@ -77,7 +93,7 @@ flags.DEFINE_integer("save_checkpoints_steps", 1000, flags.DEFINE_integer("iterations_per_loop", 1000, "How many steps to make in each estimator call.") -flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") +flags.DEFINE_integer("max_eval_steps", None, "Maximum number of eval steps.") flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") @@ -106,9 +122,48 @@ flags.DEFINE_integer( "Only used if `use_tpu` is True. Total number of TPU cores to use.") +flags.DEFINE_integer("hooking_frequence", 100, "Hooking frequence.") + +flags.DEFINE_bool("reduce_log", False, "Reduce log.") + +flags.DEFINE_integer("keep_checkpoint_max", 1, "Keep checkpoint max.") + +flags.DEFINE_bool("xla", True, "Whether to train with XLA optimization.") + +flags.DEFINE_bool("adjust_lr", True, "Whether to adjust learning_rate.") + +flags.DEFINE_integer("previous_train_steps", 0, "Previous train steps.") + +flags.DEFINE_integer("post_train_steps", 0, "Post train steps.") + +flags.DEFINE_bool("use_hvd", True, "Whether to use Horovod.") + +flags.DEFINE_bool("use_compression", True, "Whether to use compression in Horovod.") + +flags.DEFINE_bool("use_fp16", True, "Whether to use fp16.") + +flags.DEFINE_bool("cos_decay", False, "Whether to use cos decay.") + +flags.DEFINE_bool("use_lamb", False, "Whether to use lamb.") + +flags.DEFINE_bool("auto_recover", False, "Whether to use auto recover.") + +flags.DEFINE_string("recover_dir", None, "The output directory where the model checkpoints will be recovered.") + +flags.DEFINE_integer("ckpt_no", None, "Checkpoint number of model to be recovered.") + +flags.DEFINE_integer("ckpt_no_input", None, "Checkpoint number of input to be recovered.") + +flags.DEFINE_bool("clip", False, "Whether to use clip.") + +flags.DEFINE_bool("profile", False, "Whether to use profile.") + + def model_fn_builder(bert_config, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, - use_one_hot_embeddings): + use_one_hot_embeddings, adjust_lr, use_hvd, + use_compression, use_fp16, clip, cos_decay, + use_lamb, previous_train_steps, post_train_steps): """Returns `model_fn` closure for TPUEstimator.""" def model_fn(features, labels, mode, params): # pylint: disable=unused-argument @@ -134,16 +189,17 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, - use_one_hot_embeddings=use_one_hot_embeddings) + use_one_hot_embeddings=use_one_hot_embeddings, + compute_type=tf.float16 if use_fp16 else tf.float32) (masked_lm_loss, - masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( - bert_config, model.get_sequence_output(), model.get_embedding_table(), - masked_lm_positions, masked_lm_ids, masked_lm_weights) + masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( + bert_config, model.get_sequence_output(), model.get_embedding_table(), + masked_lm_positions, masked_lm_ids, masked_lm_weights, clip) (next_sentence_loss, next_sentence_example_loss, - next_sentence_log_probs) = get_next_sentence_output( - bert_config, model.get_pooled_output(), next_sentence_labels) + next_sentence_log_probs) = get_next_sentence_output( + bert_config, model.get_pooled_output(), next_sentence_labels, clip) total_loss = masked_lm_loss + next_sentence_loss @@ -174,14 +230,16 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate, output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: - train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) + train_op, update_learning_rate = optimization.create_optimizer( + total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd, + use_compression, use_fp16, clip, cos_decay, use_lamb, previous_train_steps, post_train_steps) - output_spec = tf.contrib.tpu.TPUEstimatorSpec( + logging_hook = tf.train.LoggingTensorHook({"loss": total_loss, "learning_rate": update_learning_rate}, every_n_iter=FLAGS.hooking_frequence) + output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, - scaffold_fn=scaffold_fn) + training_hooks=[logging_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, @@ -219,16 +277,15 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate, "next_sentence_loss": next_sentence_mean_loss, } - eval_metrics = (metric_fn, [ + eval_metrics = metric_fn( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels - ]) - output_spec = tf.contrib.tpu.TPUEstimatorSpec( + ) + output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, - eval_metrics=eval_metrics, - scaffold_fn=scaffold_fn) + eval_metric_ops=eval_metrics) else: raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) @@ -238,7 +295,7 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate, def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, - label_ids, label_weights): + label_ids, label_weights, clip): """Get loss and log probs for the masked LM.""" input_tensor = gather_indexes(input_tensor, positions) @@ -262,7 +319,10 @@ def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) - log_probs = tf.nn.log_softmax(logits, axis=-1) + if clip: + log_probs = tf.log(tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6)) + else: + log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) @@ -282,7 +342,7 @@ def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, return (loss, per_example_loss, log_probs) -def get_next_sentence_output(bert_config, input_tensor, labels): +def get_next_sentence_output(bert_config, input_tensor, labels, clip): """Get loss and log probs for the next sentence prediction.""" # Simple binary classification. Note that 0 is "next sentence" and 1 is @@ -297,7 +357,10 @@ def get_next_sentence_output(bert_config, input_tensor, labels): logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) - log_probs = tf.nn.log_softmax(logits, axis=-1) + if clip: + log_probs = tf.log(tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6)) + else: + log_probs = tf.nn.log_softmax(logits, axis=-1) labels = tf.reshape(labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) @@ -325,12 +388,14 @@ def input_fn_builder(input_files, max_seq_length, max_predictions_per_seq, is_training, - num_cpu_threads=4): + num_cpu_threads=4, + batch_size=None, + use_hvd=True): """Creates an `input_fn` closure to be passed to TPUEstimator.""" def input_fn(params): """The actual input function.""" - batch_size = params["batch_size"] + # batch_size = params["batch_size"] name_to_features = { "input_ids": @@ -353,6 +418,11 @@ def input_fn_builder(input_files, # For eval, we want no shuffling and parallel reading doesn't matter. if is_training: d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) + + if use_hvd: + d = d.shard(hvd.size(), hvd.rank()) #TODO only for Horovod, shard to mimic single_GPU = False + print("Data shard: %s %s" % (hvd.size(), hvd.rank())) + d = d.repeat() d = d.shuffle(buffer_size=len(input_files)) @@ -371,7 +441,7 @@ def input_fn_builder(input_files, d = tf.data.TFRecordDataset(input_files) # Since we evaluate for a fixed number of steps we don't want to encounter # out-of-range exceptions. - d = d.repeat() + # d = d.repeat() # We must `drop_remainder` on training because the TPU requires fixed # size dimensions. For eval, we assume we are evaluating on the CPU or GPU @@ -406,36 +476,90 @@ def _decode_record(record, name_to_features): def main(_): tf.logging.set_verbosity(tf.logging.INFO) - if not FLAGS.do_train and not FLAGS.do_eval: + if FLAGS.use_hvd: + hvd.init() + + if FLAGS.reduce_log and (hvd.rank() != 0): + tf.logging.set_verbosity(tf.logging.ERROR) + + FLAGS.output_dir = FLAGS.output_dir if hvd.rank() == 0 else os.path.join(FLAGS.output_dir, str(hvd.rank())) + + if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_train_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) tf.gfile.MakeDirs(FLAGS.output_dir) + if FLAGS.recover_dir is not None: + if FLAGS.use_hvd: + FLAGS.recover_dir = FLAGS.recover_dir if hvd.rank() == 0 else os.path.join(FLAGS.recover_dir, str(hvd.rank())) + path_ckpt = os.path.join(FLAGS.output_dir, "checkpoint") + path_ckpt_input = os.path.join(FLAGS.output_dir, "checkpoint_input") + + if FLAGS.ckpt_no is not None and not tf.gfile.Exists(path_ckpt): + with tf.gfile.GFile(path_ckpt, "w") as writer: + writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no))) + writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no))) + + if FLAGS.ckpt_no_input is not None and not tf.gfile.Exists(path_ckpt_input): + with tf.gfile.GFile(path_ckpt_input, "w") as writer: + writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input))) + writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input))) + + if FLAGS.use_hvd and hvd.rank() == 0 and (FLAGS.do_train or FLAGS.do_train_eval): + (cpath, cname) = os.path.split(FLAGS.bert_config_file) + tf.gfile.Copy(FLAGS.bert_config_file, os.path.join(FLAGS.output_dir, cname), True) + input_files = [] - for input_pattern in FLAGS.input_file.split(","): - input_files.extend(tf.gfile.Glob(input_pattern)) + if FLAGS.input_file is not None: + for input_pattern in FLAGS.input_file.split(","): + input_files.extend(tf.gfile.Glob(input_pattern)) + if FLAGS.input_dir is not None: + for filename in tf.gfile.ListDirectory(FLAGS.input_dir): + input_files.extend(tf.gfile.Glob(os.path.join(FLAGS.input_dir, filename))) tf.logging.info("*** Input Files ***") for input_file in input_files: tf.logging.info(" %s" % input_file) - tpu_cluster_resolver = None - if FLAGS.use_tpu and FLAGS.tpu_name: - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) + validation_input_files = [] + if FLAGS.validation_input_file is None and FLAGS.validation_input_dir is None: + validation_input_files = input_files + else: + if FLAGS.validation_input_file is not None: + for input_pattern in FLAGS.validation_input_file.split(","): + validation_input_files.extend(tf.gfile.Glob(input_pattern)) + if FLAGS.validation_input_dir is not None: + for filename in tf.gfile.ListDirectory(FLAGS.validation_input_dir): + validation_input_files.extend(tf.gfile.Glob(os.path.join(FLAGS.validation_input_dir, filename))) - is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.contrib.tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=FLAGS.master, + tf.logging.info("*** Input Validation Files ***") + for input_file in validation_input_files: + tf.logging.info(" %s" % input_file) + + config = tf.ConfigProto() + if FLAGS.xla: + config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 + if FLAGS.use_hvd: + config.gpu_options.visible_device_list = str(hvd.local_rank()) + config.gpu_options.allow_growth=True + + run_config = tf.estimator.RunConfig( model_dir=FLAGS.output_dir, + keep_checkpoint_max=FLAGS.keep_checkpoint_max, save_checkpoints_steps=FLAGS.save_checkpoints_steps, - tpu_config=tf.contrib.tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) + log_step_count_steps=FLAGS.hooking_frequence, + session_config=config) + + if FLAGS.use_hvd and hvd.rank() != 0 and not FLAGS.auto_recover: + run_config = tf.estimator.RunConfig( + model_dir=FLAGS.output_dir, + keep_checkpoint_max=FLAGS.keep_checkpoint_max, + save_checkpoints_steps=None, + save_checkpoints_secs=None, + log_step_count_steps=FLAGS.hooking_frequence, + session_config=config) model_fn = model_fn_builder( bert_config=bert_config, @@ -444,16 +568,36 @@ def main(_): num_train_steps=FLAGS.num_train_steps, num_warmup_steps=FLAGS.num_warmup_steps, use_tpu=FLAGS.use_tpu, - use_one_hot_embeddings=FLAGS.use_tpu) + use_one_hot_embeddings=FLAGS.use_tpu, + adjust_lr=FLAGS.adjust_lr, + use_hvd=FLAGS.use_hvd, + use_compression=FLAGS.use_compression, + use_fp16=FLAGS.use_fp16, + clip=FLAGS.clip, + cos_decay=FLAGS.cos_decay, + use_lamb=FLAGS.use_lamb, + previous_train_steps=FLAGS.previous_train_steps, + post_train_steps=FLAGS.post_train_steps) + + hooks = [] + + if FLAGS.use_hvd: + hooks.append(hvd.BroadcastGlobalVariablesHook(0)) + + if hvd.rank() == -1: #if debug, set 0 + CLIDebugHook = tf_debug.LocalCLIDebugHook(ui_type='readline') + CLIDebugHook.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) + hooks.append(CLIDebugHook) + + if FLAGS.profile and hvd.rank() == 0: + ProfilerHook = tf.train.ProfilerHook(save_steps=FLAGS.hooking_frequence, output_dir=FLAGS.output_dir, show_dataflow=True, show_memory=True) + hooks.append(ProfilerHook) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. - estimator = tf.contrib.tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, + estimator = tf.estimator.Estimator( model_fn=model_fn, - config=run_config, - train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size) + config=run_config) if FLAGS.do_train: tf.logging.info("***** Running training *****") @@ -462,18 +606,26 @@ def main(_): input_files=input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, - is_training=True) - estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) + is_training=True, + batch_size=FLAGS.train_batch_size, + use_hvd=FLAGS.use_hvd) + + if FLAGS.auto_recover: + hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator)) + + estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks) if FLAGS.do_eval: tf.logging.info("***** Running evaluation *****") tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) eval_input_fn = input_fn_builder( - input_files=input_files, + input_files=validation_input_files, max_seq_length=FLAGS.max_seq_length, max_predictions_per_seq=FLAGS.max_predictions_per_seq, - is_training=False) + is_training=False, + batch_size=FLAGS.eval_batch_size, + use_hvd=FLAGS.use_hvd) result = estimator.evaluate( input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) @@ -485,9 +637,37 @@ def main(_): tf.logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) + if FLAGS.do_train_eval: + tf.logging.info("***** Running training *****") + tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) + train_input_fn = input_fn_builder( + input_files=input_files, + max_seq_length=FLAGS.max_seq_length, + max_predictions_per_seq=FLAGS.max_predictions_per_seq, + is_training=True, + batch_size=FLAGS.train_batch_size, + use_hvd=FLAGS.use_hvd) + + tf.logging.info("***** Running evaluation *****") + tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) + eval_input_fn = input_fn_builder( + input_files=validation_input_files, + max_seq_length=FLAGS.max_seq_length, + max_predictions_per_seq=FLAGS.max_predictions_per_seq, + is_training=False, + batch_size=FLAGS.eval_batch_size, + use_hvd=FLAGS.use_hvd) + + if FLAGS.auto_recover: + hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator)) + + train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks) + eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + if __name__ == "__main__": - flags.mark_flag_as_required("input_file") + # flags.mark_flag_as_required("input_file") flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("output_dir") tf.app.run()