233 строки
13 KiB
Python
Executable File
233 строки
13 KiB
Python
Executable File
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
|
|
from model.pointer_net import *
|
|
from model.pointer_net_helper import *
|
|
|
|
from pprint import pprint
|
|
|
|
default_var_dict = {
|
|
"train_inputs" : "graph/train_inputs:0",
|
|
"train_outputs": "graph/train_outputs:0",
|
|
"train_input_mask": "graph/train_input_mask:0",
|
|
"train_output_mask": "graph/train_output_mask:0",
|
|
"seq2seq_feed_previous": "graph/seq2seq_feed_previous:0",
|
|
"token_accuracy": "graph/token_accuracy:0",
|
|
"sentence_accuracy": "graph/sentence_accuracy:0",
|
|
"predicted_labels": "graph/predicted_labels:0",
|
|
"total_loss": "graph/total_loss:0",
|
|
"type_masks": "graph/type_masks:0"
|
|
}
|
|
|
|
def build_graph(decoder_type, explicit_pointer, value_based_loss,
|
|
hyper_param, pnet_vocab, pretrained_enc_embedding=None,
|
|
multi_encoders=None, old_graph=None, scope=None):
|
|
""" Build a graph """
|
|
# hyper parameter
|
|
embedding_size = hyper_param["embedding_size"]
|
|
#batch_size = hyper_param["batch_size"] # fixed batch size when building graph
|
|
batch_size = None # mutable batch size data
|
|
n_hidden = hyper_param["n_hidden"]
|
|
num_layers = hyper_param["num_layers"]
|
|
learning_rate = hyper_param["learning_rate"]
|
|
dropout_keep_prob = hyper_param["dropout_keep_prob"]
|
|
encoder_merge_method = hyper_param["encoder_merge_method"]
|
|
|
|
input_vocab, output_vocab, X_maxlen, Y_maxlen = pnet_vocab.get_all()
|
|
|
|
graph = tf.Graph()
|
|
|
|
# scope should only be used after graph is defined
|
|
with graph.as_default(), tf.variable_scope(scope or "graph") as scope:
|
|
train_inputs = tf.placeholder(tf.int32, [batch_size, X_maxlen], name="train_inputs")
|
|
# the mask for identifying padding tokens from the sentence
|
|
train_input_mask = tf.placeholder(tf.bool, [batch_size, X_maxlen], name="train_input_mask")
|
|
train_outputs = tf.placeholder(tf.int32, [batch_size, Y_maxlen], name="train_outputs")
|
|
train_output_mask = tf.placeholder(tf.bool, [batch_size, Y_maxlen], name="train_output_mask")
|
|
type_masks = tf.placeholder(tf.bool, [2, batch_size, X_maxlen], name="type_masks")
|
|
|
|
batch_size = tf.shape(train_inputs)[0]
|
|
|
|
seq2seq_feed_previous = tf.placeholder(tf.bool, name="seq2seq_feed_previous")
|
|
|
|
if pretrained_enc_embedding is not None:
|
|
enc_embedding = tf.get_variable("enc_embedding",
|
|
shape=[input_vocab.size, embedding_size],
|
|
initializer=tf.constant_initializer(pretrained_enc_embedding),
|
|
trainable=False)
|
|
else:
|
|
enc_embedding = tf.get_variable("enc_embedding",
|
|
initializer=tf.random_uniform([input_vocab.size, embedding_size], -1.0,1))
|
|
|
|
fw_enc_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(n_hidden) for _ in range(num_layers)])
|
|
bw_enc_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(n_hidden) for _ in range(num_layers)])
|
|
dec_cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(n_hidden) for _ in range(num_layers)])
|
|
|
|
if dropout_keep_prob < 1:
|
|
t_dropout_keep_prob = tf.cond(seq2seq_feed_previous, lambda: 1., lambda: dropout_keep_prob)
|
|
# the cell used in the model
|
|
fw_enc_cell = tf.contrib.rnn.DropoutWrapper(fw_enc_cell, output_keep_prob=t_dropout_keep_prob)
|
|
bw_enc_cell = tf.contrib.rnn.DropoutWrapper(bw_enc_cell, output_keep_prob=t_dropout_keep_prob)
|
|
dec_cell = tf.contrib.rnn.DropoutWrapper(dec_cell, output_keep_prob=t_dropout_keep_prob)
|
|
|
|
encoder_inputs = [train_inputs[:,k] for k in range(0, X_maxlen)]
|
|
|
|
# the first token for decoder inputs is always the <GO> token and the last one from train outputs is not needed
|
|
decoder_inputs = ([tf.fill([batch_size], output_vocab.word_to_index(Vocabulary.GO_TOK))]
|
|
+ [train_outputs[:,k] for k in range(0, Y_maxlen)][0:-1])
|
|
|
|
# embed a raw vector into a 32 value vector
|
|
enc_embeds = [tf.nn.embedding_lookup(enc_embedding, enc_input) for enc_input in encoder_inputs]
|
|
stacked_enc_embeds = tf.transpose(tf.stack(enc_embeds), [1,2,0])
|
|
|
|
# embed a raw vector into a 32 value vector
|
|
dec_embedding = tf.get_variable("dec_embedding", initializer=tf.random_uniform([output_vocab.size, embedding_size],-1.0,1))
|
|
# perform lookup based on types of the last decoder cell, note that <GO> symol is always looked up by embedding
|
|
dec_embeds = [tf.nn.embedding_lookup(dec_embedding, x) if i == 0 or decoder_type[i-1].ty is DecoderType.Projector
|
|
else tf.squeeze(tf.matmul(stacked_enc_embeds,
|
|
tf.expand_dims(tf.one_hot(x, stacked_enc_embeds.get_shape()[-1]), -1)), -1)
|
|
for i, x in enumerate(decoder_inputs)]
|
|
|
|
print("Loss function is {}".format(value_based_loss))
|
|
|
|
reshaped_input_mask = tf.unstack(tf.transpose(train_input_mask, [1,0]))
|
|
|
|
def feed_prev_func(prev, current_decoder_type):
|
|
# the feeding function
|
|
if current_decoder_type.ty is DecoderType.Projector:
|
|
# in this case prev is the output vector repr
|
|
prev_symbol = math_ops.argmax(prev, 1)
|
|
emb_prev = tf.nn.embedding_lookup(dec_embedding, prev_symbol)
|
|
elif current_decoder_type.ty is DecoderType.Pointer:
|
|
# in this case prev is the energy function for pointers
|
|
|
|
logits_n_inputs = tf.concat([tf.expand_dims(prev,-2),
|
|
tf.cast(tf.expand_dims(train_inputs,-2), tf.float32)],
|
|
-2)
|
|
|
|
if value_based_loss == "sum_vloss":
|
|
transferred_distrib = tf.map_fn(lambda x: tf.unsorted_segment_sum(x[0], tf.cast(x[1], tf.int32), input_vocab.size), logits_n_inputs)
|
|
elif value_based_loss == "max_vloss" or value_based_loss == "ploss":
|
|
transferred_distrib = tf.map_fn(lambda x: tf.unsorted_segment_max(x[0], tf.cast(x[1], tf.int32), input_vocab.size), logits_n_inputs)
|
|
|
|
emb_prev = tf.nn.embedding_lookup(enc_embedding, tf.argmax(transferred_distrib,-1))
|
|
return emb_prev
|
|
else:
|
|
raise Exception('not a expected type')
|
|
|
|
return emb_prev
|
|
|
|
outs = pointer_network(enc_embeds, dec_embeds,
|
|
fw_enc_cell, bw_enc_cell, dec_cell,
|
|
decoder_type, output_vocab.size,
|
|
encoder_masks=reshaped_input_mask,
|
|
feed_prev=seq2seq_feed_previous,
|
|
loop_function=feed_prev_func,
|
|
multi_encoders=multi_encoders,
|
|
encoder_merge_method=encoder_merge_method)
|
|
|
|
|
|
# split train outs into train outs for projection and train outs for pointer
|
|
# and concretize pointer in the tensor into labels
|
|
pointer_train_outs, proj_train_outs = split_by_type(train_outputs, decoder_type, axis=1)
|
|
pointer_out_mask, proj_out_mask = split_by_type(train_output_mask, decoder_type, axis=1)
|
|
pntr_dec_types = [x for x in decoder_type if x.ty == DecoderType.Pointer]
|
|
|
|
def _process_outs(outs, target_type):
|
|
""" reshape outs into shape (batch_size, type_num, X_maxlen) for the purpose of computing """
|
|
cnt = len([i for i in range(len(outs)) if decoder_type[i].ty is target_type])
|
|
if cnt == 0:
|
|
target_outs = tf.reshape([], [batch_size, cnt, X_maxlen])
|
|
else:
|
|
trimed_outs = [outs[i] for i in range(len(outs)) if decoder_type[i].ty is target_type]
|
|
# outputs from the seq2seq model is already logits
|
|
target_outs = trimed_outs
|
|
return target_outs
|
|
|
|
|
|
proj_outs = tf.transpose(_process_outs(outs, DecoderType.Projector), perm=[1,0,2])
|
|
|
|
pointer_outs = _process_outs(outs, DecoderType.Pointer)
|
|
pointer_outs = [pointer_outs[i] * tf.cast(type_masks[pntr_dec_types[i].mask_name], tf.float32) for i in range(len(pntr_dec_types))]
|
|
pointer_outs = tf.transpose(pointer_outs, perm=[1,0,2])
|
|
|
|
# predictions made by the neural network
|
|
pointer_predictions = tf.nn.softmax(pointer_outs)
|
|
|
|
proj_predictions = tf.nn.softmax(proj_outs)
|
|
|
|
# labels predicted from the result (sharpen), get encoder symbols from the
|
|
proj_predicted_labels = tf.cast(tf.argmax(proj_predictions, axis=-1), tf.int32)
|
|
|
|
# prepare data for energy transfer
|
|
copied_train_inputs = tf.transpose(tf.stack([train_inputs for x in range(int(pointer_train_outs.get_shape()[-1]))]), [1, 0, 2])
|
|
merged_logits_inputs = tf.concat([tf.expand_dims(pointer_predictions,-2),
|
|
tf.cast(tf.expand_dims(copied_train_inputs,-2), tf.float32)],
|
|
-2)
|
|
|
|
# transfer distribution over pointer to distribution over encoder symbols
|
|
if value_based_loss == "sum_vloss":
|
|
distrib_over_encoder_symbols = tf.map_fn(lambda y: tf.map_fn(lambda x: tf.unsorted_segment_sum(x[0], tf.cast(x[1], tf.int32), input_vocab.size), y),
|
|
merged_logits_inputs)
|
|
elif value_based_loss == "max_vloss" or value_based_loss == "ploss":
|
|
distrib_over_encoder_symbols = tf.map_fn(lambda y: tf.map_fn(lambda x: tf.unsorted_segment_max(x[0], tf.cast(x[1], tf.int32), input_vocab.size), y),
|
|
merged_logits_inputs)
|
|
|
|
labels_by_pointers = tf.cast(tf.argmax(distrib_over_encoder_symbols, axis=-1), tf.int32)
|
|
|
|
predicted_labels = assemble_by_type(labels_by_pointers, proj_predicted_labels, decoder_type, axis=1)
|
|
|
|
if explicit_pointer:
|
|
pointer_predicted_labels = tf.cast(tf.argmax(pointer_predictions, axis=-1), tf.int32)
|
|
# compute accuracy with pointers, since pointers are explicitly provided in the dataset
|
|
token_accuracy, sentence_accuracy = compute_accuracy(pointer_predicted_labels, pointer_train_outs,
|
|
proj_predicted_labels, proj_train_outs,
|
|
pointer_out_mask, proj_out_mask)
|
|
else:
|
|
# compute accuracy with concrete labels instead of pointers
|
|
# (computing based on labels is better since we only care about the final result)
|
|
token_accuracy, sentence_accuracy = compute_accuracy(labels_by_pointers,
|
|
pointer_to_label(pointer_train_outs, train_inputs),
|
|
proj_predicted_labels, proj_train_outs,
|
|
pointer_out_mask, proj_out_mask)
|
|
|
|
token_accuracy = tf.identity(token_accuracy, name="token_accuracy")
|
|
sentence_accuracy = tf.identity(sentence_accuracy, name="sentence_accuracy")
|
|
predicted_labels = tf.identity(predicted_labels, name="predicted_labels")
|
|
|
|
# the loss function
|
|
if value_based_loss == "sum_vloss" or value_based_loss == "max_vloss":
|
|
# loss based on probability of
|
|
one_hot_train_out = tf.one_hot(pointer_to_label(pointer_train_outs, train_inputs), input_vocab.size, axis=-1)
|
|
# this clip is used to handle nan problem in calculation
|
|
pntr_losses = -tf.reduce_sum(one_hot_train_out * tf.log(tf.clip_by_value(distrib_over_encoder_symbols,1e-10,1.0)), -1)
|
|
elif value_based_loss == "ploss":
|
|
pntr_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=pointer_train_outs, logits=pointer_outs)
|
|
else:
|
|
print("Loss function ({}) can not be recognized, exiting...".format(value_based_loss))
|
|
sys.exit(-1)
|
|
|
|
# TODO: add a mask to remove losses from padding symbols
|
|
proj_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=proj_train_outs, logits=proj_outs)
|
|
|
|
# the loss hould be normalized by the total number of unmasked cells
|
|
total_loss = tf.add(tf.reduce_sum(tf.multiply(pntr_losses, tf.cast(pointer_out_mask, tf.float32))),
|
|
tf.reduce_sum(tf.multiply(proj_losses, tf.cast(proj_out_mask, tf.float32))),
|
|
name="total_loss")
|
|
|
|
var_dict = {
|
|
"train_inputs" : train_inputs.name,
|
|
"train_outputs": train_outputs.name,
|
|
"train_input_mask": train_input_mask.name,
|
|
"train_output_mask": train_output_mask.name,
|
|
"seq2seq_feed_previous": seq2seq_feed_previous.name,
|
|
"token_accuracy": token_accuracy.name,
|
|
"sentence_accuracy": sentence_accuracy.name,
|
|
"predicted_labels": predicted_labels.name,
|
|
"total_loss": total_loss.name,
|
|
"type_masks": type_masks.name
|
|
}
|
|
|
|
return graph, var_dict
|
|
|