Release
This commit is contained in:
Родитель
2671f436e2
Коммит
6efe228d60
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
|
||||
initializer=None, regularizer=None,
|
||||
trainable=True,
|
||||
*args, **kwargs):
|
||||
"""Custom variable getter that forces trainable variables to be stored in
|
||||
float32 precision and then casts them to the training precision.
|
||||
"""
|
||||
storage_dtype = tf.float32 if trainable else dtype
|
||||
variable = getter(name, shape, dtype=storage_dtype,
|
||||
initializer=initializer, regularizer=regularizer,
|
||||
trainable=trainable,
|
||||
*args, **kwargs)
|
||||
if trainable and dtype != tf.float32:
|
||||
variable = tf.cast(variable, dtype)
|
||||
return variable
|
||||
|
||||
def get_custom_getter(compute_type):
|
||||
return float32_variable_storage_getter if compute_type == tf.float16 else None
|
12
modeling.py
12
modeling.py
|
@ -26,6 +26,7 @@ import re
|
|||
import numpy as np
|
||||
import six
|
||||
import tensorflow as tf
|
||||
from gpu_environment import get_custom_getter
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
|
@ -135,7 +136,8 @@ class BertModel(object):
|
|||
input_mask=None,
|
||||
token_type_ids=None,
|
||||
use_one_hot_embeddings=False,
|
||||
scope=None):
|
||||
scope=None,
|
||||
compute_type=tf.float32):
|
||||
"""Constructor for BertModel.
|
||||
|
||||
Args:
|
||||
|
@ -168,7 +170,7 @@ class BertModel(object):
|
|||
if token_type_ids is None:
|
||||
token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
|
||||
|
||||
with tf.variable_scope(scope, default_name="bert"):
|
||||
with tf.variable_scope(scope, default_name="bert", custom_getter=get_custom_getter(compute_type)):
|
||||
with tf.variable_scope("embeddings"):
|
||||
# Perform embedding lookup on the word ids.
|
||||
(self.embedding_output, self.embedding_table) = embedding_lookup(
|
||||
|
@ -203,7 +205,7 @@ class BertModel(object):
|
|||
# Run the stacked transformer.
|
||||
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
|
||||
self.all_encoder_layers = transformer_model(
|
||||
input_tensor=self.embedding_output,
|
||||
input_tensor=tf.saturate_cast(self.embedding_output, compute_type),
|
||||
attention_mask=attention_mask,
|
||||
hidden_size=config.hidden_size,
|
||||
num_hidden_layers=config.num_hidden_layers,
|
||||
|
@ -215,7 +217,7 @@ class BertModel(object):
|
|||
initializer_range=config.initializer_range,
|
||||
do_return_all_layers=True)
|
||||
|
||||
self.sequence_output = self.all_encoder_layers[-1]
|
||||
self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.float32)
|
||||
# The "pooler" converts the encoded sequence tensor of shape
|
||||
# [batch_size, seq_length, hidden_size] to a tensor of shape
|
||||
# [batch_size, hidden_size]. This is necessary for segment-level
|
||||
|
@ -709,7 +711,7 @@ def attention_layer(from_tensor,
|
|||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
|
||||
adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0
|
||||
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
|
214
optimization.py
214
optimization.py
|
@ -21,26 +21,52 @@ from __future__ import print_function
|
|||
import re
|
||||
import tensorflow as tf
|
||||
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
# pylint: enable=g-direct-tensorflow-import
|
||||
|
||||
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
|
||||
import horovod.tensorflow as hvd
|
||||
|
||||
|
||||
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd,
|
||||
use_compression, use_fp16, clip, cos_decay, use_lamb=False,
|
||||
previous_train_steps=0, post_train_steps=0):
|
||||
"""Creates an optimizer training op."""
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
|
||||
|
||||
# Implements linear decay of the learning rate.
|
||||
learning_rate = tf.train.polynomial_decay(
|
||||
learning_rate,
|
||||
global_step,
|
||||
num_train_steps,
|
||||
end_learning_rate=0.0,
|
||||
power=1.0,
|
||||
cycle=False)
|
||||
# if adjust_lr:
|
||||
# # avoid step change in learning rate at end of warmup phase
|
||||
# decayed_learning_rate_at_crossover_point = init_lr * (1.0-float(num_warmup_steps)/float(num_train_steps + previous_train_steps + post_train_steps))
|
||||
# adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point)
|
||||
# print('decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e' % (decayed_learning_rate_at_crossover_point, adjusted_init_lr))
|
||||
|
||||
# learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32)
|
||||
|
||||
if cos_decay:
|
||||
# Implements cosine decay of the learning rate.
|
||||
learning_rate = tf.train.cosine_decay(
|
||||
learning_rate,
|
||||
global_step + previous_train_steps - (num_warmup_steps if adjust_lr else 0),
|
||||
num_train_steps + previous_train_steps + post_train_steps - (num_warmup_steps if adjust_lr else 0),
|
||||
alpha=0.0)
|
||||
else:
|
||||
# Implements linear decay of the learning rate.
|
||||
learning_rate = tf.train.polynomial_decay(
|
||||
learning_rate,
|
||||
global_step + previous_train_steps - (num_warmup_steps if adjust_lr else 0),
|
||||
num_train_steps + previous_train_steps + post_train_steps - (num_warmup_steps if adjust_lr else 0),
|
||||
end_learning_rate=0.0,
|
||||
power=1.0,
|
||||
cycle=False)
|
||||
|
||||
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
|
||||
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||
if num_warmup_steps:
|
||||
global_steps_int = tf.cast(global_step, tf.int32)
|
||||
global_steps_int = tf.cast(global_step + previous_train_steps, tf.int32)
|
||||
warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
|
||||
|
||||
global_steps_float = tf.cast(global_steps_int, tf.float32)
|
||||
|
@ -56,7 +82,16 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
|
|||
# It is recommended that you use this optimizer for fine tuning, since this
|
||||
# is how the model was trained (note that the Adam m/v variables are NOT
|
||||
# loaded from init_checkpoint.)
|
||||
optimizer = AdamWeightDecayOptimizer(
|
||||
if use_lamb:
|
||||
optimizer = LAMBOptimizer(
|
||||
learning_rate=learning_rate,
|
||||
weight_decay_rate=0.01,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
||||
else:
|
||||
optimizer = AdamWeightDecayOptimizer(
|
||||
learning_rate=learning_rate,
|
||||
weight_decay_rate=0.01,
|
||||
beta_1=0.9,
|
||||
|
@ -64,24 +99,51 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
|
|||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
|
||||
|
||||
tf.summary.scalar("learning_rate", optimizer.learning_rate)
|
||||
|
||||
if use_hvd:
|
||||
optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=hvd.Compression.fp16 if use_compression else hvd.Compression.none)
|
||||
|
||||
if use_fp16:
|
||||
loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=2**32, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5)
|
||||
optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)
|
||||
|
||||
if use_tpu:
|
||||
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
|
||||
|
||||
all_are_finite = tf.constant(True, dtype=tf.bool)
|
||||
tvars = tf.trainable_variables()
|
||||
grads = tf.gradients(loss, tvars)
|
||||
if use_hvd:
|
||||
grads_and_vars = optimizer.compute_gradients(loss, tvars)
|
||||
grads_and_vars = [(g,v) for g,v in grads_and_vars if g is not None]
|
||||
grads, tvars = list(zip(*grads_and_vars))
|
||||
all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 else tf.constant(True, dtype=tf.bool)
|
||||
else:
|
||||
grads = tf.gradients(loss, tvars)
|
||||
|
||||
# This is how the model was pre-trained.
|
||||
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
|
||||
# ensure global norm is a finite number
|
||||
# to prevent clip_by_global_norm from having a hizzy fit.
|
||||
(clipped_grads, _) = tf.clip_by_global_norm(
|
||||
grads, clip_norm=1.0,
|
||||
use_norm=tf.cond(
|
||||
all_are_finite,
|
||||
lambda: tf.global_norm(grads),
|
||||
lambda: tf.constant(1.0)))
|
||||
|
||||
train_op = optimizer.apply_gradients(
|
||||
zip(grads, tvars), global_step=global_step)
|
||||
list(zip(clipped_grads, tvars)), global_step=global_step)
|
||||
|
||||
# Normally the global step update is done inside of `apply_gradients`.
|
||||
# However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
|
||||
# a different optimizer, you should probably take this line out.
|
||||
new_global_step = global_step + 1
|
||||
if clip:
|
||||
new_global_step = global_step + 1
|
||||
else:
|
||||
new_global_step = tf.cond(all_are_finite, lambda: global_step+1, lambda: global_step)
|
||||
new_global_step = tf.identity(new_global_step, name='step_update')
|
||||
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
|
||||
return train_op
|
||||
return train_op, learning_rate
|
||||
|
||||
|
||||
class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
||||
|
@ -98,7 +160,7 @@ class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
|||
"""Constructs a AdamWeightDecayOptimizer."""
|
||||
super(AdamWeightDecayOptimizer, self).__init__(False, name)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.learning_rate = tf.identity(learning_rate, name='learning_rate')
|
||||
self.weight_decay_rate = weight_decay_rate
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
|
@ -172,3 +234,121 @@ class AdamWeightDecayOptimizer(tf.train.Optimizer):
|
|||
if m is not None:
|
||||
param_name = m.group(1)
|
||||
return param_name
|
||||
|
||||
|
||||
class LAMBOptimizer(tf.train.Optimizer):
|
||||
"""LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
|
||||
# A new optimizer that includes correct L2 weight decay, adaptive
|
||||
# element-wise updating, and layer-wise justification. The LAMB optimizer
|
||||
# was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
|
||||
# James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
|
||||
# Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
weight_decay_rate=0.0,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=None,
|
||||
exclude_from_layer_adaptation=None,
|
||||
name="LAMBOptimizer"):
|
||||
"""Constructs a LAMBOptimizer."""
|
||||
super(LAMBOptimizer, self).__init__(False, name)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay_rate = weight_decay_rate
|
||||
self.beta_1 = beta_1
|
||||
self.beta_2 = beta_2
|
||||
self.epsilon = epsilon
|
||||
self.exclude_from_weight_decay = exclude_from_weight_decay
|
||||
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
|
||||
# arg is None.
|
||||
# TODO(jingli): validate if exclude_from_layer_adaptation is necessary.
|
||||
if exclude_from_layer_adaptation:
|
||||
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
|
||||
else:
|
||||
self.exclude_from_layer_adaptation = exclude_from_weight_decay
|
||||
|
||||
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
||||
"""See base class."""
|
||||
assignments = []
|
||||
for (grad, param) in grads_and_vars:
|
||||
if grad is None or param is None:
|
||||
continue
|
||||
|
||||
param_name = self._get_variable_name(param.name)
|
||||
|
||||
m = tf.get_variable(
|
||||
name=param_name + "/adam_m",
|
||||
shape=param.shape.as_list(),
|
||||
dtype=tf.float32,
|
||||
trainable=False,
|
||||
initializer=tf.zeros_initializer())
|
||||
v = tf.get_variable(
|
||||
name=param_name + "/adam_v",
|
||||
shape=param.shape.as_list(),
|
||||
dtype=tf.float32,
|
||||
trainable=False,
|
||||
initializer=tf.zeros_initializer())
|
||||
|
||||
# Standard Adam update.
|
||||
next_m = (
|
||||
tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
|
||||
next_v = (
|
||||
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
||||
tf.square(grad)))
|
||||
|
||||
update = next_m / (tf.sqrt(next_v) + self.epsilon)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want ot decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if self._do_use_weight_decay(param_name):
|
||||
update += self.weight_decay_rate * param
|
||||
|
||||
ratio = 1.0
|
||||
if self._do_layer_adaptation(param_name):
|
||||
w_norm = linalg_ops.norm(param, ord=2)
|
||||
g_norm = linalg_ops.norm(update, ord=2)
|
||||
ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
|
||||
math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
|
||||
|
||||
update_with_lr = ratio * self.learning_rate * update
|
||||
|
||||
next_param = param - update_with_lr
|
||||
|
||||
assignments.extend(
|
||||
[param.assign(next_param),
|
||||
m.assign(next_m),
|
||||
v.assign(next_v)])
|
||||
return tf.group(*assignments, name=name)
|
||||
|
||||
def _do_use_weight_decay(self, param_name):
|
||||
"""Whether to use L2 weight decay for `param_name`."""
|
||||
if not self.weight_decay_rate:
|
||||
return False
|
||||
if self.exclude_from_weight_decay:
|
||||
for r in self.exclude_from_weight_decay:
|
||||
if re.search(r, param_name) is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _do_layer_adaptation(self, param_name):
|
||||
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
|
||||
if self.exclude_from_layer_adaptation:
|
||||
for r in self.exclude_from_layer_adaptation:
|
||||
if re.search(r, param_name) is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_variable_name(self, param_name):
|
||||
"""Get the variable name from the tensor name."""
|
||||
m = re.match("^(.*):\\d+$", param_name)
|
||||
if m is not None:
|
||||
param_name = m.group(1)
|
||||
return param_name
|
||||
|
|
|
@ -25,6 +25,8 @@ import modeling
|
|||
import optimization
|
||||
import tokenization
|
||||
import tensorflow as tf
|
||||
import horovod.tensorflow as hvd
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
flags = tf.flags
|
||||
|
||||
|
@ -36,6 +38,16 @@ flags.DEFINE_string(
|
|||
"The input data dir. Should contain the .tsv files (or other data files) "
|
||||
"for the task.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"validation_data_dir", None,
|
||||
"The input validation data dir. Should contain the .tsv files (or other data files) "
|
||||
"for the task.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"test_data_dir", None,
|
||||
"The input test data dir. Should contain the .tsv files (or other data files) "
|
||||
"for the task.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"bert_config_file", None,
|
||||
"The config json file corresponding to the pre-trained BERT model. "
|
||||
|
@ -71,6 +83,8 @@ flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
|||
|
||||
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
||||
|
||||
flags.DEFINE_bool("do_train_eval", False, "Whether to run train with eval.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_predict", False,
|
||||
"Whether to run the model in inference mode on the test set.")
|
||||
|
@ -91,7 +105,7 @@ flags.DEFINE_float(
|
|||
"Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10% of training.")
|
||||
|
||||
flags.DEFINE_integer("save_checkpoints_steps", 1000,
|
||||
flags.DEFINE_integer("save_checkpoints_steps", None,
|
||||
"How often to save the model checkpoint.")
|
||||
|
||||
flags.DEFINE_integer("iterations_per_loop", 1000,
|
||||
|
@ -99,31 +113,98 @@ flags.DEFINE_integer("iterations_per_loop", 1000,
|
|||
|
||||
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
||||
|
||||
tf.flags.DEFINE_string(
|
||||
flags.DEFINE_string(
|
||||
"tpu_name", None,
|
||||
"The Cloud TPU to use for training. This should be either the name "
|
||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
||||
"url.")
|
||||
|
||||
tf.flags.DEFINE_string(
|
||||
flags.DEFINE_string(
|
||||
"tpu_zone", None,
|
||||
"[Optional] GCE zone where the Cloud TPU is located in. If not "
|
||||
"specified, we will attempt to automatically detect the GCE project from "
|
||||
"metadata.")
|
||||
|
||||
tf.flags.DEFINE_string(
|
||||
flags.DEFINE_string(
|
||||
"gcp_project", None,
|
||||
"[Optional] Project name for the Cloud TPU-enabled project. If not "
|
||||
"specified, we will attempt to automatically detect the GCE project from "
|
||||
"metadata.")
|
||||
|
||||
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
||||
flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"num_tpu_cores", 8,
|
||||
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
||||
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"do_export", False,
|
||||
"Whether to export the model.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"export_dir", None,
|
||||
"The dir where the exported model will be written.")
|
||||
|
||||
flags.DEFINE_string("label_list", None, "Label list.")
|
||||
|
||||
flags.DEFINE_integer("bad_label_num", 1, "Bad label num.")
|
||||
|
||||
flags.DEFINE_bool("add_header", False, "Add header.")
|
||||
|
||||
flags.DEFINE_bool("use_tfrecord", False, "Use tfrecord.")
|
||||
|
||||
flags.DEFINE_bool("use_validation_tfrecord", False, "Use validation tfrecord.")
|
||||
|
||||
flags.DEFINE_bool("use_test_tfrecord", False, "Use test tfrecord.")
|
||||
|
||||
flags.DEFINE_string("tfrecord_name", "train.tf_record", "tfrecord name.")
|
||||
|
||||
flags.DEFINE_string("validation_tfrecord_name", "eval.tf_record", "validation tfrecord name.")
|
||||
|
||||
flags.DEFINE_string("test_tfrecord_name", "predict.tf_record", "test tfrecord name.")
|
||||
|
||||
flags.DEFINE_bool("clean_tfrecord", False, "Clean tfrecord.")
|
||||
|
||||
flags.DEFINE_integer("train_examples_count", None, "Train examples count.")
|
||||
|
||||
flags.DEFINE_integer("hooking_frequence", 100, "Hooking frequence.")
|
||||
|
||||
flags.DEFINE_bool("reduce_log", False, "Reduce log.")
|
||||
|
||||
flags.DEFINE_integer("keep_checkpoint_max", None, "Keep checkpoint max.")
|
||||
|
||||
flags.DEFINE_bool("xla", True, "Whether to train with XLA optimization.")
|
||||
|
||||
flags.DEFINE_bool("adjust_lr", True, "Whether to adjust learning_rate.")
|
||||
|
||||
flags.DEFINE_integer("previous_train_steps", 0, "Previous train steps.")
|
||||
|
||||
flags.DEFINE_integer("post_train_steps", 0, "Post train steps.")
|
||||
|
||||
flags.DEFINE_bool("use_hvd", True, "Whether to use Horovod.")
|
||||
|
||||
flags.DEFINE_bool("use_compression", True, "Whether to use compression in Horovod.")
|
||||
|
||||
flags.DEFINE_bool("use_fp16", True, "Whether to use fp16.")
|
||||
|
||||
flags.DEFINE_bool("cos_decay", False, "Whether to use cos decay.")
|
||||
|
||||
flags.DEFINE_bool("use_lamb", False, "Whether to use lamb.")
|
||||
|
||||
flags.DEFINE_bool("auto_recover", False, "Whether to use auto recover.")
|
||||
|
||||
flags.DEFINE_string("recover_dir", None, "The output directory where the model checkpoints will be recovered.")
|
||||
|
||||
flags.DEFINE_integer("ckpt_no", None, "Checkpoint number of model to be recovered.")
|
||||
|
||||
flags.DEFINE_integer("ckpt_no_input", None, "Checkpoint number of input to be recovered.")
|
||||
|
||||
flags.DEFINE_bool("clip", False, "Whether to use clip.")
|
||||
|
||||
flags.DEFINE_bool("profile", False, "Whether to use profile.")
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
|
||||
|
@ -165,11 +246,13 @@ class InputFeatures(object):
|
|||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
row_id,
|
||||
label_id,
|
||||
is_real_example=True):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.row_id = row_id
|
||||
self.label_id = label_id
|
||||
self.is_real_example = is_real_example
|
||||
|
||||
|
@ -203,6 +286,64 @@ class DataProcessor(object):
|
|||
lines.append(line)
|
||||
return lines
|
||||
|
||||
@classmethod
|
||||
def _read_tsv_from_dir(cls, input_dir, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
input_files = [input_dir + "/" + i for i in tf.gfile.ListDirectory(input_dir)]
|
||||
lines = []
|
||||
for input_file in input_files:
|
||||
with tf.gfile.Open(input_file, "r") as f:
|
||||
reader = csv.reader((line.replace('\0', '') for line in f), delimiter="\t", quotechar=quotechar)
|
||||
for line in reader:
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
@classmethod
|
||||
def _read_tsv_from_dir_by_name(cls, input_dir, quotechar=None, name='0'):
|
||||
"""Reads a tab separated value file."""
|
||||
lines = []
|
||||
input_file = input_dir + "/" + name
|
||||
with tf.gfile.Open(input_file, "r") as f:
|
||||
reader = csv.reader((line.replace('\0', '') for line in f), delimiter="\t", quotechar=quotechar)
|
||||
for line in reader:
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
class QKProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
# return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
return self._create_examples(self._read_tsv_from_dir(data_dir), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
# return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
return self._create_examples(self._read_tsv_from_dir(data_dir), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
# return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
||||
return self._create_examples(self._read_tsv_from_dir(data_dir), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
guid = tokenization.convert_to_unicode(line[0])
|
||||
label = tokenization.convert_to_unicode(line[1])
|
||||
text_a = tokenization.convert_to_unicode(line[2])
|
||||
text_b = tokenization.convert_to_unicode(line[3])
|
||||
|
||||
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class XnliProcessor(DataProcessor):
|
||||
"""Processor for the XNLI data set."""
|
||||
|
@ -383,6 +524,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
|
|||
input_ids=[0] * max_seq_length,
|
||||
input_mask=[0] * max_seq_length,
|
||||
segment_ids=[0] * max_seq_length,
|
||||
row_id=0,
|
||||
label_id=0,
|
||||
is_real_example=False)
|
||||
|
||||
|
@ -471,6 +613,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
|
|||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
row_id=int(example.guid),
|
||||
label_id=label_id,
|
||||
is_real_example=True)
|
||||
return feature
|
||||
|
@ -497,6 +640,7 @@ def file_based_convert_examples_to_features(
|
|||
features["input_ids"] = create_int_feature(feature.input_ids)
|
||||
features["input_mask"] = create_int_feature(feature.input_mask)
|
||||
features["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
features["input_rowid"] = create_int_feature([feature.row_id])
|
||||
features["label_ids"] = create_int_feature([feature.label_id])
|
||||
features["is_real_example"] = create_int_feature(
|
||||
[int(feature.is_real_example)])
|
||||
|
@ -507,13 +651,14 @@ def file_based_convert_examples_to_features(
|
|||
|
||||
|
||||
def file_based_input_fn_builder(input_file, seq_length, is_training,
|
||||
drop_remainder):
|
||||
drop_remainder, batch_size=None, use_hvd=True):
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||
|
||||
name_to_features = {
|
||||
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
||||
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
|
||||
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
|
||||
"input_rowid": tf.FixedLenFeature([], tf.int64),
|
||||
"label_ids": tf.FixedLenFeature([], tf.int64),
|
||||
"is_real_example": tf.FixedLenFeature([], tf.int64),
|
||||
}
|
||||
|
@ -526,7 +671,7 @@ def file_based_input_fn_builder(input_file, seq_length, is_training,
|
|||
# So cast all int64 to int32.
|
||||
for name in list(example.keys()):
|
||||
t = example[name]
|
||||
if t.dtype == tf.int64:
|
||||
if t.dtype == tf.int64 and name != "input_rowid":
|
||||
t = tf.to_int32(t)
|
||||
example[name] = t
|
||||
|
||||
|
@ -534,12 +679,17 @@ def file_based_input_fn_builder(input_file, seq_length, is_training,
|
|||
|
||||
def input_fn(params):
|
||||
"""The actual input function."""
|
||||
batch_size = params["batch_size"]
|
||||
# batch_size = params["batch_size"]
|
||||
|
||||
# For training, we want a lot of parallel reading and shuffling.
|
||||
# For eval, we want no shuffling and parallel reading doesn't matter.
|
||||
d = tf.data.TFRecordDataset(input_file)
|
||||
if is_training:
|
||||
|
||||
if use_hvd:
|
||||
d = d.shard(hvd.size(), hvd.rank()) #TODO only for Horovod, shard to mimic single_GPU = False
|
||||
print("Data shard: %s %s" % (hvd.size(), hvd.rank()))
|
||||
|
||||
d = d.repeat()
|
||||
d = d.shuffle(buffer_size=100)
|
||||
|
||||
|
@ -572,7 +722,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|||
|
||||
|
||||
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
|
||||
labels, num_labels, use_one_hot_embeddings):
|
||||
labels, num_labels, use_one_hot_embeddings, use_fp16, clip):
|
||||
"""Creates a classification model."""
|
||||
model = modeling.BertModel(
|
||||
config=bert_config,
|
||||
|
@ -580,7 +730,8 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
|
|||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=segment_ids,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
compute_type=tf.float16 if use_fp16 else tf.float32)
|
||||
|
||||
# In the demo, we are doing a simple classification task on the entire
|
||||
# segment.
|
||||
|
@ -605,20 +756,30 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
|
|||
|
||||
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
|
||||
logits = tf.nn.bias_add(logits, output_bias)
|
||||
probabilities = tf.nn.softmax(logits, axis=-1)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
if clip:
|
||||
probabilities = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6)
|
||||
log_probs = tf.log(probabilities)
|
||||
else:
|
||||
probabilities = tf.nn.softmax(logits, axis=-1)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
|
||||
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
|
||||
|
||||
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
||||
loss = tf.reduce_mean(per_example_loss)
|
||||
|
||||
p0 = tf.reduce_sum(probabilities[:, 0:FLAGS.bad_label_num], axis=-1)
|
||||
p1 = tf.subtract(1.0, p0)
|
||||
probabilities = tf.stack([p0, p1], axis=-1)
|
||||
|
||||
return (loss, per_example_loss, logits, probabilities)
|
||||
|
||||
|
||||
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
||||
num_train_steps, num_warmup_steps, use_tpu,
|
||||
use_one_hot_embeddings):
|
||||
use_one_hot_embeddings, adjust_lr, use_hvd,
|
||||
use_compression, use_fp16, clip, cos_decay,
|
||||
use_lamb, previous_train_steps, post_train_steps):
|
||||
"""Returns `model_fn` closure for TPUEstimator."""
|
||||
|
||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||
|
@ -631,6 +792,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
|||
input_ids = features["input_ids"]
|
||||
input_mask = features["input_mask"]
|
||||
segment_ids = features["segment_ids"]
|
||||
input_rowid = features["input_rowid"]
|
||||
label_ids = features["label_ids"]
|
||||
is_real_example = None
|
||||
if "is_real_example" in features:
|
||||
|
@ -642,7 +804,7 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
|||
|
||||
(total_loss, per_example_loss, logits, probabilities) = create_model(
|
||||
bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
|
||||
num_labels, use_one_hot_embeddings)
|
||||
num_labels, use_one_hot_embeddings, use_fp16, clip)
|
||||
|
||||
tvars = tf.trainable_variables()
|
||||
initialized_variable_names = {}
|
||||
|
@ -670,39 +832,42 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
|
|||
|
||||
output_spec = None
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
train_op, update_learning_rate = optimization.create_optimizer(
|
||||
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd,
|
||||
use_compression, use_fp16, clip, cos_decay, use_lamb, previous_train_steps, post_train_steps)
|
||||
|
||||
train_op = optimization.create_optimizer(
|
||||
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
||||
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
logging_hook = tf.train.LoggingTensorHook({"loss": total_loss, "learning_rate": update_learning_rate}, every_n_iter=FLAGS.hooking_frequence)
|
||||
output_spec = tf.estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
train_op=train_op,
|
||||
scaffold_fn=scaffold_fn)
|
||||
training_hooks=[logging_hook])
|
||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
|
||||
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
|
||||
def metric_fn(per_example_loss, label_ids, logits, probabilities, is_real_example):
|
||||
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||
accuracy = tf.metrics.accuracy(
|
||||
labels=label_ids, predictions=predictions, weights=is_real_example)
|
||||
loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
|
||||
rocauc = tf.metrics.auc(labels=label_ids, predictions=probabilities[:, 1], curve="ROC", summation_method="careful_interpolation", weights=is_real_example)
|
||||
prauc = tf.metrics.auc(labels=label_ids, predictions=probabilities[:, 1], curve="PR", summation_method="careful_interpolation", weights=is_real_example)
|
||||
return {
|
||||
"eval_accuracy": accuracy,
|
||||
"eval_loss": loss,
|
||||
"rocauc": rocauc,
|
||||
"prauc": prauc,
|
||||
}
|
||||
|
||||
eval_metrics = (metric_fn,
|
||||
[per_example_loss, label_ids, logits, is_real_example])
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
eval_metrics = metric_fn(
|
||||
per_example_loss, label_ids, logits, probabilities, is_real_example)
|
||||
output_spec = tf.estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
eval_metrics=eval_metrics,
|
||||
scaffold_fn=scaffold_fn)
|
||||
eval_metric_ops=eval_metrics)
|
||||
else:
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
output_spec = tf.estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
predictions={"probabilities": probabilities},
|
||||
scaffold_fn=scaffold_fn)
|
||||
predictions={"probabilities": probabilities, "labels": label_ids, "rowids": input_rowid})
|
||||
return output_spec
|
||||
|
||||
return model_fn
|
||||
|
@ -783,17 +948,26 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
|||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
if FLAGS.use_hvd:
|
||||
hvd.init()
|
||||
|
||||
if FLAGS.reduce_log and (hvd.rank() != 0):
|
||||
tf.logging.set_verbosity(tf.logging.ERROR)
|
||||
|
||||
FLAGS.output_dir = FLAGS.output_dir if hvd.rank() == 0 else os.path.join(FLAGS.output_dir, str(hvd.rank()))
|
||||
|
||||
processors = {
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
"mrpc": MrpcProcessor,
|
||||
"xnli": XnliProcessor,
|
||||
"qk": QKProcessor,
|
||||
}
|
||||
|
||||
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
|
||||
FLAGS.init_checkpoint)
|
||||
|
||||
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
|
||||
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict and not FLAGS.do_train_eval and not FLAGS.do_export:
|
||||
raise ValueError(
|
||||
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")
|
||||
|
||||
|
@ -807,6 +981,35 @@ def main(_):
|
|||
|
||||
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||
|
||||
if FLAGS.recover_dir is not None:
|
||||
if FLAGS.use_hvd:
|
||||
FLAGS.recover_dir = FLAGS.recover_dir if hvd.rank() == 0 else os.path.join(FLAGS.recover_dir, str(hvd.rank()))
|
||||
path_ckpt = os.path.join(FLAGS.output_dir, "checkpoint")
|
||||
path_ckpt_input = os.path.join(FLAGS.output_dir, "checkpoint_input")
|
||||
|
||||
if FLAGS.ckpt_no is not None and not tf.gfile.Exists(path_ckpt):
|
||||
with tf.gfile.GFile(path_ckpt, "w") as writer:
|
||||
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
|
||||
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
|
||||
|
||||
if FLAGS.ckpt_no_input is not None and not tf.gfile.Exists(path_ckpt_input):
|
||||
with tf.gfile.GFile(path_ckpt_input, "w") as writer:
|
||||
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
|
||||
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
|
||||
|
||||
if FLAGS.use_hvd and hvd.rank() == 0 and (FLAGS.do_train or FLAGS.do_train_eval):
|
||||
(vpath, vname) = os.path.split(FLAGS.vocab_file)
|
||||
tf.gfile.Copy(FLAGS.vocab_file, os.path.join(FLAGS.output_dir, vname), True)
|
||||
|
||||
(cpath, cname) = os.path.split(FLAGS.bert_config_file)
|
||||
tf.gfile.Copy(FLAGS.bert_config_file, os.path.join(FLAGS.output_dir, cname), True)
|
||||
|
||||
if FLAGS.validation_data_dir is None:
|
||||
FLAGS.validation_data_dir = FLAGS.data_dir
|
||||
|
||||
if FLAGS.test_data_dir is None:
|
||||
FLAGS.test_data_dir = FLAGS.validation_data_dir
|
||||
|
||||
task_name = FLAGS.task_name.lower()
|
||||
|
||||
if task_name not in processors:
|
||||
|
@ -816,33 +1019,59 @@ def main(_):
|
|||
|
||||
label_list = processor.get_labels()
|
||||
|
||||
if FLAGS.label_list:
|
||||
label_list = FLAGS.label_list.split(",")
|
||||
|
||||
tokenizer = tokenization.FullTokenizer(
|
||||
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
|
||||
|
||||
tpu_cluster_resolver = None
|
||||
if FLAGS.use_tpu and FLAGS.tpu_name:
|
||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||
|
||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
cluster=tpu_cluster_resolver,
|
||||
master=FLAGS.master,
|
||||
model_dir=FLAGS.output_dir,
|
||||
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
iterations_per_loop=FLAGS.iterations_per_loop,
|
||||
num_shards=FLAGS.num_tpu_cores,
|
||||
per_host_input_for_training=is_per_host))
|
||||
|
||||
num_gpu = 1 if not FLAGS.use_hvd else hvd.size()
|
||||
train_examples = None
|
||||
num_train_steps = None
|
||||
num_warmup_steps = None
|
||||
if FLAGS.do_train:
|
||||
train_examples = processor.get_train_examples(FLAGS.data_dir)
|
||||
num_train_steps = int(
|
||||
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
|
||||
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
|
||||
if FLAGS.do_train or FLAGS.do_train_eval:
|
||||
if FLAGS.use_tfrecord:
|
||||
if FLAGS.train_examples_count is None:
|
||||
FLAGS.train_examples_count = 0
|
||||
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name)):
|
||||
FLAGS.train_examples_count += 1
|
||||
|
||||
num_train_steps = int(
|
||||
FLAGS.train_examples_count / (FLAGS.train_batch_size * num_gpu) * FLAGS.num_train_epochs)
|
||||
else:
|
||||
train_examples = processor.get_train_examples(FLAGS.data_dir)
|
||||
num_train_steps = int(
|
||||
len(train_examples) / (FLAGS.train_batch_size * num_gpu) * FLAGS.num_train_epochs)
|
||||
num_warmup_steps = int((num_train_steps + FLAGS.previous_train_steps + FLAGS.post_train_steps) * FLAGS.warmup_proportion)
|
||||
|
||||
if FLAGS.save_checkpoints_steps is None:
|
||||
FLAGS.save_checkpoints_steps = 1000 if num_train_steps is None else int(num_train_steps / FLAGS.num_train_epochs)
|
||||
|
||||
if FLAGS.keep_checkpoint_max is None:
|
||||
FLAGS.keep_checkpoint_max = int(FLAGS.num_train_epochs + 1.0)
|
||||
|
||||
config = tf.ConfigProto()
|
||||
if FLAGS.xla:
|
||||
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
||||
if FLAGS.use_hvd:
|
||||
config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||
config.gpu_options.allow_growth=True
|
||||
|
||||
run_config = tf.estimator.RunConfig(
|
||||
model_dir=FLAGS.output_dir,
|
||||
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
|
||||
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||
log_step_count_steps=FLAGS.hooking_frequence,
|
||||
session_config=config)
|
||||
|
||||
if FLAGS.use_hvd and hvd.rank() != 0 and not FLAGS.auto_recover:
|
||||
run_config = tf.estimator.RunConfig(
|
||||
model_dir=FLAGS.output_dir,
|
||||
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
|
||||
save_checkpoints_steps=None,
|
||||
save_checkpoints_secs=None,
|
||||
log_step_count_steps=FLAGS.hooking_frequence,
|
||||
session_config=config)
|
||||
|
||||
model_fn = model_fn_builder(
|
||||
bert_config=bert_config,
|
||||
|
@ -852,53 +1081,95 @@ def main(_):
|
|||
num_train_steps=num_train_steps,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
use_one_hot_embeddings=FLAGS.use_tpu)
|
||||
use_one_hot_embeddings=FLAGS.use_tpu,
|
||||
adjust_lr=FLAGS.adjust_lr,
|
||||
use_hvd=FLAGS.use_hvd,
|
||||
use_compression=FLAGS.use_compression,
|
||||
use_fp16=FLAGS.use_fp16,
|
||||
clip=FLAGS.clip,
|
||||
cos_decay=FLAGS.cos_decay,
|
||||
use_lamb=FLAGS.use_lamb,
|
||||
previous_train_steps=FLAGS.previous_train_steps,
|
||||
post_train_steps=FLAGS.post_train_steps)
|
||||
|
||||
hooks = []
|
||||
|
||||
if FLAGS.use_hvd:
|
||||
hooks.append(hvd.BroadcastGlobalVariablesHook(0))
|
||||
|
||||
if hvd.rank() == -1: #if debug, set 0
|
||||
CLIDebugHook = tf_debug.LocalCLIDebugHook(ui_type='readline')
|
||||
CLIDebugHook.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
|
||||
hooks.append(CLIDebugHook)
|
||||
|
||||
if FLAGS.profile and hvd.rank() == 0:
|
||||
ProfilerHook = tf.train.ProfilerHook(save_steps=FLAGS.hooking_frequence, output_dir=FLAGS.output_dir, show_dataflow=True, show_memory=True)
|
||||
hooks.append(ProfilerHook)
|
||||
|
||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||
# or GPU.
|
||||
estimator = tf.contrib.tpu.TPUEstimator(
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=model_fn,
|
||||
config=run_config,
|
||||
train_batch_size=FLAGS.train_batch_size,
|
||||
eval_batch_size=FLAGS.eval_batch_size,
|
||||
predict_batch_size=FLAGS.predict_batch_size)
|
||||
config=run_config)
|
||||
|
||||
if FLAGS.do_train:
|
||||
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
|
||||
file_based_convert_examples_to_features(
|
||||
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
|
||||
if FLAGS.use_tfrecord:
|
||||
train_file = os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name)
|
||||
else:
|
||||
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
|
||||
file_based_convert_examples_to_features(
|
||||
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
|
||||
tf.logging.info("***** Running training *****")
|
||||
tf.logging.info(" Num examples = %d", len(train_examples))
|
||||
tf.logging.info(" Num examples = %d", FLAGS.train_examples_count if train_examples is None else len(train_examples))
|
||||
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||
tf.logging.info(" Num steps = %d", num_train_steps)
|
||||
train_input_fn = file_based_input_fn_builder(
|
||||
input_file=train_file,
|
||||
seq_length=FLAGS.max_seq_length,
|
||||
is_training=True,
|
||||
drop_remainder=True)
|
||||
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
||||
drop_remainder=True,
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
if FLAGS.auto_recover:
|
||||
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
|
||||
|
||||
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=hooks)
|
||||
|
||||
if FLAGS.clean_tfrecord and tf.gfile.Exists(train_file):
|
||||
tf.gfile.Remove(train_file)
|
||||
|
||||
if FLAGS.do_eval:
|
||||
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
|
||||
num_actual_eval_examples = len(eval_examples)
|
||||
if FLAGS.use_tpu:
|
||||
# TPU requires a fixed batch size for all batches, therefore the number
|
||||
# of examples must be a multiple of the batch size, or else examples
|
||||
# will get dropped. So we pad with fake examples which are ignored
|
||||
# later on. These do NOT count towards the metric (all tf.metrics
|
||||
# support a per-instance weight, and these get a weight of 0.0).
|
||||
while len(eval_examples) % FLAGS.eval_batch_size != 0:
|
||||
eval_examples.append(PaddingInputExample())
|
||||
if FLAGS.use_validation_tfrecord:
|
||||
num_actual_eval_examples = 0
|
||||
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)):
|
||||
num_actual_eval_examples += 1
|
||||
|
||||
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
|
||||
file_based_convert_examples_to_features(
|
||||
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
|
||||
validation_examples_count = num_actual_eval_examples
|
||||
|
||||
eval_file = os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)
|
||||
else:
|
||||
eval_examples = processor.get_dev_examples(FLAGS.validation_data_dir)
|
||||
num_actual_eval_examples = len(eval_examples)
|
||||
if FLAGS.use_tpu:
|
||||
# TPU requires a fixed batch size for all batches, therefore the number
|
||||
# of examples must be a multiple of the batch size, or else examples
|
||||
# will get dropped. So we pad with fake examples which are ignored
|
||||
# later on. These do NOT count towards the metric (all tf.metrics
|
||||
# support a per-instance weight, and these get a weight of 0.0).
|
||||
while len(eval_examples) % FLAGS.eval_batch_size != 0:
|
||||
eval_examples.append(PaddingInputExample())
|
||||
|
||||
validation_examples_count = len(eval_examples)
|
||||
|
||||
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
|
||||
file_based_convert_examples_to_features(
|
||||
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
|
||||
|
||||
tf.logging.info("***** Running evaluation *****")
|
||||
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
|
||||
len(eval_examples), num_actual_eval_examples,
|
||||
len(eval_examples) - num_actual_eval_examples)
|
||||
validation_examples_count, num_actual_eval_examples,
|
||||
validation_examples_count - num_actual_eval_examples)
|
||||
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||
|
||||
# This tells the estimator to run through the entire set.
|
||||
|
@ -906,15 +1177,17 @@ def main(_):
|
|||
# However, if running eval on the TPU, you will need to specify the
|
||||
# number of steps.
|
||||
if FLAGS.use_tpu:
|
||||
assert len(eval_examples) % FLAGS.eval_batch_size == 0
|
||||
eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
|
||||
assert validation_examples_count % FLAGS.eval_batch_size == 0
|
||||
eval_steps = int(validation_examples_count // FLAGS.eval_batch_size)
|
||||
|
||||
eval_drop_remainder = True if FLAGS.use_tpu else False
|
||||
eval_input_fn = file_based_input_fn_builder(
|
||||
input_file=eval_file,
|
||||
seq_length=FLAGS.max_seq_length,
|
||||
is_training=False,
|
||||
drop_remainder=eval_drop_remainder)
|
||||
drop_remainder=eval_drop_remainder,
|
||||
batch_size=FLAGS.eval_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
|
||||
|
||||
|
@ -925,26 +1198,40 @@ def main(_):
|
|||
tf.logging.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if FLAGS.do_predict:
|
||||
predict_examples = processor.get_test_examples(FLAGS.data_dir)
|
||||
num_actual_predict_examples = len(predict_examples)
|
||||
if FLAGS.use_tpu:
|
||||
# TPU requires a fixed batch size for all batches, therefore the number
|
||||
# of examples must be a multiple of the batch size, or else examples
|
||||
# will get dropped. So we pad with fake examples which are ignored
|
||||
# later on.
|
||||
while len(predict_examples) % FLAGS.predict_batch_size != 0:
|
||||
predict_examples.append(PaddingInputExample())
|
||||
if FLAGS.clean_tfrecord and tf.gfile.Exists(eval_file):
|
||||
tf.gfile.Remove(eval_file)
|
||||
|
||||
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
|
||||
file_based_convert_examples_to_features(predict_examples, label_list,
|
||||
FLAGS.max_seq_length, tokenizer,
|
||||
predict_file)
|
||||
if FLAGS.do_predict:
|
||||
if FLAGS.use_test_tfrecord:
|
||||
num_actual_predict_examples = 0
|
||||
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.test_data_dir, FLAGS.test_tfrecord_name)):
|
||||
num_actual_predict_examples += 1
|
||||
|
||||
test_examples_count = num_actual_predict_examples
|
||||
|
||||
predict_file = os.path.join(FLAGS.test_data_dir, FLAGS.test_tfrecord_name)
|
||||
else:
|
||||
predict_examples = processor.get_test_examples(FLAGS.test_data_dir)
|
||||
num_actual_predict_examples = len(predict_examples)
|
||||
if FLAGS.use_tpu:
|
||||
# TPU requires a fixed batch size for all batches, therefore the number
|
||||
# of examples must be a multiple of the batch size, or else examples
|
||||
# will get dropped. So we pad with fake examples which are ignored
|
||||
# later on.
|
||||
while len(predict_examples) % FLAGS.predict_batch_size != 0:
|
||||
predict_examples.append(PaddingInputExample())
|
||||
|
||||
test_examples_count = len(predict_examples)
|
||||
|
||||
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
|
||||
file_based_convert_examples_to_features(predict_examples, label_list,
|
||||
FLAGS.max_seq_length, tokenizer,
|
||||
predict_file)
|
||||
|
||||
tf.logging.info("***** Running prediction*****")
|
||||
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
|
||||
len(predict_examples), num_actual_predict_examples,
|
||||
len(predict_examples) - num_actual_predict_examples)
|
||||
test_examples_count, num_actual_predict_examples,
|
||||
test_examples_count - num_actual_predict_examples)
|
||||
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
|
||||
|
||||
predict_drop_remainder = True if FLAGS.use_tpu else False
|
||||
|
@ -952,7 +1239,9 @@ def main(_):
|
|||
input_file=predict_file,
|
||||
seq_length=FLAGS.max_seq_length,
|
||||
is_training=False,
|
||||
drop_remainder=predict_drop_remainder)
|
||||
drop_remainder=predict_drop_remainder,
|
||||
batch_size=FLAGS.predict_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
result = estimator.predict(input_fn=predict_input_fn)
|
||||
|
||||
|
@ -960,20 +1249,123 @@ def main(_):
|
|||
with tf.gfile.GFile(output_predict_file, "w") as writer:
|
||||
num_written_lines = 0
|
||||
tf.logging.info("***** Predict results *****")
|
||||
if FLAGS.add_header:
|
||||
writer.write("rowids\tprobabilities\tlabels\n")
|
||||
for (i, prediction) in enumerate(result):
|
||||
probabilities = prediction["probabilities"]
|
||||
if i >= num_actual_predict_examples:
|
||||
break
|
||||
output_line = "\t".join(
|
||||
str(class_probability)
|
||||
for class_probability in probabilities) + "\n"
|
||||
output_line = str(prediction["rowids"]) + "\t" + str(probabilities[1]) + "\t" + str(prediction["labels"]) + "\n"
|
||||
writer.write(output_line)
|
||||
num_written_lines += 1
|
||||
assert num_written_lines == num_actual_predict_examples
|
||||
|
||||
if FLAGS.clean_tfrecord and tf.gfile.Exists(predict_file):
|
||||
tf.gfile.Remove(predict_file)
|
||||
|
||||
if FLAGS.do_train_eval:
|
||||
if FLAGS.use_tfrecord:
|
||||
train_file = os.path.join(FLAGS.data_dir, FLAGS.tfrecord_name)
|
||||
else:
|
||||
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
|
||||
file_based_convert_examples_to_features(
|
||||
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
|
||||
tf.logging.info("***** Running training *****")
|
||||
tf.logging.info(" Num examples = %d", FLAGS.train_examples_count if train_examples is None else len(train_examples))
|
||||
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||
tf.logging.info(" Num steps = %d", num_train_steps)
|
||||
train_input_fn = file_based_input_fn_builder(
|
||||
input_file=train_file,
|
||||
seq_length=FLAGS.max_seq_length,
|
||||
is_training=True,
|
||||
drop_remainder=True,
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
if FLAGS.use_validation_tfrecord:
|
||||
num_actual_eval_examples = 0
|
||||
for record in tf.python_io.tf_record_iterator(os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)):
|
||||
num_actual_eval_examples += 1
|
||||
|
||||
validation_examples_count = num_actual_eval_examples
|
||||
|
||||
eval_file = os.path.join(FLAGS.validation_data_dir, FLAGS.validation_tfrecord_name)
|
||||
else:
|
||||
eval_examples = processor.get_dev_examples(FLAGS.validation_data_dir)
|
||||
num_actual_eval_examples = len(eval_examples)
|
||||
if FLAGS.use_tpu:
|
||||
# TPU requires a fixed batch size for all batches, therefore the number
|
||||
# of examples must be a multiple of the batch size, or else examples
|
||||
# will get dropped. So we pad with fake examples which are ignored
|
||||
# later on. These do NOT count towards the metric (all tf.metrics
|
||||
# support a per-instance weight, and these get a weight of 0.0).
|
||||
while len(eval_examples) % FLAGS.eval_batch_size != 0:
|
||||
eval_examples.append(PaddingInputExample())
|
||||
|
||||
validation_examples_count = len(eval_examples)
|
||||
|
||||
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
|
||||
file_based_convert_examples_to_features(
|
||||
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
|
||||
|
||||
tf.logging.info("***** Running evaluation *****")
|
||||
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
|
||||
validation_examples_count, num_actual_eval_examples,
|
||||
validation_examples_count - num_actual_eval_examples)
|
||||
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||
|
||||
# This tells the estimator to run through the entire set.
|
||||
eval_steps = None
|
||||
# However, if running eval on the TPU, you will need to specify the
|
||||
# number of steps.
|
||||
if FLAGS.use_tpu:
|
||||
assert validation_examples_count % FLAGS.eval_batch_size == 0
|
||||
eval_steps = int(validation_examples_count // FLAGS.eval_batch_size)
|
||||
|
||||
eval_drop_remainder = True if FLAGS.use_tpu else False
|
||||
eval_input_fn = file_based_input_fn_builder(
|
||||
input_file=eval_file,
|
||||
seq_length=FLAGS.max_seq_length,
|
||||
is_training=False,
|
||||
drop_remainder=eval_drop_remainder,
|
||||
batch_size=FLAGS.eval_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
if FLAGS.auto_recover:
|
||||
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
|
||||
|
||||
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps, hooks=hooks)
|
||||
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=eval_steps)
|
||||
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
|
||||
|
||||
if FLAGS.clean_tfrecord and tf.gfile.Exists(train_file):
|
||||
tf.gfile.Remove(train_file)
|
||||
|
||||
if FLAGS.clean_tfrecord and tf.gfile.Exists(eval_file):
|
||||
tf.gfile.Remove(eval_file)
|
||||
|
||||
if FLAGS.do_export:
|
||||
def serving_input_fn():
|
||||
input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
|
||||
input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
|
||||
segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
|
||||
input_rowid = tf.placeholder(tf.int64, [None], name='input_rowid')
|
||||
label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
|
||||
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
|
||||
'input_ids': input_ids,
|
||||
'input_mask': input_mask,
|
||||
'segment_ids': segment_ids,
|
||||
'input_rowid': input_rowid,
|
||||
'label_ids': label_ids,
|
||||
})()
|
||||
return input_fn
|
||||
|
||||
estimator._export_to_tpu = False
|
||||
estimator.export_savedmodel(FLAGS.export_dir, serving_input_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("data_dir")
|
||||
# flags.mark_flag_as_required("data_dir")
|
||||
flags.mark_flag_as_required("task_name")
|
||||
flags.mark_flag_as_required("vocab_file")
|
||||
flags.mark_flag_as_required("bert_config_file")
|
||||
|
|
|
@ -22,6 +22,8 @@ import os
|
|||
import modeling
|
||||
import optimization
|
||||
import tensorflow as tf
|
||||
import horovod.tensorflow as hvd
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
flags = tf.flags
|
||||
|
||||
|
@ -37,6 +39,18 @@ flags.DEFINE_string(
|
|||
"input_file", None,
|
||||
"Input TF example files (can be a glob or comma separated).")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"validation_input_file", None,
|
||||
"Input validation TF example files (can be a glob or comma separated).")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"input_dir", None,
|
||||
"Input TF example dir.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"validation_input_dir", None,
|
||||
"Input validation TF example dir.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"output_dir", None,
|
||||
"The output directory where the model checkpoints will be written.")
|
||||
|
@ -61,6 +75,8 @@ flags.DEFINE_bool("do_train", False, "Whether to run training.")
|
|||
|
||||
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
|
||||
|
||||
flags.DEFINE_bool("do_train_eval", False, "Whether to run train with eval.")
|
||||
|
||||
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
|
||||
|
||||
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
|
||||
|
@ -77,7 +93,7 @@ flags.DEFINE_integer("save_checkpoints_steps", 1000,
|
|||
flags.DEFINE_integer("iterations_per_loop", 1000,
|
||||
"How many steps to make in each estimator call.")
|
||||
|
||||
flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
|
||||
flags.DEFINE_integer("max_eval_steps", None, "Maximum number of eval steps.")
|
||||
|
||||
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
|
||||
|
||||
|
@ -106,9 +122,48 @@ flags.DEFINE_integer(
|
|||
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
||||
|
||||
|
||||
flags.DEFINE_integer("hooking_frequence", 100, "Hooking frequence.")
|
||||
|
||||
flags.DEFINE_bool("reduce_log", False, "Reduce log.")
|
||||
|
||||
flags.DEFINE_integer("keep_checkpoint_max", 1, "Keep checkpoint max.")
|
||||
|
||||
flags.DEFINE_bool("xla", True, "Whether to train with XLA optimization.")
|
||||
|
||||
flags.DEFINE_bool("adjust_lr", True, "Whether to adjust learning_rate.")
|
||||
|
||||
flags.DEFINE_integer("previous_train_steps", 0, "Previous train steps.")
|
||||
|
||||
flags.DEFINE_integer("post_train_steps", 0, "Post train steps.")
|
||||
|
||||
flags.DEFINE_bool("use_hvd", True, "Whether to use Horovod.")
|
||||
|
||||
flags.DEFINE_bool("use_compression", True, "Whether to use compression in Horovod.")
|
||||
|
||||
flags.DEFINE_bool("use_fp16", True, "Whether to use fp16.")
|
||||
|
||||
flags.DEFINE_bool("cos_decay", False, "Whether to use cos decay.")
|
||||
|
||||
flags.DEFINE_bool("use_lamb", False, "Whether to use lamb.")
|
||||
|
||||
flags.DEFINE_bool("auto_recover", False, "Whether to use auto recover.")
|
||||
|
||||
flags.DEFINE_string("recover_dir", None, "The output directory where the model checkpoints will be recovered.")
|
||||
|
||||
flags.DEFINE_integer("ckpt_no", None, "Checkpoint number of model to be recovered.")
|
||||
|
||||
flags.DEFINE_integer("ckpt_no_input", None, "Checkpoint number of input to be recovered.")
|
||||
|
||||
flags.DEFINE_bool("clip", False, "Whether to use clip.")
|
||||
|
||||
flags.DEFINE_bool("profile", False, "Whether to use profile.")
|
||||
|
||||
|
||||
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
||||
num_train_steps, num_warmup_steps, use_tpu,
|
||||
use_one_hot_embeddings):
|
||||
use_one_hot_embeddings, adjust_lr, use_hvd,
|
||||
use_compression, use_fp16, clip, cos_decay,
|
||||
use_lamb, previous_train_steps, post_train_steps):
|
||||
"""Returns `model_fn` closure for TPUEstimator."""
|
||||
|
||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||
|
@ -134,16 +189,17 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
|||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
token_type_ids=segment_ids,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
compute_type=tf.float16 if use_fp16 else tf.float32)
|
||||
|
||||
(masked_lm_loss,
|
||||
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
||||
bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
||||
masked_lm_positions, masked_lm_ids, masked_lm_weights)
|
||||
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
|
||||
bert_config, model.get_sequence_output(), model.get_embedding_table(),
|
||||
masked_lm_positions, masked_lm_ids, masked_lm_weights, clip)
|
||||
|
||||
(next_sentence_loss, next_sentence_example_loss,
|
||||
next_sentence_log_probs) = get_next_sentence_output(
|
||||
bert_config, model.get_pooled_output(), next_sentence_labels)
|
||||
next_sentence_log_probs) = get_next_sentence_output(
|
||||
bert_config, model.get_pooled_output(), next_sentence_labels, clip)
|
||||
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
|
||||
|
@ -174,14 +230,16 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
|||
|
||||
output_spec = None
|
||||
if mode == tf.estimator.ModeKeys.TRAIN:
|
||||
train_op = optimization.create_optimizer(
|
||||
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
|
||||
train_op, update_learning_rate = optimization.create_optimizer(
|
||||
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, adjust_lr, use_hvd,
|
||||
use_compression, use_fp16, clip, cos_decay, use_lamb, previous_train_steps, post_train_steps)
|
||||
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
logging_hook = tf.train.LoggingTensorHook({"loss": total_loss, "learning_rate": update_learning_rate}, every_n_iter=FLAGS.hooking_frequence)
|
||||
output_spec = tf.estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
train_op=train_op,
|
||||
scaffold_fn=scaffold_fn)
|
||||
training_hooks=[logging_hook])
|
||||
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||
|
||||
def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||
|
@ -219,16 +277,15 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
|||
"next_sentence_loss": next_sentence_mean_loss,
|
||||
}
|
||||
|
||||
eval_metrics = (metric_fn, [
|
||||
eval_metrics = metric_fn(
|
||||
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_example_loss,
|
||||
next_sentence_log_probs, next_sentence_labels
|
||||
])
|
||||
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
)
|
||||
output_spec = tf.estimator.EstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
eval_metrics=eval_metrics,
|
||||
scaffold_fn=scaffold_fn)
|
||||
eval_metric_ops=eval_metrics)
|
||||
else:
|
||||
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
|
||||
|
||||
|
@ -238,7 +295,7 @@ def model_fn_builder(bert_config, init_checkpoint, learning_rate,
|
|||
|
||||
|
||||
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
|
||||
label_ids, label_weights):
|
||||
label_ids, label_weights, clip):
|
||||
"""Get loss and log probs for the masked LM."""
|
||||
input_tensor = gather_indexes(input_tensor, positions)
|
||||
|
||||
|
@ -262,7 +319,10 @@ def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
|
|||
initializer=tf.zeros_initializer())
|
||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||
logits = tf.nn.bias_add(logits, output_bias)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
if clip:
|
||||
log_probs = tf.log(tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6))
|
||||
else:
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
|
||||
label_ids = tf.reshape(label_ids, [-1])
|
||||
label_weights = tf.reshape(label_weights, [-1])
|
||||
|
@ -282,7 +342,7 @@ def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
|
|||
return (loss, per_example_loss, log_probs)
|
||||
|
||||
|
||||
def get_next_sentence_output(bert_config, input_tensor, labels):
|
||||
def get_next_sentence_output(bert_config, input_tensor, labels, clip):
|
||||
"""Get loss and log probs for the next sentence prediction."""
|
||||
|
||||
# Simple binary classification. Note that 0 is "next sentence" and 1 is
|
||||
|
@ -297,7 +357,10 @@ def get_next_sentence_output(bert_config, input_tensor, labels):
|
|||
|
||||
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
|
||||
logits = tf.nn.bias_add(logits, output_bias)
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
if clip:
|
||||
log_probs = tf.log(tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-6, 1.0 - 1e-6))
|
||||
else:
|
||||
log_probs = tf.nn.log_softmax(logits, axis=-1)
|
||||
labels = tf.reshape(labels, [-1])
|
||||
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
|
||||
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
|
||||
|
@ -325,12 +388,14 @@ def input_fn_builder(input_files,
|
|||
max_seq_length,
|
||||
max_predictions_per_seq,
|
||||
is_training,
|
||||
num_cpu_threads=4):
|
||||
num_cpu_threads=4,
|
||||
batch_size=None,
|
||||
use_hvd=True):
|
||||
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
|
||||
|
||||
def input_fn(params):
|
||||
"""The actual input function."""
|
||||
batch_size = params["batch_size"]
|
||||
# batch_size = params["batch_size"]
|
||||
|
||||
name_to_features = {
|
||||
"input_ids":
|
||||
|
@ -353,6 +418,11 @@ def input_fn_builder(input_files,
|
|||
# For eval, we want no shuffling and parallel reading doesn't matter.
|
||||
if is_training:
|
||||
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
|
||||
|
||||
if use_hvd:
|
||||
d = d.shard(hvd.size(), hvd.rank()) #TODO only for Horovod, shard to mimic single_GPU = False
|
||||
print("Data shard: %s %s" % (hvd.size(), hvd.rank()))
|
||||
|
||||
d = d.repeat()
|
||||
d = d.shuffle(buffer_size=len(input_files))
|
||||
|
||||
|
@ -371,7 +441,7 @@ def input_fn_builder(input_files,
|
|||
d = tf.data.TFRecordDataset(input_files)
|
||||
# Since we evaluate for a fixed number of steps we don't want to encounter
|
||||
# out-of-range exceptions.
|
||||
d = d.repeat()
|
||||
# d = d.repeat()
|
||||
|
||||
# We must `drop_remainder` on training because the TPU requires fixed
|
||||
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
|
||||
|
@ -406,36 +476,90 @@ def _decode_record(record, name_to_features):
|
|||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
if not FLAGS.do_train and not FLAGS.do_eval:
|
||||
if FLAGS.use_hvd:
|
||||
hvd.init()
|
||||
|
||||
if FLAGS.reduce_log and (hvd.rank() != 0):
|
||||
tf.logging.set_verbosity(tf.logging.ERROR)
|
||||
|
||||
FLAGS.output_dir = FLAGS.output_dir if hvd.rank() == 0 else os.path.join(FLAGS.output_dir, str(hvd.rank()))
|
||||
|
||||
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_train_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
|
||||
|
||||
tf.gfile.MakeDirs(FLAGS.output_dir)
|
||||
|
||||
if FLAGS.recover_dir is not None:
|
||||
if FLAGS.use_hvd:
|
||||
FLAGS.recover_dir = FLAGS.recover_dir if hvd.rank() == 0 else os.path.join(FLAGS.recover_dir, str(hvd.rank()))
|
||||
path_ckpt = os.path.join(FLAGS.output_dir, "checkpoint")
|
||||
path_ckpt_input = os.path.join(FLAGS.output_dir, "checkpoint_input")
|
||||
|
||||
if FLAGS.ckpt_no is not None and not tf.gfile.Exists(path_ckpt):
|
||||
with tf.gfile.GFile(path_ckpt, "w") as writer:
|
||||
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
|
||||
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "model.ckpt"), str(FLAGS.ckpt_no)))
|
||||
|
||||
if FLAGS.ckpt_no_input is not None and not tf.gfile.Exists(path_ckpt_input):
|
||||
with tf.gfile.GFile(path_ckpt_input, "w") as writer:
|
||||
writer.write('model_checkpoint_path: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
|
||||
writer.write('all_model_checkpoint_paths: "%s-%s"\n' % (os.path.join(FLAGS.recover_dir, "input.ckpt"), str(FLAGS.ckpt_no_input)))
|
||||
|
||||
if FLAGS.use_hvd and hvd.rank() == 0 and (FLAGS.do_train or FLAGS.do_train_eval):
|
||||
(cpath, cname) = os.path.split(FLAGS.bert_config_file)
|
||||
tf.gfile.Copy(FLAGS.bert_config_file, os.path.join(FLAGS.output_dir, cname), True)
|
||||
|
||||
input_files = []
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
if FLAGS.input_file is not None:
|
||||
for input_pattern in FLAGS.input_file.split(","):
|
||||
input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
if FLAGS.input_dir is not None:
|
||||
for filename in tf.gfile.ListDirectory(FLAGS.input_dir):
|
||||
input_files.extend(tf.gfile.Glob(os.path.join(FLAGS.input_dir, filename)))
|
||||
|
||||
tf.logging.info("*** Input Files ***")
|
||||
for input_file in input_files:
|
||||
tf.logging.info(" %s" % input_file)
|
||||
|
||||
tpu_cluster_resolver = None
|
||||
if FLAGS.use_tpu and FLAGS.tpu_name:
|
||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||
validation_input_files = []
|
||||
if FLAGS.validation_input_file is None and FLAGS.validation_input_dir is None:
|
||||
validation_input_files = input_files
|
||||
else:
|
||||
if FLAGS.validation_input_file is not None:
|
||||
for input_pattern in FLAGS.validation_input_file.split(","):
|
||||
validation_input_files.extend(tf.gfile.Glob(input_pattern))
|
||||
if FLAGS.validation_input_dir is not None:
|
||||
for filename in tf.gfile.ListDirectory(FLAGS.validation_input_dir):
|
||||
validation_input_files.extend(tf.gfile.Glob(os.path.join(FLAGS.validation_input_dir, filename)))
|
||||
|
||||
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
cluster=tpu_cluster_resolver,
|
||||
master=FLAGS.master,
|
||||
tf.logging.info("*** Input Validation Files ***")
|
||||
for input_file in validation_input_files:
|
||||
tf.logging.info(" %s" % input_file)
|
||||
|
||||
config = tf.ConfigProto()
|
||||
if FLAGS.xla:
|
||||
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
||||
if FLAGS.use_hvd:
|
||||
config.gpu_options.visible_device_list = str(hvd.local_rank())
|
||||
config.gpu_options.allow_growth=True
|
||||
|
||||
run_config = tf.estimator.RunConfig(
|
||||
model_dir=FLAGS.output_dir,
|
||||
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
|
||||
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
iterations_per_loop=FLAGS.iterations_per_loop,
|
||||
num_shards=FLAGS.num_tpu_cores,
|
||||
per_host_input_for_training=is_per_host))
|
||||
log_step_count_steps=FLAGS.hooking_frequence,
|
||||
session_config=config)
|
||||
|
||||
if FLAGS.use_hvd and hvd.rank() != 0 and not FLAGS.auto_recover:
|
||||
run_config = tf.estimator.RunConfig(
|
||||
model_dir=FLAGS.output_dir,
|
||||
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
|
||||
save_checkpoints_steps=None,
|
||||
save_checkpoints_secs=None,
|
||||
log_step_count_steps=FLAGS.hooking_frequence,
|
||||
session_config=config)
|
||||
|
||||
model_fn = model_fn_builder(
|
||||
bert_config=bert_config,
|
||||
|
@ -444,16 +568,36 @@ def main(_):
|
|||
num_train_steps=FLAGS.num_train_steps,
|
||||
num_warmup_steps=FLAGS.num_warmup_steps,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
use_one_hot_embeddings=FLAGS.use_tpu)
|
||||
use_one_hot_embeddings=FLAGS.use_tpu,
|
||||
adjust_lr=FLAGS.adjust_lr,
|
||||
use_hvd=FLAGS.use_hvd,
|
||||
use_compression=FLAGS.use_compression,
|
||||
use_fp16=FLAGS.use_fp16,
|
||||
clip=FLAGS.clip,
|
||||
cos_decay=FLAGS.cos_decay,
|
||||
use_lamb=FLAGS.use_lamb,
|
||||
previous_train_steps=FLAGS.previous_train_steps,
|
||||
post_train_steps=FLAGS.post_train_steps)
|
||||
|
||||
hooks = []
|
||||
|
||||
if FLAGS.use_hvd:
|
||||
hooks.append(hvd.BroadcastGlobalVariablesHook(0))
|
||||
|
||||
if hvd.rank() == -1: #if debug, set 0
|
||||
CLIDebugHook = tf_debug.LocalCLIDebugHook(ui_type='readline')
|
||||
CLIDebugHook.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
|
||||
hooks.append(CLIDebugHook)
|
||||
|
||||
if FLAGS.profile and hvd.rank() == 0:
|
||||
ProfilerHook = tf.train.ProfilerHook(save_steps=FLAGS.hooking_frequence, output_dir=FLAGS.output_dir, show_dataflow=True, show_memory=True)
|
||||
hooks.append(ProfilerHook)
|
||||
|
||||
# If TPU is not available, this will fall back to normal Estimator on CPU
|
||||
# or GPU.
|
||||
estimator = tf.contrib.tpu.TPUEstimator(
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=model_fn,
|
||||
config=run_config,
|
||||
train_batch_size=FLAGS.train_batch_size,
|
||||
eval_batch_size=FLAGS.eval_batch_size)
|
||||
config=run_config)
|
||||
|
||||
if FLAGS.do_train:
|
||||
tf.logging.info("***** Running training *****")
|
||||
|
@ -462,18 +606,26 @@ def main(_):
|
|||
input_files=input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=True)
|
||||
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
|
||||
is_training=True,
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
if FLAGS.auto_recover:
|
||||
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
|
||||
|
||||
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks)
|
||||
|
||||
if FLAGS.do_eval:
|
||||
tf.logging.info("***** Running evaluation *****")
|
||||
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||
|
||||
eval_input_fn = input_fn_builder(
|
||||
input_files=input_files,
|
||||
input_files=validation_input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=False)
|
||||
is_training=False,
|
||||
batch_size=FLAGS.eval_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
result = estimator.evaluate(
|
||||
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
||||
|
@ -485,9 +637,37 @@ def main(_):
|
|||
tf.logging.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if FLAGS.do_train_eval:
|
||||
tf.logging.info("***** Running training *****")
|
||||
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
|
||||
train_input_fn = input_fn_builder(
|
||||
input_files=input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=True,
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
tf.logging.info("***** Running evaluation *****")
|
||||
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
|
||||
eval_input_fn = input_fn_builder(
|
||||
input_files=validation_input_files,
|
||||
max_seq_length=FLAGS.max_seq_length,
|
||||
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
|
||||
is_training=False,
|
||||
batch_size=FLAGS.eval_batch_size,
|
||||
use_hvd=FLAGS.use_hvd)
|
||||
|
||||
if FLAGS.auto_recover:
|
||||
hooks.append(tf.data.experimental.CheckpointInputPipelineHook(estimator))
|
||||
|
||||
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps, hooks=hooks)
|
||||
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
|
||||
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("input_file")
|
||||
# flags.mark_flag_as_required("input_file")
|
||||
flags.mark_flag_as_required("bert_config_file")
|
||||
flags.mark_flag_as_required("output_dir")
|
||||
tf.app.run()
|
||||
|
|
Загрузка…
Ссылка в новой задаче