This commit is contained in:
Yuefeng Zhan 2021-03-05 01:52:07 +08:00
Родитель 2671f436e2
Коммит 6efe228d60
5 изменённых файлов: 964 добавлений и 174 удалений

36
gpu_environment.py Normal file
Просмотреть файл

@ -0,0 +1,36 @@
# coding=utf-8
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np
def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
initializer=None, regularizer=None,
trainable=True,
*args, **kwargs):
"""Custom variable getter that forces trainable variables to be stored in
float32 precision and then casts them to the training precision.
"""
storage_dtype = tf.float32 if trainable else dtype
variable = getter(name, shape, dtype=storage_dtype,
initializer=initializer, regularizer=regularizer,
trainable=trainable,
*args, **kwargs)
if trainable and dtype != tf.float32:
variable = tf.cast(variable, dtype)
return variable
def get_custom_getter(compute_type):
return float32_variable_storage_getter if compute_type == tf.float16 else None

Просмотреть файл

@ -26,6 +26,7 @@ import re
import numpy as np
import six
import tensorflow as tf
from gpu_environment import get_custom_getter
class BertConfig(object):
@ -135,7 +136,8 @@ class BertModel(object):
input_mask=None,
token_type_ids=None,
use_one_hot_embeddings=False,
scope=None):
scope=None,
compute_type=tf.float32):
"""Constructor for BertModel.
Args:
@ -168,7 +170,7 @@ class BertModel(object):
if token_type_ids is None:
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
with tf.variable_scope(scope, default_name="bert"):
with tf.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)):
with tf.variable_scope("embeddings"):
# Perform embedding lookup on the word ids.
(self.embedding_output, self.embedding_table) = embedding_lookup(
@ -203,7 +205,7 @@ class BertModel(object):
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self.all_encoder_layers = transformer_model(
input_tensor=self.embedding_output,
input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
attention_mask=attention_mask,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
@ -215,7 +217,7 @@ class BertModel(object):
initializer_range=config.initializer_range,
do_return_all_layers=True)
self.sequence_output = self.all_encoder_layers[-1]
self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.float32)
# The "pooler" converts the encoded sequence tensor of shape
# [batch_size, seq_length, hidden_size] to a tensor of shape
# [batch_size, hidden_size]. This is necessary for segment-level
@ -709,7 +711,7 @@ def attention_layer(from_tensor,
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.

Просмотреть файл

@ -21,26 +21,52 @@ from __future__ import print_function
import re
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
# pylint: enable=g-direct-tensorflow-import
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
import horovod.tensorflow as hvd
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd,
use_compression, use_fp16, clip, cos_decay, use_lamb=False,
previous_train_steps=0, post_train_steps=0):
"""Creates an optimizer training op."""
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
# Implements linear decay of the learning rate.
learning_rate = tf.train.polynomial_decay(
learning_rate,
global_step,
num_train_steps,
end_learning_rate=0.0,
power=1.0,
cycle=False)
# if adjust_lr:
# # avoid step change in learning rate at end of warmup phase
# decayed_learning_rate_at_crossover_point = init_lr * (1.0-float(num_warmup_steps)/float(num_train_steps + previous_train_steps + post_train_steps))
# adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point)
# print('decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e' % (decayed_learning_rate_at_crossover_point, adjusted_init_lr))
# learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32)
if cos_decay:
# Implements cosine decay of the learning rate.
learning_rate = tf.train.cosine_decay(
learning_rate,
global_step + previous_train_steps - (num_warmup_steps if adjust_lr else 0),
num_train_steps + previous_train_steps + post_train_steps - (num_warmup_steps if adjust_lr else 0),
alpha=0.0)
else:
# Implements linear decay of the learning rate.
learning_rate = tf.train.polynomial_decay(
learning_rate,
global_step + previous_train_steps - (num_warmup_steps if adjust_lr else 0),
num_train_steps + previous_train_steps + post_train_steps - (num_warmup_steps if adjust_lr else 0),
end_learning_rate=0.0,
power=1.0,
cycle=False)
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
if num_warmup_steps:
global_steps_int = tf.cast(global_step, tf.int32)
global_steps_int = tf.cast(global_step + previous_train_steps, tf.int32)
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
global_steps_float = tf.cast(global_steps_int, tf.float32)
@ -56,7 +82,16 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
# It is recommended that you use this optimizer for fine tuning, since this
# is how the model was trained (note that the Adam m/v variables are NOT
# loaded from init_checkpoint.)
optimizer = AdamWeightDecayOptimizer(
if use_lamb:
optimizer = LAMBOptimizer(
learning_rate=learning_rate,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
else:
optimizer = AdamWeightDecayOptimizer(
learning_rate=learning_rate,
weight_decay_rate=0.01,
beta_1=0.9,
@ -64,24 +99,51 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
tf.summary.scalar("learning_rate", optimizer.learning_rate)
if use_hvd:
optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=hvd.Compression.fp16 if use_compression else hvd.Compression.none)
if use_fp16:
loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=2**32, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5)
optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)
if use_tpu:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
all_are_finite = tf.constant(True, dtype=tf.bool)
tvars = tf.trainable_variables()
grads = tf.gradients(loss, tvars)
if use_hvd:
grads_and_vars = optimizer.compute_gradients(loss, tvars)
grads_and_vars = [(g,v) for g,v in grads_and_vars if g is not None]
grads, tvars = list(zip(*grads_and_vars))
all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 else tf.constant(True, dtype=tf.bool)
else:
grads = tf.gradients(loss, tvars)
# This is how the model was pre-trained.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
# ensure global norm is a finite number
# to prevent clip_by_global_norm from having a hizzy fit.
(clipped_grads, _) = tf.clip_by_global_norm(
grads, clip_norm=1.0,
use_norm=tf.cond(
all_are_finite,
lambda: tf.global_norm(grads),
lambda: tf.constant(1.0)))
train_op = optimizer.apply_gradients(
zip(grads, tvars), global_step=global_step)
list(zip(clipped_grads, tvars)), global_step=global_step)
# Normally the global step update is done inside of `apply_gradients`.
# However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
# a different optimizer, you should probably take this line out.
new_global_step = global_step + 1
if clip:
new_global_step = global_step + 1
else:
new_global_step = tf.cond(all_are_finite, lambda: global_step+1, lambda: global_step)
new_global_step = tf.identity(new_global_step, name='step_update')
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
return train_op
return train_op, learning_rate
class AdamWeightDecayOptimizer(tf.train.Optimizer):
@ -98,7 +160,7 @@ class AdamWeightDecayOptimizer(tf.train.Optimizer):
"""Constructs a AdamWeightDecayOptimizer."""
super(AdamWeightDecayOptimizer, self).__init__(False, name)
self.learning_rate = learning_rate
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
@ -172,3 +234,121 @@ class AdamWeightDecayOptimizer(tf.train.Optimizer):
if m is not None:
param_name = m.group(1)
return param_name
class LAMBOptimizer(tf.train.Optimizer):
"""LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
# A new optimizer that includes correct L2 weight decay, adaptive
# element-wise updating, and layer-wise justification. The LAMB optimizer
# was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
# James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
# Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
def __init__(self,
learning_rate,
weight_decay_rate=0.0,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=None,
exclude_from_layer_adaptation=None,
name="LAMBOptimizer"):
"""Constructs a LAMBOptimizer."""
super(LAMBOptimizer, self).__init__(False, name)
self.learning_rate = learning_rate
self.weight_decay_rate = weight_decay_rate
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
# TODO(jingli): validate if exclude_from_layer_adaptation is necessary.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""See base class."""
assignments = []
for (grad, param) in grads_and_vars:
if grad is None or param is None:
continue
param_name = self._get_variable_name(param.name)
m = tf.get_variable(
name=param_name + "/adam_m",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
v = tf.get_variable(
name=param_name + "/adam_v",
shape=param.shape.as_list(),
dtype=tf.float32,
trainable=False,
initializer=tf.zeros_initializer())
# Standard Adam update.
next_m = (
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
next_v = (
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
tf.square(grad)))
update = next_m / (tf.sqrt(next_v) + self.epsilon)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want ot decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if self._do_use_weight_decay(param_name):
update += self.weight_decay_rate * param
ratio = 1.0
if self._do_layer_adaptation(param_name):
w_norm = linalg_ops.norm(param, ord=2)
g_norm = linalg_ops.norm(update, ord=2)
ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
update_with_lr = ratio * self.learning_rate * update
next_param = param - update_with_lr
assignments.extend(
[param.assign(next_param),
m.assign(next_m),
v.assign(next_v)])
return tf.group(*assignments, name=name)
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay_rate:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True
def _get_variable_name(self, param_name):
"""Get the variable name from the tensor name."""
m = re.match("^(.*):\\d+$", param_name)
if m is not None:
param_name = m.group(1)
return param_name

Просмотреть файл

@ -25,6 +25,8 @@ import modeling
import optimization
import tokenization
import tensorflow as tf
import horovod.tensorflow as hvd
from tensorflow.python import debug as tf_debug
flags = tf.flags
@ -36,6 +38,16 @@ flags.DEFINE_string(
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string(
"validation_data_dir", None,
"The input validation data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string(
"test_data_dir", None,
"The input test data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string(
"bert_config_file", None,
"The config json file corresponding to the pre-trained BERT model. "
@ -71,6 +83,8 @@ flags.DEFINE_bool("do_train", False, "Whether to run training.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_bool("do_train_eval", False, "Whether to run train with eval.")
flags.DEFINE_bool(
"do_predict", False,
"Whether to run the model in inference mode on the test set.")
@ -91,7 +105,7 @@ flags.DEFINE_float(
"Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10% of training.")
flags.DEFINE_integer("save_checkpoints_steps", 1000,
flags.DEFINE_integer("save_checkpoints_steps", None,
"How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 1000,
@ -99,31 +113,98 @@ flags.DEFINE_integer("iterations_per_loop", 1000,
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
tf.flags.DEFINE_string(
flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
flags.DEFINE_string(
"gcp_project", None,
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
flags.DEFINE_bool(
"do_export", False,
"Whether to export the model.")
flags.DEFINE_string(
"export_dir", None,
"The dir where the exported model will be written.")
flags.DEFINE_string("label_list", None, "Label list.")
flags.DEFINE_integer("bad_label_num", 1, "Bad label num.")
flags.DEFINE_bool("add_header", False, "Add header.")
flags.DEFINE_bool("use_tfrecord", False, "Use tfrecord.")
flags.DEFINE_bool("use_validation_tfrecord", False, "Use validation tfrecord.")
flags.DEFINE_bool("use_test_tfrecord", False, "Use test tfrecord.")
flags.DEFINE_string("tfrecord_name", "train.tf_record", "tfrecord name.")
flags.DEFINE_string("validation_tfrecord_name", "eval.tf_record", "validation tfrecord name.")
flags.DEFINE_string("test_tfrecord_name", "predict.tf_record", "test tfrecord name.")
flags.DEFINE_bool("clean_tfrecord", False, "Clean tfrecord.")
flags.DEFINE_integer("train_examples_count", None, "Train examples count.")
flags.DEFINE_integer("hooking_frequence", 100, "Hooking frequence.")
flags.DEFINE_bool("reduce_log", False, "Reduce log.")
flags.DEFINE_integer("keep_checkpoint_max", None, "Keep checkpoint max.")
flags.DEFINE_bool("xla", True, "Whether to train with XLA optimization.")
flags.DEFINE_bool("adjust_lr", True, "Whether to adjust learning_rate.")
flags.DEFINE_integer("previous_train_steps", 0, "Previous train steps.")
flags.DEFINE_integer("post_train_steps", 0, "Post train steps.")
flags.DEFINE_bool("use_hvd", True, "Whether to use Horovod.")
flags.DEFINE_bool("use_compression", True, "Whether to use compression in Horovod.")
flags.DEFINE_bool("use_fp16", True, "Whether to use fp16.")
flags.DEFINE_bool("cos_decay", False, "Whether to use cos decay.")
flags.DEFINE_bool("use_lamb", False, "Whether to use lamb.")
flags.DEFINE_bool("auto_recover", False, "Whether to use auto recover.")
flags.DEFINE_string("recover_dir", None, "The output directory where the model checkpoints will be recovered.")
flags.DEFINE_integer("ckpt_no", None, "Checkpoint number of model to be recovered.")
flags.DEFINE_integer("ckpt_no_input", None, "Checkpoint number of input to be recovered.")
flags.DEFINE_bool("clip", False, "Whether to use clip.")
flags.DEFINE_bool("profile", False, "Whether to use profile.")
class InputExample(object):
"""A single training/test example for simple sequence classification."""
@ -165,11 +246,13 @@ class InputFeatures(object):
input_ids,
input_mask,
segment_ids,
row_id,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.row_id = row_id
self.label_id = label_id
self.is_real_example = is_real_example
@ -203,6 +286,64 @@ class DataProcessor(object):
lines.append(line)
return lines
@classmethod
def _read_tsv_from_dir(cls, input_dir, quotechar=None):
"""Reads a tab separated value file."""
input_files = [input_dir + "/" + i for i in tf.gfile.ListDirectory(input_dir)]
lines = []
for input_file in input_files:
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader((line.replace('\0', '') for line in f), delimiter="\t", quotechar=quotechar)
for line in reader:
lines.append(line)
return lines
@classmethod
def _read_tsv_from_dir_by_name(cls, input_dir, quotechar=None, name='0'):
"""Reads a tab separated value file."""
lines = []
input_file = input_dir + "/" + name
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader((line.replace('\0', '') for line in f), delimiter="\t", quotechar=quotechar)
for line in reader:
lines.append(line)
return lines
class QKProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
# return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
return self._create_examples(self._read_tsv_from_dir(data_dir), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
# return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
return self._create_examples(self._read_tsv_from_dir(data_dir), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
# return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
return self._create_examples(self._read_tsv_from_dir(data_dir), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = tokenization.convert_to_unicode(line[0])
label = tokenization.convert_to_unicode(line[1])
text_a = tokenization.convert_to_unicode(line[2])
text_b = tokenization.convert_to_unicode(line[3])
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
@ -383,6 +524,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
input_ids=[0] * max_seq_length,
input_mask=[0] * max_seq_length,
segment_ids=[0] * max_seq_length,
row_id=0,
label_id=0,
is_real_example=False)
@ -471,6 +613,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
row_id=int(example.guid),
label_id=label_id,
is_real_example=True)
return feature
@ -497,6 +640,7 @@ def file_based_convert_examples_to_features(
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["input_rowid"] = create_int_feature([feature.row_id])
features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])
@ -507,13 +651,14 @@ def file_based_convert_examples_to_features(
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder):
drop_remainder, batch_size=None, use_hvd=True):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_rowid": tf.FixedLenFeature([], tf.int64),
"label_ids": tf.FixedLenFeature([], tf.int64),
"is_real_example": tf.FixedLenFeature([], tf.int64),
}
@ -526,7 +671,7 @@ def file_based_input_fn_builder(input_file, seq_length, is_training,
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
if t.dtype == tf.int64 and name != "input_rowid":
t = tf.to_int32(t)
example[name] = t
@ -534,12 +679,17 @@ def file_based_input_fn_builder(input_file, seq_length, is_training,
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# batch_size = params["batch_size"]
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
if use_hvd:
d = d.shard(hvd.size(), hvd.rank()) #TODO only for Horovod, shard to mimic single_GPU = False
print("Data shard: %s %s" % (hvd.size(), hvd.rank()))
d = d.repeat()
d = d.shuffle(buffer_size=100)
@ -572,7 +722,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
labels, num_labels, use_one_hot_embeddings):
labels, num_labels, use_one_hot_embeddings, use_fp16, clip):
"""Creates a classification model."""
model = modeling.BertModel(
config=bert_config,
@ -580,7 +730,8 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
use_one_hot_embeddings=use_one_hot_embeddings,
compute_type=tf.float16 if use_fp16 else tf.float32)
# In the demo, we are doing a simple classification task on the entire
# segment.
@ -605,20 +756,30 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
probabilities = tf.nn.softmax(logits, axis=-1)
log_probs = tf.nn.log_softmax(logits, axis=-1)
if clip:
probabilities = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6)
log_probs = tf.log(probabilities)
else:
probabilities = tf.nn.softmax(logits, axis=-1)
log_probs = tf.nn.log_softmax(logits, axis=-1)
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
p0 = tf.reduce_sum(probabilities[:, 0:FLAGS.bad_label_num], axis=-1)
p1 = tf.subtract(1.0, p0)
probabilities = tf.stack([p0, p1], axis=-1)
return (loss, per_example_loss, logits, probabilities)
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
use_one_hot_embeddings, adjust_lr, use_hvd,
use_compression, use_fp16, clip, cos_decay,
use_lamb, previous_train_steps, post_train_steps):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
@ -631,6 +792,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
input_rowid = features["input_rowid"]
label_ids = features["label_ids"]
is_real_example = None
if "is_real_example" in features:
@ -642,7 +804,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
(total_loss, per_example_loss, logits, probabilities) = create_model(
bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
num_labels, use_one_hot_embeddings)
num_labels, use_one_hot_embeddings, use_fp16, clip)
tvars = tf.trainable_variables()
initialized_variable_names = {}
@ -670,39 +832,42 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op, update_learning_rate = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd,
use_compression, use_fp16, clip, cos_decay, use_lamb, previous_train_steps, post_train_steps)
train_op = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
logging_hook = tf.train.LoggingTensorHook({"loss": total_loss, "learning_rate": update_learning_rate}, every_n_iter=FLAGS.hooking_frequence)
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
training_hooks=[logging_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
def metric_fn(per_example_loss, label_ids, logits, probabilities, is_real_example):
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(
labels=label_ids, predictions=predictions, weights=is_real_example)
loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
rocauc = tf.metrics.auc(labels=label_ids, predictions=probabilities[:, 1], curve="ROC", summation_method="careful_interpolation", weights=is_real_example)
prauc = tf.metrics.auc(labels=label_ids, predictions=probabilities[:, 1], curve="PR", summation_method="careful_interpolation", weights=is_real_example)
return {
"eval_accuracy": accuracy,
"eval_loss": loss,
"rocauc": rocauc,
"prauc": prauc,
}
eval_metrics = (metric_fn,
[per_example_loss, label_ids, logits, is_real_example])
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
eval_metrics = metric_fn(
per_example_loss, label_ids, logits, probabilities, is_real_example)
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
eval_metric_ops=eval_metrics)
else:
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
predictions={"probabilities": probabilities},
scaffold_fn=scaffold_fn)
predictions={"probabilities": probabilities, "labels": label_ids, "rowids": input_rowid})
return output_spec
return model_fn
@ -783,17 +948,26 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.use_hvd:
hvd.init()
if FLAGS.reduce_log and (hvd.rank() != 0):
tf.logging.set_verbosity(tf.logging.ERROR)
FLAGS.output_dir = FLAGS.output_dir if hvd.rank() == 0 else os.path.join(FLAGS.output_dir, str(hvd.rank()))
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"qk": QKProcessor,
}
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
FLAGS.init_checkpoint)
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict and not FLAGS.do_train_eval and not FLAGS.do_export:
raise ValueError(
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")
@ -807,6 +981,35 @@ def main(_):
tf.gfile.MakeDirs(FLAGS.output_dir)
if FLAGS.recover_dir is not None:
if FLAGS.use_hvd:
FLAGS.recover_dir = FLAGS.recover_dir if hvd.rank() == 0 else os.path.join(FLAGS.recover_dir, str(hvd.rank()))
path_ckpt = os.path.join(FLAGS.output_dir, "checkpoint")
path_ckpt_input = os.path.join(FLAGS.output_dir, "checkpoint_input")
if FLAGS.ckpt_no is not None and not tf.gfile.Exists(path_ckpt):
with tf.gfile.GFile(path_ckpt, "w") as writer:
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
if FLAGS.ckpt_no_input is not None and not tf.gfile.Exists(path_ckpt_input):
with tf.gfile.GFile(path_ckpt_input, "w") as writer:
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
if FLAGS.use_hvd and hvd.rank() == 0 and (FLAGS.do_train or FLAGS.do_train_eval):
(vpath, vname) = os.path.split(FLAGS.vocab_file)
tf.gfile.Copy(FLAGS.vocab_file, os.path.join(FLAGS.output_dir, vname), True)
(cpath, cname) = os.path.split(FLAGS.bert_config_file)
tf.gfile.Copy(FLAGS.bert_config_file, os.path.join(FLAGS.output_dir, cname), True)
if FLAGS.validation_data_dir is None:
FLAGS.validation_data_dir = FLAGS.data_dir
if FLAGS.test_data_dir is None:
FLAGS.test_data_dir = FLAGS.validation_data_dir
task_name = FLAGS.task_name.lower()
if task_name not in processors:
@ -816,33 +1019,59 @@ def main(_):
label_list = processor.get_labels()
if FLAGS.label_list:
label_list = FLAGS.label_list.split(",")
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
num_gpu = 1 if not FLAGS.use_hvd else hvd.size()
train_examples = None
num_train_steps = None
num_warmup_steps = None
if FLAGS.do_train:
train_examples = processor.get_train_examples(FLAGS.data_dir)
num_train_steps = int(
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
if FLAGS.do_train or FLAGS.do_train_eval:
if FLAGS.use_tfrecord:
if FLAGS.train_examples_count is None:
FLAGS.train_examples_count = 0
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name)):
FLAGS.train_examples_count += 1
num_train_steps = int(
FLAGS.train_examples_count / (FLAGS.train_batch_size * num_gpu) * FLAGS.num_train_epochs)
else:
train_examples = processor.get_train_examples(FLAGS.data_dir)
num_train_steps = int(
len(train_examples) / (FLAGS.train_batch_size * num_gpu) * FLAGS.num_train_epochs)
num_warmup_steps = int((num_train_steps + FLAGS.previous_train_steps + FLAGS.post_train_steps) * FLAGS.warmup_proportion)
if FLAGS.save_checkpoints_steps is None:
FLAGS.save_checkpoints_steps = 1000 if num_train_steps is None else int(num_train_steps / FLAGS.num_train_epochs)
if FLAGS.keep_checkpoint_max is None:
FLAGS.keep_checkpoint_max = int(FLAGS.num_train_epochs + 1.0)
config = tf.ConfigProto()
if FLAGS.xla:
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
if FLAGS.use_hvd:
config.gpu_options.visible_device_list = str(hvd.local_rank())
config.gpu_options.allow_growth=True
run_config = tf.estimator.RunConfig(
model_dir=FLAGS.output_dir,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
log_step_count_steps=FLAGS.hooking_frequence,
session_config=config)
if FLAGS.use_hvd and hvd.rank() != 0 and not FLAGS.auto_recover:
run_config = tf.estimator.RunConfig(
model_dir=FLAGS.output_dir,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
save_checkpoints_steps=None,
save_checkpoints_secs=None,
log_step_count_steps=FLAGS.hooking_frequence,
session_config=config)
model_fn = model_fn_builder(
bert_config=bert_config,
@ -852,53 +1081,95 @@ def main(_):
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
use_one_hot_embeddings=FLAGS.use_tpu,
adjust_lr=FLAGS.adjust_lr,
use_hvd=FLAGS.use_hvd,
use_compression=FLAGS.use_compression,
use_fp16=FLAGS.use_fp16,
clip=FLAGS.clip,
cos_decay=FLAGS.cos_decay,
use_lamb=FLAGS.use_lamb,
previous_train_steps=FLAGS.previous_train_steps,
post_train_steps=FLAGS.post_train_steps)
hooks = []
if FLAGS.use_hvd:
hooks.append(hvd.BroadcastGlobalVariablesHook(0))
if hvd.rank() == -1: #if debug, set 0
CLIDebugHook = tf_debug.LocalCLIDebugHook(ui_type='readline')
CLIDebugHook.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
hooks.append(CLIDebugHook)
if FLAGS.profile and hvd.rank() == 0:
ProfilerHook = tf.train.ProfilerHook(save_steps=FLAGS.hooking_frequence, output_dir=FLAGS.output_dir, show_dataflow=True, show_memory=True)
hooks.append(ProfilerHook)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
predict_batch_size=FLAGS.predict_batch_size)
config=run_config)
if FLAGS.do_train:
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
file_based_convert_examples_to_features(
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
if FLAGS.use_tfrecord:
train_file = os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name)
else:
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
file_based_convert_examples_to_features(
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
tf.logging.info("***** Running training *****")
tf.logging.info(" Num examples = %d", len(train_examples))
tf.logging.info(" Num examples = %d", FLAGS.train_examples_count if train_examples is None else len(train_examples))
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
tf.logging.info(" Num steps = %d", num_train_steps)
train_input_fn = file_based_input_fn_builder(
input_file=train_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
drop_remainder=True,
batch_size=FLAGS.train_batch_size,
use_hvd=FLAGS.use_hvd)
if FLAGS.auto_recover:
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=hooks)
if FLAGS.clean_tfrecord and tf.gfile.Exists(train_file):
tf.gfile.Remove(train_file)
if FLAGS.do_eval:
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
num_actual_eval_examples = len(eval_examples)
if FLAGS.use_tpu:
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on. These do NOT count towards the metric (all tf.metrics
# support a per-instance weight, and these get a weight of 0.0).
while len(eval_examples) % FLAGS.eval_batch_size != 0:
eval_examples.append(PaddingInputExample())
if FLAGS.use_validation_tfrecord:
num_actual_eval_examples = 0
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)):
num_actual_eval_examples += 1
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
file_based_convert_examples_to_features(
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
validation_examples_count = num_actual_eval_examples
eval_file = os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)
else:
eval_examples = processor.get_dev_examples(FLAGS.validation_data_dir)
num_actual_eval_examples = len(eval_examples)
if FLAGS.use_tpu:
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on. These do NOT count towards the metric (all tf.metrics
# support a per-instance weight, and these get a weight of 0.0).
while len(eval_examples) % FLAGS.eval_batch_size != 0:
eval_examples.append(PaddingInputExample())
validation_examples_count = len(eval_examples)
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
file_based_convert_examples_to_features(
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
len(eval_examples), num_actual_eval_examples,
len(eval_examples) - num_actual_eval_examples)
validation_examples_count, num_actual_eval_examples,
validation_examples_count - num_actual_eval_examples)
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
# This tells the estimator to run through the entire set.
@ -906,15 +1177,17 @@ def main(_):
# However, if running eval on the TPU, you will need to specify the
# number of steps.
if FLAGS.use_tpu:
assert len(eval_examples) % FLAGS.eval_batch_size == 0
eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
assert validation_examples_count % FLAGS.eval_batch_size == 0
eval_steps = int(validation_examples_count // FLAGS.eval_batch_size)
eval_drop_remainder = True if FLAGS.use_tpu else False
eval_input_fn = file_based_input_fn_builder(
input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder)
drop_remainder=eval_drop_remainder,
batch_size=FLAGS.eval_batch_size,
use_hvd=FLAGS.use_hvd)
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
@ -925,26 +1198,40 @@ def main(_):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if FLAGS.do_predict:
predict_examples = processor.get_test_examples(FLAGS.data_dir)
num_actual_predict_examples = len(predict_examples)
if FLAGS.use_tpu:
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
while len(predict_examples) % FLAGS.predict_batch_size != 0:
predict_examples.append(PaddingInputExample())
if FLAGS.clean_tfrecord and tf.gfile.Exists(eval_file):
tf.gfile.Remove(eval_file)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
file_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer,
predict_file)
if FLAGS.do_predict:
if FLAGS.use_test_tfrecord:
num_actual_predict_examples = 0
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.test_data_dir, FLAGS.test_tfrecord_name)):
num_actual_predict_examples += 1
test_examples_count = num_actual_predict_examples
predict_file = os.path.join(FLAGS.test_data_dir, FLAGS.test_tfrecord_name)
else:
predict_examples = processor.get_test_examples(FLAGS.test_data_dir)
num_actual_predict_examples = len(predict_examples)
if FLAGS.use_tpu:
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
while len(predict_examples) % FLAGS.predict_batch_size != 0:
predict_examples.append(PaddingInputExample())
test_examples_count = len(predict_examples)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
file_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer,
predict_file)
tf.logging.info("***** Running prediction*****")
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
len(predict_examples), num_actual_predict_examples,
len(predict_examples) - num_actual_predict_examples)
test_examples_count, num_actual_predict_examples,
test_examples_count - num_actual_predict_examples)
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
predict_drop_remainder = True if FLAGS.use_tpu else False
@ -952,7 +1239,9 @@ def main(_):
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=predict_drop_remainder)
drop_remainder=predict_drop_remainder,
batch_size=FLAGS.predict_batch_size,
use_hvd=FLAGS.use_hvd)
result = estimator.predict(input_fn=predict_input_fn)
@ -960,20 +1249,123 @@ def main(_):
with tf.gfile.GFile(output_predict_file, "w") as writer:
num_written_lines = 0
tf.logging.info("***** Predict results *****")
if FLAGS.add_header:
writer.write("rowids\tprobabilities\tlabels\n")
for (i, prediction) in enumerate(result):
probabilities = prediction["probabilities"]
if i >= num_actual_predict_examples:
break
output_line = "\t".join(
str(class_probability)
for class_probability in probabilities) + "\n"
output_line = str(prediction["rowids"]) + "\t" + str(probabilities[1]) + "\t" + str(prediction["labels"]) + "\n"
writer.write(output_line)
num_written_lines += 1
assert num_written_lines == num_actual_predict_examples
if FLAGS.clean_tfrecord and tf.gfile.Exists(predict_file):
tf.gfile.Remove(predict_file)
if FLAGS.do_train_eval:
if FLAGS.use_tfrecord:
train_file = os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name)
else:
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
file_based_convert_examples_to_features(
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
tf.logging.info("***** Running training *****")
tf.logging.info(" Num examples = %d", FLAGS.train_examples_count if train_examples is None else len(train_examples))
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
tf.logging.info(" Num steps = %d", num_train_steps)
train_input_fn = file_based_input_fn_builder(
input_file=train_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True,
batch_size=FLAGS.train_batch_size,
use_hvd=FLAGS.use_hvd)
if FLAGS.use_validation_tfrecord:
num_actual_eval_examples = 0
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)):
num_actual_eval_examples += 1
validation_examples_count = num_actual_eval_examples
eval_file = os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)
else:
eval_examples = processor.get_dev_examples(FLAGS.validation_data_dir)
num_actual_eval_examples = len(eval_examples)
if FLAGS.use_tpu:
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on. These do NOT count towards the metric (all tf.metrics
# support a per-instance weight, and these get a weight of 0.0).
while len(eval_examples) % FLAGS.eval_batch_size != 0:
eval_examples.append(PaddingInputExample())
validation_examples_count = len(eval_examples)
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
file_based_convert_examples_to_features(
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
validation_examples_count, num_actual_eval_examples,
validation_examples_count - num_actual_eval_examples)
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
# This tells the estimator to run through the entire set.
eval_steps = None
# However, if running eval on the TPU, you will need to specify the
# number of steps.
if FLAGS.use_tpu:
assert validation_examples_count % FLAGS.eval_batch_size == 0
eval_steps = int(validation_examples_count // FLAGS.eval_batch_size)
eval_drop_remainder = True if FLAGS.use_tpu else False
eval_input_fn = file_based_input_fn_builder(
input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder,
batch_size=FLAGS.eval_batch_size,
use_hvd=FLAGS.use_hvd)
if FLAGS.auto_recover:
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps, hooks=hooks)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=eval_steps)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if FLAGS.clean_tfrecord and tf.gfile.Exists(train_file):
tf.gfile.Remove(train_file)
if FLAGS.clean_tfrecord and tf.gfile.Exists(eval_file):
tf.gfile.Remove(eval_file)
if FLAGS.do_export:
def serving_input_fn():
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
input_rowid = tf.placeholder(tf.int64, [None], name='input_rowid')
label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
'input_rowid': input_rowid,
'label_ids': label_ids,
})()
return input_fn
estimator._export_to_tpu = False
estimator.export_savedmodel(FLAGS.export_dir, serving_input_fn)
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
# flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")

