From cd8d8f00e94ef36fb553354a1dd88f04ffa30af5 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Wed, 26 Feb 2020 03:33:44 +0000 Subject: [PATCH 1/4] add start token and end token for encoded target --- utils_nlp/models/transformers/abssum.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/utils_nlp/models/transformers/abssum.py b/utils_nlp/models/transformers/abssum.py index 663d595..d6e5279 100644 --- a/utils_nlp/models/transformers/abssum.py +++ b/utils_nlp/models/transformers/abssum.py @@ -113,7 +113,7 @@ class AbsSumProcessor: model_name="bert-base-uncased", to_lower=False, cache_dir=".", - max_len=512, + max_src_len=512, max_target_len=140, ): """ Initialize the preprocessor. @@ -144,14 +144,11 @@ class AbsSumProcessor: self.sep_token = "[SEP]" self.cls_token = "[CLS]" self.pad_token = "[PAD]" - self.tgt_bos = "[unused0]" - self.tgt_eos = "[unused1]" + self.tgt_bos = self.symbols["BOS"] + self.tgt_eos = self.symbols["EOS"] - self.sep_vid = self.tokenizer.vocab[self.sep_token] - self.cls_vid = self.tokenizer.vocab[self.cls_token] - self.pad_vid = self.tokenizer.vocab[self.pad_token] - self.max_len = max_len + self.max_src_len = max_src_len self.max_target_len = max_target_len @staticmethod @@ -243,7 +240,8 @@ class AbsSumProcessor: if train_mode: encoded_summaries = torch.tensor( [ - fit_to_block_size(summary, block_size, self.tokenizer.pad_token_id) + [self.tgt_bos] + fit_to_block_size(summary, block_size-2, self.tokenizer.pad_token_id) + +[self.tgt_eos] for _, summary in encoded_text ] ) @@ -308,7 +306,7 @@ class AbsSumProcessor: try: if len(line) <= 0: continue - story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_len)) + story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_src_len)) except: print(line) raise @@ -325,9 +323,9 @@ class AbsSumProcessor: except: print(line) raise - summary_token_ids = [ + summary_token_ids = [ token for sentence in summary_lines_token_ids for token in sentence - ] + ] return story_token_ids, summary_token_ids else: return story_token_ids From 9f24c6bde7af1725638cb78cb067d96f7abeb63c Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Tue, 3 Mar 2020 05:57:37 +0000 Subject: [PATCH 2/4] enable customizable maximum input length --- .../bertabs_cnndm_distributed_train.py | 14 ++++++++++++-- utils_nlp/models/transformers/abssum.py | 13 +++++++++---- .../models/transformers/bertabs/model_builder.py | 2 +- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/text_summarization/bertabs_cnndm_distributed_train.py b/examples/text_summarization/bertabs_cnndm_distributed_train.py index 64c7886..dfb41dd 100644 --- a/examples/text_summarization/bertabs_cnndm_distributed_train.py +++ b/examples/text_summarization/bertabs_cnndm_distributed_train.py @@ -25,6 +25,12 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas from utils_nlp.eval.evaluate_summarization import get_rouge os.environ["NCCL_IB_DISABLE"] = "0" +#os.environ["NCCL_DEBUG"] = "INFO" +os.environ["NCCL_DEBUG_SUBSYS"] = "ALL" +#os.environ["MASTER_PORT"] = "29952" +#os.environ["MASTER_ADDR"] = "172.12.0.6" +#os.environ['NCCL_SOCKET_IFNAME'] = 'lo' + def shorten_dataset(dataset, top_n=-1): if top_n == -1: @@ -56,6 +62,8 @@ parser.add_argument("--lr_dec", type=float, default=2e-1, help="Learning rate for the decoder.") parser.add_argument("--batch_size", type=int, default=5, help="batch size in terms of input token numbers in training") +parser.add_argument("--max_pos", type=int, default=640, + help="maximum input length in terms of input token numbers in training") parser.add_argument("--max_steps", type=int, default=5e4, help="Maximum number of training steps run in training. If quick_run is set,\ it's not used.") @@ -109,7 +117,8 @@ def main(): ngpus_per_node = torch.cuda.device_count() processor = AbsSumProcessor(cache_dir=args.cache_dir) summarizer = AbsSum( - processor, cache_dir=args.cache_dir + processor, cache_dir=args.cache_dir, + max_pos=args.max_pos, ) mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, summarizer, args)) @@ -120,13 +129,14 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args): print("world_size is {}".format(world_size)) print("local_rank is {} and rank is {}".format(local_rank, rank)) - torch.distributed.init_process_group( backend="nccl", init_method=args.dist_url, world_size=world_size, rank=rank, ) + + # return ## should not load checkpoint from this place, otherwise, huge memory increase if args.checkpoint_filename: checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename) diff --git a/utils_nlp/models/transformers/abssum.py b/utils_nlp/models/transformers/abssum.py index d6e5279..2eeac79 100644 --- a/utils_nlp/models/transformers/abssum.py +++ b/utils_nlp/models/transformers/abssum.py @@ -364,6 +364,7 @@ class AbsSum(Transformer): cache_dir=".", label_smoothing=0.1, test=False, + max_pos=640, ): """Initialize a ExtractiveSummarizer. @@ -402,12 +403,13 @@ class AbsSum(Transformer): checkpoint=None, label_smoothing=label_smoothing, symbols=processor.symbols, - test=test + test=test, + max_pos=max_pos, ) self.processor = processor self.optim_bert = None self.optim_dec = None - + self.max_pos = max_pos @staticmethod @@ -421,7 +423,7 @@ class AbsSum(Transformer): train_dataset, num_gpus=None, gpu_ids=None, - batch_size=140, + batch_size=4, local_rank=-1, max_steps=5e5, warmup_steps_bert=8000, @@ -496,6 +498,7 @@ class AbsSum(Transformer): self.optim_bert = model_builder.build_optim_bert( self.model, + optim=optimization_method, visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3", lr_bert=learning_rate_bert, warmup_steps_bert=warmup_steps_bert, @@ -503,12 +506,14 @@ class AbsSum(Transformer): ) self.optim_dec = model_builder.build_optim_dec( self.model, + optim=optimization_method, visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3" lr_dec=learning_rate_dec, warmup_steps_dec=warmup_steps_dec, ) optimizers = [self.optim_bert, self.optim_dec] + self.amp = get_amp(fp16) if self.amp: self.model, optim = self.amp.initialize(self.model, optimizers, opt_level=fp16_opt_level) @@ -539,7 +544,7 @@ class AbsSum(Transformer): sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) def collate_fn(data): - return self.processor.collate(data, block_size=512, device=device) + return self.processor.collate(data, block_size=self.max_pos, device=device) train_dataloader = DataLoader( train_dataset, diff --git a/utils_nlp/models/transformers/bertabs/model_builder.py b/utils_nlp/models/transformers/bertabs/model_builder.py index 46437b3..cf9ca43 100644 --- a/utils_nlp/models/transformers/bertabs/model_builder.py +++ b/utils_nlp/models/transformers/bertabs/model_builder.py @@ -196,7 +196,7 @@ class AbsSummarizer(nn.Module): self.bert.model = BertModel(bert_config) if max_pos > 512: - my_pos_embeddi = nn.Embedding(max_pos, self.bert.model.config.hidden_size) + my_pos_embeddings = nn.Embedding(max_pos, self.bert.model.config.hidden_size) my_pos_embeddings.weight.data[ :512 ] = self.bert.model.embeddings.position_embeddings.weight.data From 0b42481d8b3a0221188bfe174514fa89ff073a38 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Tue, 3 Mar 2020 05:58:13 +0000 Subject: [PATCH 3/4] enable multiple schedulers --- utils_nlp/models/transformers/common.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/utils_nlp/models/transformers/common.py b/utils_nlp/models/transformers/common.py index a957b09..ce2a7b7 100755 --- a/utils_nlp/models/transformers/common.py +++ b/utils_nlp/models/transformers/common.py @@ -43,7 +43,7 @@ TOKENIZER_CLASS.update( MAX_SEQ_LEN = 512 logger = logging.getLogger(__name__) -fh = logging.FileHandler("abssum_train.log") +fh = logging.FileHandler("longer_input_abssum_train.log") logger.addHandler(fh) logger.setLevel(logging.INFO) @@ -292,13 +292,18 @@ class Transformer: accum_loss = 0 train_size = 0 start = end - if type(optimizer) == list: - for o in optimizer: - o.step() - else: - optimizer.step() + if optimizer: + if type(optimizer) == list: + for o in optimizer: + o.step() + else: + optimizer.step() if scheduler: - scheduler.step() + if type(scheduler) == list: + for s in scheduler: + s.step() + else: + scheduler.step() self.model.zero_grad() if save_every != -1 and global_step % save_every == 0 and verbose: From b527bdfcdb6ee7fbb0b55875e2df510b92719978 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Fri, 6 Mar 2020 04:26:02 +0000 Subject: [PATCH 4/4] enable multi-gpu inferencing for DataParallel and enable longer input length --- .../bertabs_cnndm_distributed_train.py | 5 +- tests/unit/test_bertabs_abssum.py | 16 +- utils_nlp/models/transformers/abssum.py | 47 +++-- .../models/transformers/bertabs/predictor.py | 185 ++---------------- 4 files changed, 60 insertions(+), 193 deletions(-) diff --git a/examples/text_summarization/bertabs_cnndm_distributed_train.py b/examples/text_summarization/bertabs_cnndm_distributed_train.py index dfb41dd..affc4a8 100644 --- a/examples/text_summarization/bertabs_cnndm_distributed_train.py +++ b/examples/text_summarization/bertabs_cnndm_distributed_train.py @@ -115,10 +115,9 @@ def main(): print("data_dir is {}".format(args.data_dir)) print("cache_dir is {}".format(args.cache_dir)) ngpus_per_node = torch.cuda.device_count() - processor = AbsSumProcessor(cache_dir=args.cache_dir) + processor = AbsSumProcessor(cache_dir=args.cache_dir, max_src_len=max_pos) summarizer = AbsSum( - processor, cache_dir=args.cache_dir, - max_pos=args.max_pos, + processor, cache_dir=args.cache_dir ) mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, summarizer, args)) diff --git a/tests/unit/test_bertabs_abssum.py b/tests/unit/test_bertabs_abssum.py index fe8c4ad..d8436cf 100644 --- a/tests/unit/test_bertabs_abssum.py +++ b/tests/unit/test_bertabs_abssum.py @@ -25,7 +25,7 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas from utils_nlp.eval.evaluate_summarization import get_rouge CACHE_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp" -DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization" +DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp" MODEL_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp" TOP_N = 10 @@ -259,10 +259,11 @@ def test_pretrained_model(): checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt")) #checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt")) - checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt")) + #checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt")) summarizer = AbsSum( processor, cache_dir=CACHE_PATH, + max_pos=512, ) summarizer.model.load_checkpoint(checkpoint['model']) """ @@ -284,14 +285,15 @@ def test_pretrained_model(): return """ - top_n = 8 + top_n = 96*4 src = test_sum_dataset.source[0:top_n] reference_summaries = ["".join(t).rstrip("\n") for t in test_sum_dataset.target[0:top_n]] print("start prediction") generated_summaries = summarizer.predict( - shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=4, num_gpus=2 + shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=96+16, num_gpus=1, max_seq_length=512 ) print(generated_summaries[0]) + print(len(generated_summaries)) assert len(generated_summaries) == len(reference_summaries) RESULT_DIR = TemporaryDirectory().name rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR) @@ -303,6 +305,6 @@ def test_pretrained_model(): #test_collate() #preprocess_cnndm_abs() #test_train_model() -#test_pretrained_model() -if __name__ == "__main__": - main() +test_pretrained_model() +#if __name__ == "__main__": +# main() diff --git a/utils_nlp/models/transformers/abssum.py b/utils_nlp/models/transformers/abssum.py index 2eeac79..84db75d 100644 --- a/utils_nlp/models/transformers/abssum.py +++ b/utils_nlp/models/transformers/abssum.py @@ -113,7 +113,7 @@ class AbsSumProcessor: model_name="bert-base-uncased", to_lower=False, cache_dir=".", - max_src_len=512, + max_src_len=640, max_target_len=140, ): """ Initialize the preprocessor. @@ -364,7 +364,7 @@ class AbsSum(Transformer): cache_dir=".", label_smoothing=0.1, test=False, - max_pos=640, + max_pos=768, ): """Initialize a ExtractiveSummarizer. @@ -396,6 +396,7 @@ class AbsSum(Transformer): self.model_class = MODEL_CLASS[model_name] self.cache_dir = cache_dir + self.max_pos = max_pos self.model = AbsSummarizer( temp_dir=cache_dir, @@ -404,12 +405,11 @@ class AbsSum(Transformer): label_smoothing=label_smoothing, symbols=processor.symbols, test=test, - max_pos=max_pos, + max_pos=self.max_pos, ) self.processor = processor self.optim_bert = None self.optim_dec = None - self.max_pos = max_pos @staticmethod @@ -588,13 +588,13 @@ class AbsSum(Transformer): local_rank=-1, gpu_ids=None, batch_size=16, - # sentence_separator="", alpha=0.6, beam_size=5, min_length=15, max_length=150, fp16=False, verbose=True, + max_seq_length=768, ): """ Predict the summarization for the input data iterator. @@ -621,11 +621,11 @@ class AbsSum(Transformer): List of strings which are the summaries """ - device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank) - # move model to devices + device, num_gpus = get_device(num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank) + # move model to devices def this_model_move_callback(model, device): model = move_model_to_device(model, device) - return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=None, local_rank=local_rank) + return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank) if fp16: self.model = self.model.half() @@ -643,13 +643,13 @@ class AbsSum(Transformer): min_length=min_length, max_length=max_length, ) - predictor = predictor.move_to_device(device, this_model_move_callback) + predictor = this_model_move_callback(predictor, device) test_sampler = SequentialSampler(test_dataset) def collate_fn(data): - return self.processor.collate(data, 512, device, train_mode=False) + return self.processor.collate(data, max_seq_length, device, train_mode=False) test_dataloader = DataLoader( test_dataset, sampler=test_sampler, batch_size=batch_size, collate_fn=collate_fn, @@ -659,7 +659,7 @@ class AbsSum(Transformer): """ Transforms the output of the `from_batch` function into nicely formatted summaries. """ - raw_summary, _, = translation + raw_summary = translation summary = ( raw_summary.replace("[unused0]", "") .replace("[unused3]", "") @@ -673,16 +673,31 @@ class AbsSum(Transformer): .strip() ) + return summary + def generate_summary_from_tokenid(preds, pred_score): + batch_size = preds.size()[0] # batch.batch_size + translations = [] + for b in range(batch_size): + if len(preds[b]) < 1: + pred_sents = "" + else: + pred_sents = self.processor.tokenizer.convert_ids_to_tokens([int(n) for n in preds[b] if int(n)!=0]) + pred_sents = " ".join(pred_sents).replace(" ##", "") + translations.append(pred_sents) + return translations + generated_summaries = [] from tqdm import tqdm - for batch in tqdm(test_dataloader): - batch_data = predictor.translate_batch(batch) - translations = predictor.from_batch(batch_data) - summaries = [format_summary(t) for t in translations] - generated_summaries += summaries + for batch in tqdm(test_dataloader, desc="Generating summary", disable=not verbose): + input = self.processor.get_inputs(batch, device, "bert", train_mode=False) + translations, scores = predictor(**input) + + translations_text = generate_summary_from_tokenid(translations, scores) + summaries = [format_summary(t) for t in translations_text] + generated_summaries.extend(summaries) return generated_summaries def save_model(self, global_step=None, full_name=None): diff --git a/utils_nlp/models/transformers/bertabs/predictor.py b/utils_nlp/models/transformers/bertabs/predictor.py index 841cdc0..990fa6a 100644 --- a/utils_nlp/models/transformers/bertabs/predictor.py +++ b/utils_nlp/models/transformers/bertabs/predictor.py @@ -6,7 +6,7 @@ import os import math import torch - +from torch import nn from tensorboardX import SummaryWriter # from others.utils import rouge_results_to_str, test_rouge, tile @@ -55,7 +55,7 @@ def tile(x, count, dim=0): return x -class Translator(object): +class Translator(nn.Module): """ Uses a model to translate a batch of sentences. @@ -70,14 +70,12 @@ class Translator(object): global_scores (:obj:`GlobalScorer`): object to rescore final translations copy_attn (bool): use copy attention during translation - cuda (bool): use cuda beam_trace (bool): trace beam search for debugging logger(logging.Logger): logger. """ def __init__( self, - # args, beam_size, min_length, max_length, @@ -89,10 +87,9 @@ class Translator(object): logger=None, dump_beam="", ): + super(Translator, self).__init__() self.logger = logger - self.cuda = 0 # args.visible_gpus != '-1' - # self.args = args self.model = model.module if hasattr(model, "module") else model self.generator = self.model.generator self.decoder = self.model.decoder @@ -115,10 +112,6 @@ class Translator(object): self.beam_trace = self.dump_beam != "" self.beam_accum = None - # tensorboard_log_dir = args.model_path - - # self.tensorboard_writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") - if self.beam_trace: self.beam_accum = { "predicted_ids": [], @@ -126,90 +119,15 @@ class Translator(object): "scores": [], "log_probs": [], } - - def move_to_device(self, device, move_to_device_fn): - self.move_to_device_fn = move_to_device_fn - self.model = move_to_device_fn(self.model, device) - self.bert = move_to_device_fn(self.bert, device) - self.generator = move_to_device_fn(self.generator, device) - return self - + """ def eval(self): self.model.eval() self.bert.eval() self.decoder.eval() self.generator.eval() + """ - def _build_target_tokens(self, pred): - # vocab = self.fields["tgt"].vocab - tokens = [] - for tok in pred: - tok = int(tok) - tokens.append(tok) - if tokens[-1] == self.end_token: - tokens = tokens[:-1] - break - tokens = [t for t in tokens if t < len(self.vocab)] - tokens = self.vocab.DecodeIds(tokens).split(" ") - return tokens - - def from_batch(self, translation_batch): - batch = translation_batch["batch"] - # assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"]) - batch_size = batch.batch_size - - # preds, pred_score, gold_score, tgt_str, src = ( - preds, pred_score, src = ( - translation_batch["predictions"], - translation_batch["scores"], - # translation_batch["gold_score"], - # batch.tgt_str, - batch.src, - ) - # print(preds) - # print(pred_score) - # print(batch.tgt_str) - # print(batch.src) - - translations = [] - for b in range(batch_size): - - if len(preds[b]) < 1: - pred_sents = "" - else: - pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]]) - pred_sents = " ".join(pred_sents).replace(" ##", "") - # gold_sent = " ".join(tgt_str[b].split()) - # translation = Translation(fname[b],src[:, b] if src is not None else None, - # src_raw, pred_sents, - # attn[b], pred_score[b], gold_sent, - # gold_score[b]) - # src = self.spm.DecodeIds([int(t) for t in translation_batch['batch'].src[0][5] if int(t) != len(self.spm)]) - raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500] - raw_src = " ".join(raw_src) - # translation = (pred_sents, gold_sent, raw_src) - translation = (pred_sents, raw_src) - # translation = (pred_sents[0], gold_sent) - translations.append(translation) - - return translations - - def translate(self, batch, attn_debug=False): - - #self.model.eval() - self.eval() - # pred_results, gold_results = [], [] - with torch.no_grad(): - batch_data = self.translate_batch(batch) - translations = self.from_batch(batch_data) - - # for trans in translations: - # pred, gold, src = trans - # pred_str = pred.replace('[unused0]', '').replace('[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace(r' +', ' ').replace(' [unused2] ', '').replace('[unused2]', '').strip() - # gold_str = gold.strip() - return translations - - def translate_batch(self, batch, fast=False): + def forward(self, src, segs, mask_src): """ Translate a batch of sentences. @@ -224,22 +142,22 @@ class Translator(object): Shouldn't need the original dataset. """ with torch.no_grad(): - return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length) + predictions, scores = self._fast_translate_batch(src, segs, mask_src, self.max_length, min_length=self.min_length) + return predictions, scores + - def _fast_translate_batch(self, batch, max_length, min_length=0): + def _fast_translate_batch(self, src, segs, mask_src, max_length, min_length=0): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert not self.dump_beam beam_size = self.beam_size - batch_size = batch.batch_size - src = batch.src - segs = batch.segs - mask_src = batch.mask_src + batch_size = src.size()[0] #32 #batch.batch_size src_features = self.bert(src, segs, mask_src) - dec_states = self.decoder.init_decoder_state(src, src_features, with_cache=True) + this_decoder = self.decoder.module if hasattr(self.decoder, "module") else self.decoder + dec_states = this_decoder.init_decoder_state(src, src_features, with_cache=True) device = src_features.device @@ -266,7 +184,7 @@ class Translator(object): results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 # results["gold_score"] = [0] * batch_size - results["batch"] = batch + #results["batch"] = batch for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1) @@ -274,7 +192,7 @@ class Translator(object): # Decoder forward. decoder_input = decoder_input.transpose(0, 1) - dec_out, dec_states = self.decoder( + dec_out, dec_states = this_decoder( decoder_input, src_features, dec_states, step=step ) @@ -321,13 +239,11 @@ class Translator(object): # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) - # print("topk_beam_index.shape {}".format( topk_beam_index.size())) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1) select_indices = batch_index.view(-1) - # print("select_indices {}".format(select_indices)) # Append last prediction. alive_seq = torch.cat( @@ -335,15 +251,10 @@ class Translator(object): ) is_finished = topk_ids.eq(self.end_token) - # print("is_finished {}".format(is_finished)) - # print("is_finished size {}".format(is_finished.size())) if step + 1 == max_length: - # print("reached max_length {} at step {}".format(max_length, step)) is_finished.fill_(True) - # print("is_finished {}".format(is_finished)) # End condition is top beam is finished. end_condition = is_finished[:, 0].eq(True) - # print("end_condition {}".format(end_condition)) if step + 1 == max_length: assert not any(end_condition.eq(False)) @@ -353,7 +264,6 @@ class Translator(object): for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: - # print("batch offset {} finished".format(b)) is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. @@ -363,9 +273,6 @@ class Translator(object): if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) score, pred = best_hyp[0] - # if len(pred) == 0: - # print("batch offset {} finished with empty prediction {}".format(b, pred)) - results["scores"][b].append(score) results["predictions"][b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) @@ -383,63 +290,7 @@ class Translator(object): dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices)) empty_output = [len(results["predictions"][b]) <= 0 for b in batch_offset] - if any(empty_output): - print("there is empty output {}".format(empty_output)) - print(batch_offset) - print(results) + predictions = torch.tensor([i[0].tolist()[0:self.max_length]+[0]*(self.max_length-i[0].size()[0]) for i in results["predictions"]], device=device) + scores = torch.tensor([i[0].item() for i in results['scores']], device=device) + return predictions, scores - # print("###########################################") - return results - - -class Translation(object): - """ - Container for a translated sentence. - - Attributes: - src (`LongTensor`): src word ids - src_raw ([str]): raw src words - - pred_sents ([[str]]): words from the n-best translations - pred_scores ([[float]]): log-probs of n-best translations - attns ([`FloatTensor`]) : attention dist for each translation - gold_sent ([str]): words from gold translation - gold_score ([float]): log-prob of gold translation - - """ - - def __init__( - self, fname, src, src_raw, pred_sents, attn, pred_scores, tgt_sent=None, gold_score=0 - ): - self.fname = fname - self.src = src - self.src_raw = src_raw - self.pred_sents = pred_sents - self.attns = attn - self.pred_scores = pred_scores - self.gold_sent = tgt_sent - self.gold_score = gold_score - - def log(self, sent_number): - """ - Log translation. - """ - - output = "\nSENT {}: {}\n".format(sent_number, self.src_raw) - - best_pred = self.pred_sents[0] - best_score = self.pred_scores[0] - pred_sent = " ".join(best_pred) - output += "PRED {}: {}\n".format(sent_number, pred_sent) - output += "PRED SCORE: {:.4f}\n".format(best_score) - - if self.gold_sent is not None: - tgt_sent = " ".join(self.gold_sent) - output += "GOLD {}: {}\n".format(sent_number, tgt_sent) - output += "GOLD SCORE: {:.4f}\n".format(self.gold_score) - if len(self.pred_sents) > 1: - output += "\nBEST HYP:\n" - for score, sent in zip(self.pred_scores, self.pred_sents): - output += "[{:.4f}] {}\n".format(score, sent) - - return output