This commit is contained in:
Daisy Deng 2020-03-06 15:56:44 +00:00
Родитель c836249713 b527bdfcdb
Коммит fd84b4b32f
6 изменённых файлов: 85 добавлений и 205 удалений

Просмотреть файл

@ -25,6 +25,12 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas
from utils_nlp.eval.evaluate_summarization import get_rouge from utils_nlp.eval.evaluate_summarization import get_rouge
os.environ["NCCL_IB_DISABLE"] = "0" os.environ["NCCL_IB_DISABLE"] = "0"
#os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"
#os.environ["MASTER_PORT"] = "29952"
#os.environ["MASTER_ADDR"] = "172.12.0.6"
#os.environ['NCCL_SOCKET_IFNAME'] = 'lo'
def shorten_dataset(dataset, top_n=-1): def shorten_dataset(dataset, top_n=-1):
if top_n == -1: if top_n == -1:
@ -56,6 +62,8 @@ parser.add_argument("--lr_dec", type=float, default=2e-1,
help="Learning rate for the decoder.") help="Learning rate for the decoder.")
parser.add_argument("--batch_size", type=int, default=5, parser.add_argument("--batch_size", type=int, default=5,
help="batch size in terms of input token numbers in training") help="batch size in terms of input token numbers in training")
parser.add_argument("--max_pos", type=int, default=640,
help="maximum input length in terms of input token numbers in training")
parser.add_argument("--max_steps", type=int, default=5e4, parser.add_argument("--max_steps", type=int, default=5e4,
help="Maximum number of training steps run in training. If quick_run is set,\ help="Maximum number of training steps run in training. If quick_run is set,\
it's not used.") it's not used.")
@ -107,7 +115,7 @@ def main():
print("data_dir is {}".format(args.data_dir)) print("data_dir is {}".format(args.data_dir))
print("cache_dir is {}".format(args.cache_dir)) print("cache_dir is {}".format(args.cache_dir))
ngpus_per_node = torch.cuda.device_count() ngpus_per_node = torch.cuda.device_count()
processor = AbsSumProcessor(cache_dir=args.cache_dir) processor = AbsSumProcessor(cache_dir=args.cache_dir, max_src_len=max_pos)
summarizer = AbsSum( summarizer = AbsSum(
processor, cache_dir=args.cache_dir processor, cache_dir=args.cache_dir
) )
@ -120,13 +128,14 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
print("world_size is {}".format(world_size)) print("world_size is {}".format(world_size))
print("local_rank is {} and rank is {}".format(local_rank, rank)) print("local_rank is {} and rank is {}".format(local_rank, rank))
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
init_method=args.dist_url, init_method=args.dist_url,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
) )
# return
## should not load checkpoint from this place, otherwise, huge memory increase ## should not load checkpoint from this place, otherwise, huge memory increase
if args.checkpoint_filename: if args.checkpoint_filename:
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename) checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)

Просмотреть файл