Просмотреть файл

@ -22,6 +22,8 @@ import os
import modeling
import optimization
import tensorflow as tf
import horovod.tensorflow as hvd
from tensorflow.python import debug as tf_debug
flags = tf.flags
@ -37,6 +39,18 @@ flags.DEFINE_string(
"input_file", None,
"Input TF example files (can be a glob or comma separated).")
flags.DEFINE_string(
"validation_input_file", None,
"Input validation TF example files (can be a glob or comma separated).")
flags.DEFINE_string(
"input_dir", None,
"Input TF example dir.")
flags.DEFINE_string(
"validation_input_dir", None,
"Input validation TF example dir.")
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
@ -61,6 +75,8 @@ flags.DEFINE_bool("do_train", False, "Whether to run training.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_bool("do_train_eval", False, "Whether to run train with eval.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
@ -77,7 +93,7 @@ flags.DEFINE_integer("save_checkpoints_steps", 1000,
flags.DEFINE_integer("iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
flags.DEFINE_integer("max_eval_steps", None, "Maximum number of eval steps.")
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
@ -106,9 +122,48 @@ flags.DEFINE_integer(
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
flags.DEFINE_integer("hooking_frequence", 100, "Hooking frequence.")
flags.DEFINE_bool("reduce_log", False, "Reduce log.")
flags.DEFINE_integer("keep_checkpoint_max", 1, "Keep checkpoint max.")
flags.DEFINE_bool("xla", True, "Whether to train with XLA optimization.")
flags.DEFINE_bool("adjust_lr", True, "Whether to adjust learning_rate.")
flags.DEFINE_integer("previous_train_steps", 0, "Previous train steps.")
flags.DEFINE_integer("post_train_steps", 0, "Post train steps.")
flags.DEFINE_bool("use_hvd", True, "Whether to use Horovod.")
flags.DEFINE_bool("use_compression", True, "Whether to use compression in Horovod.")
flags.DEFINE_bool("use_fp16", True, "Whether to use fp16.")
flags.DEFINE_bool("cos_decay", False, "Whether to use cos decay.")
flags.DEFINE_bool("use_lamb", False, "Whether to use lamb.")
flags.DEFINE_bool("auto_recover", False, "Whether to use auto recover.")
flags.DEFINE_string("recover_dir", None, "The output directory where the model checkpoints will be recovered.")
flags.DEFINE_integer("ckpt_no", None, "Checkpoint number of model to be recovered.")
flags.DEFINE_integer("ckpt_no_input", None, "Checkpoint number of input to be recovered.")
flags.DEFINE_bool("clip", False, "Whether to use clip.")
flags.DEFINE_bool("profile", False, "Whether to use profile.")
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
use_one_hot_embeddings, adjust_lr, use_hvd,
use_compression, use_fp16, clip, cos_decay,
use_lamb, previous_train_steps, post_train_steps):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
@ -134,16 +189,17 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
use_one_hot_embeddings=use_one_hot_embeddings,
compute_type=tf.float16 if use_fp16 else tf.float32)
(masked_lm_loss,
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
bert_config, model.get_sequence_output(), model.get_embedding_table(),
masked_lm_positions, masked_lm_ids, masked_lm_weights)
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
bert_config, model.get_sequence_output(), model.get_embedding_table(),
masked_lm_positions, masked_lm_ids, masked_lm_weights, clip)
(next_sentence_loss, next_sentence_example_loss,
next_sentence_log_probs) = get_next_sentence_output(
bert_config, model.get_pooled_output(), next_sentence_labels)
next_sentence_log_probs) = get_next_sentence_output(
bert_config, model.get_pooled_output(), next_sentence_labels, clip)
total_loss = masked_lm_loss + next_sentence_loss
@ -174,14 +230,16 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
train_op, update_learning_rate = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd,
use_compression, use_fp16, clip, cos_decay, use_lamb, previous_train_steps, post_train_steps)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
logging_hook = tf.train.LoggingTensorHook({"loss": total_loss, "learning_rate": update_learning_rate}, every_n_iter=FLAGS.hooking_frequence)
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
training_hooks=[logging_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
@ -219,16 +277,15 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
"next_sentence_loss": next_sentence_mean_loss,
}
eval_metrics = (metric_fn, [
eval_metrics = metric_fn(
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_log_probs, next_sentence_labels
])
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
)
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
eval_metric_ops=eval_metrics)
else:
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
@ -238,7 +295,7 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
label_ids, label_weights):
label_ids, label_weights, clip):
"""Get loss and log probs for the masked LM."""
input_tensor = gather_indexes(input_tensor, positions)
@ -262,7 +319,10 @@ def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
if clip:
log_probs = tf.log(tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6))
else:
log_probs = tf.nn.log_softmax(logits, axis=-1)
label_ids = tf.reshape(label_ids, [-1])
label_weights = tf.reshape(label_weights, [-1])
@ -282,7 +342,7 @@ def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
return (loss, per_example_loss, log_probs)
def get_next_sentence_output(bert_config, input_tensor, labels):
def get_next_sentence_output(bert_config, input_tensor, labels, clip):
"""Get loss and log probs for the next sentence prediction."""
# Simple binary classification. Note that 0 is "next sentence" and 1 is
@ -297,7 +357,10 @@ def get_next_sentence_output(bert_config, input_tensor, labels):
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
if clip:
log_probs = tf.log(tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6))
else:
log_probs = tf.nn.log_softmax(logits, axis=-1)
labels = tf.reshape(labels, [-1])
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
@ -325,12 +388,14 @@ def input_fn_builder(input_files,
max_seq_length,
max_predictions_per_seq,
is_training,
num_cpu_threads=4):
num_cpu_threads=4,
batch_size=None,
use_hvd=True):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
# batch_size = params["batch_size"]
name_to_features = {
"input_ids":
@ -353,6 +418,11 @@ def input_fn_builder(input_files,
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
if use_hvd:
d = d.shard(hvd.size(), hvd.rank()) #TODO only for Horovod, shard to mimic single_GPU = False
print("Data shard: %s %s" % (hvd.size(), hvd.rank()))
d = d.repeat()
d = d.shuffle(buffer_size=len(input_files))
@ -371,7 +441,7 @@ def input_fn_builder(input_files,
d = tf.data.TFRecordDataset(input_files)
# Since we evaluate for a fixed number of steps we don't want to encounter
# out-of-range exceptions.
d = d.repeat()
# d = d.repeat()
# We must `drop_remainder` on training because the TPU requires fixed
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
@ -406,36 +476,90 @@ def _decode_record(record, name_to_features):
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
if not FLAGS.do_train and not FLAGS.do_eval:
if FLAGS.use_hvd:
hvd.init()
if FLAGS.reduce_log and (hvd.rank() != 0):
tf.logging.set_verbosity(tf.logging.ERROR)
FLAGS.output_dir = FLAGS.output_dir if hvd.rank() == 0 else os.path.join(FLAGS.output_dir, str(hvd.rank()))
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_train_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
tf.gfile.MakeDirs(FLAGS.output_dir)
if FLAGS.recover_dir is not None:
if FLAGS.use_hvd:
FLAGS.recover_dir = FLAGS.recover_dir if hvd.rank() == 0 else os.path.join(FLAGS.recover_dir, str(hvd.rank()))
path_ckpt = os.path.join(FLAGS.output_dir, "checkpoint")
path_ckpt_input = os.path.join(FLAGS.output_dir, "checkpoint_input")
if FLAGS.ckpt_no is not None and not tf.gfile.Exists(path_ckpt):
with tf.gfile.GFile(path_ckpt, "w") as writer:
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
if FLAGS.ckpt_no_input is not None and not tf.gfile.Exists(path_ckpt_input):
with tf.gfile.GFile(path_ckpt_input, "w") as writer:
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
if FLAGS.use_hvd and hvd.rank() == 0 and (FLAGS.do_train or FLAGS.do_train_eval):
(cpath, cname) = os.path.split(FLAGS.bert_config_file)
tf.gfile.Copy(FLAGS.bert_config_file, os.path.join(FLAGS.output_dir, cname), True)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
if FLAGS.input_file is not None:
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
if FLAGS.input_dir is not None:
for filename in tf.gfile.ListDirectory(FLAGS.input_dir):
input_files.extend(tf.gfile.Glob(os.path.join(FLAGS.input_dir, filename)))
tf.logging.info("*** Input Files ***")
for input_file in input_files:
tf.logging.info(" %s" % input_file)
tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
validation_input_files = []
if FLAGS.validation_input_file is None and FLAGS.validation_input_dir is None:
validation_input_files = input_files
else:
if FLAGS.validation_input_file is not None:
for input_pattern in FLAGS.validation_input_file.split(","):
validation_input_files.extend(tf.gfile.Glob(input_pattern))
if FLAGS.validation_input_dir is not None:
for filename in tf.gfile.ListDirectory(FLAGS.validation_input_dir):
validation_input_files.extend(tf.gfile.Glob(os.path.join(FLAGS.validation_input_dir, filename)))
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=FLAGS.master,
tf.logging.info("*** Input Validation Files ***")
for input_file in validation_input_files:
tf.logging.info(" %s" % input_file)
config = tf.ConfigProto()
if FLAGS.xla:
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
if FLAGS.use_hvd:
config.gpu_options.visible_device_list = str(hvd.local_rank())
config.gpu_options.allow_growth=True
run_config = tf.estimator.RunConfig(
model_dir=FLAGS.output_dir,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
log_step_count_steps=FLAGS.hooking_frequence,
session_config=config)
if FLAGS.use_hvd and hvd.rank() != 0 and not FLAGS.auto_recover:
run_config = tf.estimator.RunConfig(
model_dir=FLAGS.output_dir,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
save_checkpoints_steps=None,
save_checkpoints_secs=None,
log_step_count_steps=FLAGS.hooking_frequence,
session_config=config)
model_fn = model_fn_builder(
bert_config=bert_config,
@ -444,16 +568,36 @@ def main(_):
num_train_steps=FLAGS.num_train_steps,
num_warmup_steps=FLAGS.num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
use_one_hot_embeddings=FLAGS.use_tpu,
adjust_lr=FLAGS.adjust_lr,
use_hvd=FLAGS.use_hvd,
use_compression=FLAGS.use_compression,
use_fp16=FLAGS.use_fp16,
clip=FLAGS.clip,
cos_decay=FLAGS.cos_decay,
use_lamb=FLAGS.use_lamb,
previous_train_steps=FLAGS.previous_train_steps,
post_train_steps=FLAGS.post_train_steps)
hooks = []
if FLAGS.use_hvd:
hooks.append(hvd.BroadcastGlobalVariablesHook(0))
if hvd.rank() == -1: #if debug, set 0
CLIDebugHook = tf_debug.LocalCLIDebugHook(ui_type='readline')
CLIDebugHook.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
hooks.append(CLIDebugHook)
if FLAGS.profile and hvd.rank() == 0:
ProfilerHook = tf.train.ProfilerHook(save_steps=FLAGS.hooking_frequence, output_dir=FLAGS.output_dir, show_dataflow=True, show_memory=True)
hooks.append(ProfilerHook)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
config=run_config)
if FLAGS.do_train:
tf.logging.info("***** Running training *****")
@ -462,18 +606,26 @@ def main(_):
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=True)
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
is_training=True,
batch_size=FLAGS.train_batch_size,
use_hvd=FLAGS.use_hvd)
if FLAGS.auto_recover:
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks)
if FLAGS.do_eval:
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
eval_input_fn = input_fn_builder(
input_files=input_files,
input_files=validation_input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=False)
is_training=False,
batch_size=FLAGS.eval_batch_size,
use_hvd=FLAGS.use_hvd)
result = estimator.evaluate(
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
@ -485,9 +637,37 @@ def main(_):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if FLAGS.do_train_eval:
tf.logging.info("***** Running training *****")
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
train_input_fn = input_fn_builder(
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=True,
batch_size=FLAGS.train_batch_size,
use_hvd=FLAGS.use_hvd)
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
eval_input_fn = input_fn_builder(
input_files=validation_input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=False,
batch_size=FLAGS.eval_batch_size,
use_hvd=FLAGS.use_hvd)
if FLAGS.auto_recover:
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
# flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("output_dir")
tf.app.run()