@ -25,7 +25,7 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas
from utils_nlp.eval.evaluate_summarization import get_rouge from utils_nlp.eval.evaluate_summarization import get_rouge
CACHE_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp" CACHE_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization" DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
MODEL_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp" MODEL_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
TOP_N = 10 TOP_N = 10
@ -259,10 +259,11 @@ def test_pretrained_model():
checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt")) checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt"))
#checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt")) #checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt"))
checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt")) #checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt"))
summarizer = AbsSum( summarizer = AbsSum(
processor, processor,
cache_dir=CACHE_PATH, cache_dir=CACHE_PATH,
max_pos=512,
) )
summarizer.model.load_checkpoint(checkpoint['model']) summarizer.model.load_checkpoint(checkpoint['model'])
""" """
@ -284,14 +285,15 @@ def test_pretrained_model():
return return
""" """
top_n = 8 top_n = 96*4
src = test_sum_dataset.source[0:top_n] src = test_sum_dataset.source[0:top_n]
reference_summaries = ["".join(t).rstrip("\n") for t in test_sum_dataset.target[0:top_n]] reference_summaries = ["".join(t).rstrip("\n") for t in test_sum_dataset.target[0:top_n]]
print("start prediction") print("start prediction")
generated_summaries = summarizer.predict( generated_summaries = summarizer.predict(
shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=4, num_gpus=2 shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=96+16, num_gpus=1, max_seq_length=512
) )
print(generated_summaries[0]) print(generated_summaries[0])
print(len(generated_summaries))
assert len(generated_summaries) == len(reference_summaries) assert len(generated_summaries) == len(reference_summaries)
RESULT_DIR = TemporaryDirectory().name RESULT_DIR = TemporaryDirectory().name
rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR) rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR)
@ -303,6 +305,6 @@ def test_pretrained_model():
#test_collate() #test_collate()
#preprocess_cnndm_abs() #preprocess_cnndm_abs()
#test_train_model() #test_train_model()
#test_pretrained_model() test_pretrained_model()
if __name__ == "__main__": #if __name__ == "__main__":
main() # main()

Просмотреть файл

@ -113,7 +113,7 @@ class AbsSumProcessor:
model_name="bert-base-uncased", model_name="bert-base-uncased",
to_lower=False, to_lower=False,
cache_dir=".", cache_dir=".",
max_len=512, max_src_len=640,
max_target_len=140, max_target_len=140,
): ):
""" Initialize the preprocessor. """ Initialize the preprocessor.
@ -144,14 +144,11 @@ class AbsSumProcessor:
self.sep_token = "[SEP]" self.sep_token = "[SEP]"
self.cls_token = "[CLS]" self.cls_token = "[CLS]"
self.pad_token = "[PAD]" self.pad_token = "[PAD]"
self.tgt_bos = "[unused0]" self.tgt_bos = self.symbols["BOS"]
self.tgt_eos = "[unused1]" self.tgt_eos = self.symbols["EOS"]
self.sep_vid = self.tokenizer.vocab[self.sep_token]
self.cls_vid = self.tokenizer.vocab[self.cls_token]
self.pad_vid = self.tokenizer.vocab[self.pad_token]
self.max_len = max_len self.max_src_len = max_src_len
self.max_target_len = max_target_len self.max_target_len = max_target_len
@staticmethod @staticmethod
@ -243,7 +240,8 @@ class AbsSumProcessor:
if train_mode: if train_mode:
encoded_summaries = torch.tensor( encoded_summaries = torch.tensor(
[ [
fit_to_block_size(summary, block_size, self.tokenizer.pad_token_id) [self.tgt_bos] + fit_to_block_size(summary, block_size-2, self.tokenizer.pad_token_id)
+[self.tgt_eos]
for _, summary in encoded_text for _, summary in encoded_text
] ]
) )
@ -308,7 +306,7 @@ class AbsSumProcessor:
try: try:
if len(line) <= 0: if len(line) <= 0:
continue continue
story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_len)) story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_src_len))
except: except:
print(line) print(line)
raise raise
@ -325,9 +323,9 @@ class AbsSumProcessor:
except: except:
print(line) print(line)
raise raise
summary_token_ids = [ summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence token for sentence in summary_lines_token_ids for token in sentence
] ]
return story_token_ids, summary_token_ids return story_token_ids, summary_token_ids
else: else:
return story_token_ids return story_token_ids
@ -366,6 +364,7 @@ class AbsSum(Transformer):
cache_dir=".", cache_dir=".",
label_smoothing=0.1, label_smoothing=0.1,
test=False, test=False,
max_pos=768,
): ):
"""Initialize a ExtractiveSummarizer. """Initialize a ExtractiveSummarizer.
@ -397,6 +396,7 @@ class AbsSum(Transformer):
self.model_class = MODEL_CLASS[model_name] self.model_class = MODEL_CLASS[model_name]
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.max_pos = max_pos
self.model = AbsSummarizer( self.model = AbsSummarizer(
temp_dir=cache_dir, temp_dir=cache_dir,
@ -404,12 +404,12 @@ class AbsSum(Transformer):
checkpoint=None, checkpoint=None,
label_smoothing=label_smoothing, label_smoothing=label_smoothing,
symbols=processor.symbols, symbols=processor.symbols,
test=test test=test,
max_pos=self.max_pos,
) )
self.processor = processor self.processor = processor
self.optim_bert = None self.optim_bert = None
self.optim_dec = None self.optim_dec = None
@staticmethod @staticmethod
@ -423,7 +423,7 @@ class AbsSum(Transformer):
train_dataset, train_dataset,
num_gpus=None, num_gpus=None,
gpu_ids=None, gpu_ids=None,
batch_size=140, batch_size=4,
local_rank=-1, local_rank=-1,
max_steps=5e5, max_steps=5e5,
warmup_steps_bert=8000, warmup_steps_bert=8000,
@ -499,6 +499,7 @@ class AbsSum(Transformer):
#""" #"""
self.optim_bert = model_builder.build_optim_bert( self.optim_bert = model_builder.build_optim_bert(
self.model, self.model,
optim=optimization_method,
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3", visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3",
lr_bert=learning_rate_bert, lr_bert=learning_rate_bert,
warmup_steps_bert=warmup_steps_bert, warmup_steps_bert=warmup_steps_bert,
@ -506,6 +507,7 @@ class AbsSum(Transformer):
) )
self.optim_dec = model_builder.build_optim_dec( self.optim_dec = model_builder.build_optim_dec(
self.model, self.model,
optim=optimization_method,
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3" visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3"
lr_dec=learning_rate_dec, lr_dec=learning_rate_dec,
warmup_steps_dec=warmup_steps_dec, warmup_steps_dec=warmup_steps_dec,
@ -518,6 +520,7 @@ class AbsSum(Transformer):
optimizers = [self.optim_bert, self.optim_dec] optimizers = [self.optim_bert, self.optim_dec]
schedulers = [self.scheduler_bert, self.scheduler_dec] schedulers = [self.scheduler_bert, self.scheduler_dec]
self.amp = get_amp(fp16) self.amp = get_amp(fp16)
if self.amp: if self.amp:
self.model, optim = self.amp.initialize(self.model, optimizers, opt_level=fp16_opt_level) self.model, optim = self.amp.initialize(self.model, optimizers, opt_level=fp16_opt_level)
@ -548,7 +551,7 @@ class AbsSum(Transformer):
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
def collate_fn(data): def collate_fn(data):
return self.processor.collate(data, block_size=512, device=device) return self.processor.collate(data, block_size=self.max_pos, device=device)
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
@ -592,13 +595,13 @@ class AbsSum(Transformer):
local_rank=-1, local_rank=-1,
gpu_ids=None, gpu_ids=None,
batch_size=16, batch_size=16,
# sentence_separator="<q>",
alpha=0.6, alpha=0.6,
beam_size=5, beam_size=5,
min_length=15, min_length=15,
max_length=150, max_length=150,
fp16=False, fp16=False,
verbose=True, verbose=True,
max_seq_length=768,
): ):
""" """
Predict the summarization for the input data iterator. Predict the summarization for the input data iterator.
@ -625,11 +628,11 @@ class AbsSum(Transformer):
List of strings which are the summaries List of strings which are the summaries
""" """
device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank) device, num_gpus = get_device(num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank)
# move model to devices # move model to devices
def this_model_move_callback(model, device): def this_model_move_callback(model, device):
model = move_model_to_device(model, device) model = move_model_to_device(model, device)
return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=None, local_rank=local_rank) return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank)
if fp16: if fp16:
self.model = self.model.half() self.model = self.model.half()
@ -647,13 +650,13 @@ class AbsSum(Transformer):
min_length=min_length, min_length=min_length,
max_length=max_length, max_length=max_length,
) )
predictor = predictor.move_to_device(device, this_model_move_callback) predictor = this_model_move_callback(predictor, device)
test_sampler = SequentialSampler(test_dataset) test_sampler = SequentialSampler(test_dataset)
def collate_fn(data): def collate_fn(data):
return self.processor.collate(data, 512, device, train_mode=False) return self.processor.collate(data, max_seq_length, device, train_mode=False)
test_dataloader = DataLoader( test_dataloader = DataLoader(
test_dataset, sampler=test_sampler, batch_size=batch_size, collate_fn=collate_fn, test_dataset, sampler=test_sampler, batch_size=batch_size, collate_fn=collate_fn,
@ -663,7 +666,7 @@ class AbsSum(Transformer):
""" Transforms the output of the `from_batch` function """ Transforms the output of the `from_batch` function
into nicely formatted summaries. into nicely formatted summaries.
""" """
raw_summary, _, = translation raw_summary = translation
summary = ( summary = (
raw_summary.replace("[unused0]", "") raw_summary.replace("[unused0]", "")
.replace("[unused3]", "") .replace("[unused3]", "")
@ -677,16 +680,31 @@ class AbsSum(Transformer):
.strip() .strip()
) )
return summary return summary
def generate_summary_from_tokenid(preds, pred_score):
batch_size = preds.size()[0] # batch.batch_size
translations = []
for b in range(batch_size):
if len(preds[b]) < 1:
pred_sents = ""
else:
pred_sents = self.processor.tokenizer.convert_ids_to_tokens([int(n) for n in preds[b] if int(n)!=0])
pred_sents = " ".join(pred_sents).replace(" ##", "")
translations.append(pred_sents)
return translations
generated_summaries = [] generated_summaries = []
from tqdm import tqdm from tqdm import tqdm
for batch in tqdm(test_dataloader): for batch in tqdm(test_dataloader, desc="Generating summary", disable=not verbose):
batch_data = predictor.translate_batch(batch) input = self.processor.get_inputs(batch, device, "bert", train_mode=False)
translations = predictor.from_batch(batch_data) translations, scores = predictor(**input)
summaries = [format_summary(t) for t in translations]
generated_summaries += summaries translations_text = generate_summary_from_tokenid(translations, scores)
summaries = [format_summary(t) for t in translations_text]
generated_summaries.extend(summaries)
return generated_summaries return generated_summaries
def save_model(self, global_step=None, full_name=None): def save_model(self, global_step=None, full_name=None):

Просмотреть файл

@ -196,7 +196,7 @@ class AbsSummarizer(nn.Module):
self.bert.model = BertModel(bert_config) self.bert.model = BertModel(bert_config)
if max_pos > 512: if max_pos > 512:
my_pos_embeddi = nn.Embedding(max_pos, self.bert.model.config.hidden_size) my_pos_embeddings = nn.Embedding(max_pos, self.bert.model.config.hidden_size)
my_pos_embeddings.weight.data[ my_pos_embeddings.weight.data[
:512 :512
] = self.bert.model.embeddings.position_embeddings.weight.data ] = self.bert.model.embeddings.position_embeddings.weight.data

Просмотреть файл

@ -6,7 +6,7 @@ import os
import math import math
import torch import torch
from torch import nn
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
# from others.utils import rouge_results_to_str, test_rouge, tile # from others.utils import rouge_results_to_str, test_rouge, tile
@ -55,7 +55,7 @@ def tile(x, count, dim=0):
return x return x
class Translator(object): class Translator(nn.Module):
""" """
Uses a model to translate a batch of sentences. Uses a model to translate a batch of sentences.
@ -70,14 +70,12 @@ class Translator(object):
global_scores (:obj:`GlobalScorer`): global_scores (:obj:`GlobalScorer`):
object to rescore final translations object to rescore final translations
copy_attn (bool): use copy attention during translation copy_attn (bool): use copy attention during translation
cuda (bool): use cuda
beam_trace (bool): trace beam search for debugging beam_trace (bool): trace beam search for debugging
logger(logging.Logger): logger. logger(logging.Logger): logger.
""" """
def __init__( def __init__(
self, self,
# args,
beam_size, beam_size,
min_length, min_length,
max_length, max_length,
@ -89,10 +87,9 @@ class Translator(object):
logger=None, logger=None,
dump_beam="", dump_beam="",
): ):
super(Translator, self).__init__()
self.logger = logger self.logger = logger
self.cuda = 0 # args.visible_gpus != '-1'
# self.args = args
self.model = model.module if hasattr(model, "module") else model self.model = model.module if hasattr(model, "module") else model
self.generator = self.model.generator self.generator = self.model.generator
self.decoder = self.model.decoder self.decoder = self.model.decoder
@ -115,10 +112,6 @@ class Translator(object):
self.beam_trace = self.dump_beam != "" self.beam_trace = self.dump_beam != ""
self.beam_accum = None self.beam_accum = None
# tensorboard_log_dir = args.model_path
# self.tensorboard_writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")
if self.beam_trace: if self.beam_trace:
self.beam_accum = { self.beam_accum = {
"predicted_ids": [], "predicted_ids": [],
@ -126,90 +119,15 @@ class Translator(object):
"scores": [], "scores": [],
"log_probs": [], "log_probs": [],
} }
"""
def move_to_device(self, device, move_to_device_fn):
self.move_to_device_fn = move_to_device_fn
self.model = move_to_device_fn(self.model, device)
self.bert = move_to_device_fn(self.bert, device)
self.generator = move_to_device_fn(self.generator, device)
return self
def eval(self): def eval(self):
self.model.eval() self.model.eval()
self.bert.eval() self.bert.eval()
self.decoder.eval() self.decoder.eval()
self.generator.eval() self.generator.eval()
"""
def _build_target_tokens(self, pred): def forward(self, src, segs, mask_src):
# vocab = self.fields["tgt"].vocab
tokens = []
for tok in pred:
tok = int(tok)
tokens.append(tok)
if tokens[-1] == self.end_token:
tokens = tokens[:-1]
break
tokens = [t for t in tokens if t < len(self.vocab)]
tokens = self.vocab.DecodeIds(tokens).split(" ")
return tokens
def from_batch(self, translation_batch):
batch = translation_batch["batch"]
# assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"])
batch_size = batch.batch_size
# preds, pred_score, gold_score, tgt_str, src = (
preds, pred_score, src = (
translation_batch["predictions"],
translation_batch["scores"],
# translation_batch["gold_score"],
# batch.tgt_str,
batch.src,
)
# print(preds)
# print(pred_score)
# print(batch.tgt_str)
# print(batch.src)
translations = []
for b in range(batch_size):
if len(preds[b]) < 1:
pred_sents = ""
else:
pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]])
pred_sents = " ".join(pred_sents).replace(" ##", "")
# gold_sent = " ".join(tgt_str[b].split())
# translation = Translation(fname[b],src[:, b] if src is not None else None,
# src_raw, pred_sents,
# attn[b], pred_score[b], gold_sent,
# gold_score[b])
# src = self.spm.DecodeIds([int(t) for t in translation_batch['batch'].src[0][5] if int(t) != len(self.spm)])
raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500]
raw_src = " ".join(raw_src)
# translation = (pred_sents, gold_sent, raw_src)
translation = (pred_sents, raw_src)
# translation = (pred_sents[0], gold_sent)
translations.append(translation)
return translations
def translate(self, batch, attn_debug=False):
#self.model.eval()
self.eval()
# pred_results, gold_results = [], []
with torch.no_grad():
batch_data = self.translate_batch(batch)
translations = self.from_batch(batch_data)
# for trans in translations:
# pred, gold, src = trans
# pred_str = pred.replace('[unused0]', '').replace('[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace(r' +', ' ').replace(' [unused2] ', '<q>').replace('[unused2]', '').strip()
# gold_str = gold.strip()
return translations
def translate_batch(self, batch, fast=False):
""" """
Translate a batch of sentences. Translate a batch of sentences.
@ -224,22 +142,22 @@ class Translator(object):
Shouldn't need the original dataset. Shouldn't need the original dataset.
""" """
with torch.no_grad(): with torch.no_grad():
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length) predictions, scores = self._fast_translate_batch(src, segs, mask_src, self.max_length, min_length=self.min_length)
return predictions, scores
def _fast_translate_batch(self, batch, max_length, min_length=0): def _fast_translate_batch(self, src, segs, mask_src, max_length, min_length=0):
# TODO: faster code path for beam_size == 1. # TODO: faster code path for beam_size == 1.
# TODO: support these blacklisted features. # TODO: support these blacklisted features.
assert not self.dump_beam assert not self.dump_beam
beam_size = self.beam_size beam_size = self.beam_size
batch_size = batch.batch_size batch_size = src.size()[0] #32 #batch.batch_size
src = batch.src
segs = batch.segs
mask_src = batch.mask_src
src_features = self.bert(src, segs, mask_src) src_features = self.bert(src, segs, mask_src)
dec_states = self.decoder.init_decoder_state(src, src_features, with_cache=True) this_decoder = self.decoder.module if hasattr(self.decoder, "module") else self.decoder
dec_states = this_decoder.init_decoder_state(src, src_features, with_cache=True)
device = src_features.device device = src_features.device
@ -266,7 +184,7 @@ class Translator(object):
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
# results["gold_score"] = [0] * batch_size # results["gold_score"] = [0] * batch_size
results["batch"] = batch #results["batch"] = batch
for step in range(max_length): for step in range(max_length):
decoder_input = alive_seq[:, -1].view(1, -1) decoder_input = alive_seq[:, -1].view(1, -1)
@ -274,7 +192,7 @@ class Translator(object):
# Decoder forward. # Decoder forward.
decoder_input = decoder_input.transpose(0, 1) decoder_input = decoder_input.transpose(0, 1)
dec_out, dec_states = self.decoder( dec_out, dec_states = this_decoder(
decoder_input, src_features, dec_states, step=step decoder_input, src_features, dec_states, step=step
) )
@ -321,13 +239,11 @@ class Translator(object):
# Resolve beam origin and true word ids. # Resolve beam origin and true word ids.
topk_beam_index = topk_ids.div(vocab_size) topk_beam_index = topk_ids.div(vocab_size)
# print("topk_beam_index.shape {}".format( topk_beam_index.size()))
topk_ids = topk_ids.fmod(vocab_size) topk_ids = topk_ids.fmod(vocab_size)
# Map beam_index to batch_index in the flat representation. # Map beam_index to batch_index in the flat representation.
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1) batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
select_indices = batch_index.view(-1) select_indices = batch_index.view(-1)
# print("select_indices {}".format(select_indices))
# Append last prediction. # Append last prediction.
alive_seq = torch.cat( alive_seq = torch.cat(
@ -335,15 +251,10 @@ class Translator(object):
) )
is_finished = topk_ids.eq(self.end_token) is_finished = topk_ids.eq(self.end_token)
# print("is_finished {}".format(is_finished))
# print("is_finished size {}".format(is_finished.size()))
if step + 1 == max_length: if step + 1 == max_length:
# print("reached max_length {} at step {}".format(max_length, step))
is_finished.fill_(True) is_finished.fill_(True)
# print("is_finished {}".format(is_finished))
# End condition is top beam is finished. # End condition is top beam is finished.
end_condition = is_finished[:, 0].eq(True) end_condition = is_finished[:, 0].eq(True)
# print("end_condition {}".format(end_condition))
if step + 1 == max_length: if step + 1 == max_length:
assert not any(end_condition.eq(False)) assert not any(end_condition.eq(False))
@ -353,7 +264,6 @@ class Translator(object):
for i in range(is_finished.size(0)): for i in range(is_finished.size(0)):
b = batch_offset[i] b = batch_offset[i]
if end_condition[i]: if end_condition[i]:
# print("batch offset {} finished".format(b))
is_finished[i].fill_(1) is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1) finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch. # Store finished hypotheses for this batch.
@ -363,9 +273,6 @@ class Translator(object):
if end_condition[i]: if end_condition[i]:
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
score, pred = best_hyp[0] score, pred = best_hyp[0]
# if len(pred) == 0:
# print("batch offset {} finished with empty prediction {}".format(b, pred))
results["scores"][b].append(score) results["scores"][b].append(score)
results["predictions"][b].append(pred) results["predictions"][b].append(pred)
non_finished = end_condition.eq(0).nonzero().view(-1) non_finished = end_condition.eq(0).nonzero().view(-1)
@ -383,63 +290,7 @@ class Translator(object):
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices)) dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
empty_output = [len(results["predictions"][b]) <= 0 for b in batch_offset] empty_output = [len(results["predictions"][b]) <= 0 for b in batch_offset]
if any(empty_output): predictions = torch.tensor([i[0].tolist()[0:self.max_length]+[0]*(self.max_length-i[0].size()[0]) for i in results["predictions"]], device=device)
print("there is empty output {}".format(empty_output)) scores = torch.tensor([i[0].item() for i in results['scores']], device=device)
print(batch_offset) return predictions, scores
print(results)
# print("###########################################")
return results
class Translation(object):
"""
Container for a translated sentence.
Attributes:
src (`LongTensor`): src word ids
src_raw ([str]): raw src words
pred_sents ([[str]]): words from the n-best translations
pred_scores ([[float]]): log-probs of n-best translations
attns ([`FloatTensor`]) : attention dist for each translation
gold_sent ([str]): words from gold translation
gold_score ([float]): log-prob of gold translation
"""
def __init__(
self, fname, src, src_raw, pred_sents, attn, pred_scores, tgt_sent=None, gold_score=0
):
self.fname = fname
self.src = src
self.src_raw = src_raw
self.pred_sents = pred_sents
self.attns = attn
self.pred_scores = pred_scores
self.gold_sent = tgt_sent
self.gold_score = gold_score
def log(self, sent_number):
"""
Log translation.
"""
output = "\nSENT {}: {}\n".format(sent_number, self.src_raw)
best_pred = self.pred_sents[0]
best_score = self.pred_scores[0]
pred_sent = " ".join(best_pred)
output += "PRED {}: {}\n".format(sent_number, pred_sent)
output += "PRED SCORE: {:.4f}\n".format(best_score)
if self.gold_sent is not None:
tgt_sent = " ".join(self.gold_sent)
output += "GOLD {}: {}\n".format(sent_number, tgt_sent)
output += "GOLD SCORE: {:.4f}\n".format(self.gold_score)
if len(self.pred_sents) > 1:
output += "\nBEST HYP:\n"
for score, sent in zip(self.pred_scores, self.pred_sents):
output += "[{:.4f}] {}\n".format(score, sent)
return output

Просмотреть файл

@ -43,7 +43,7 @@ TOKENIZER_CLASS.update(
MAX_SEQ_LEN = 512 MAX_SEQ_LEN = 512
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
fh = logging.FileHandler("abssum_train.log") fh = logging.FileHandler("longer_input_abssum_train.log")
logger.addHandler(fh) logger.addHandler(fh)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)