Merge branch 'daden/presumm' of https://github.com/microsoft/nlp-recipes into daden/presumm
This commit is contained in:
Коммит
fd84b4b32f
|
@ -25,6 +25,12 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas
|
|||
from utils_nlp.eval.evaluate_summarization import get_rouge
|
||||
|
||||
os.environ["NCCL_IB_DISABLE"] = "0"
|
||||
#os.environ["NCCL_DEBUG"] = "INFO"
|
||||
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"
|
||||
#os.environ["MASTER_PORT"] = "29952"
|
||||
#os.environ["MASTER_ADDR"] = "172.12.0.6"
|
||||
#os.environ['NCCL_SOCKET_IFNAME'] = 'lo'
|
||||
|
||||
|
||||
def shorten_dataset(dataset, top_n=-1):
|
||||
if top_n == -1:
|
||||
|
@ -56,6 +62,8 @@ parser.add_argument("--lr_dec", type=float, default=2e-1,
|
|||
help="Learning rate for the decoder.")
|
||||
parser.add_argument("--batch_size", type=int, default=5,
|
||||
help="batch size in terms of input token numbers in training")
|
||||
parser.add_argument("--max_pos", type=int, default=640,
|
||||
help="maximum input length in terms of input token numbers in training")
|
||||
parser.add_argument("--max_steps", type=int, default=5e4,
|
||||
help="Maximum number of training steps run in training. If quick_run is set,\
|
||||
it's not used.")
|
||||
|
@ -107,7 +115,7 @@ def main():
|
|||
print("data_dir is {}".format(args.data_dir))
|
||||
print("cache_dir is {}".format(args.cache_dir))
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
processor = AbsSumProcessor(cache_dir=args.cache_dir)
|
||||
processor = AbsSumProcessor(cache_dir=args.cache_dir, max_src_len=max_pos)
|
||||
summarizer = AbsSum(
|
||||
processor, cache_dir=args.cache_dir
|
||||
)
|
||||
|
@ -120,13 +128,14 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
|
|||
print("world_size is {}".format(world_size))
|
||||
print("local_rank is {} and rank is {}".format(local_rank, rank))
|
||||
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=args.dist_url,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
# return
|
||||
## should not load checkpoint from this place, otherwise, huge memory increase
|
||||
if args.checkpoint_filename:
|
||||
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
|
||||
|
|
|
@ -25,7 +25,7 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas
|
|||
from utils_nlp.eval.evaluate_summarization import get_rouge
|
||||
|
||||
CACHE_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
||||
DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization"
|
||||
DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
||||
MODEL_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
||||
TOP_N = 10
|
||||
|
||||
|
@ -259,10 +259,11 @@ def test_pretrained_model():
|
|||
checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt"))
|
||||
|
||||
#checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt"))
|
||||
checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt"))
|
||||
#checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt"))
|
||||
summarizer = AbsSum(
|
||||
processor,
|
||||
cache_dir=CACHE_PATH,
|
||||
max_pos=512,
|
||||
)
|
||||
summarizer.model.load_checkpoint(checkpoint['model'])
|
||||
"""
|
||||
|
@ -284,14 +285,15 @@ def test_pretrained_model():
|
|||
return
|
||||
"""
|
||||
|
||||
top_n = 8
|
||||
top_n = 96*4
|
||||
src = test_sum_dataset.source[0:top_n]
|
||||
reference_summaries = ["".join(t).rstrip("\n") for t in test_sum_dataset.target[0:top_n]]
|
||||
print("start prediction")
|
||||
generated_summaries = summarizer.predict(
|
||||
shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=4, num_gpus=2
|
||||
shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=96+16, num_gpus=1, max_seq_length=512
|
||||
)
|
||||
print(generated_summaries[0])
|
||||
print(len(generated_summaries))
|
||||
assert len(generated_summaries) == len(reference_summaries)
|
||||
RESULT_DIR = TemporaryDirectory().name
|
||||
rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR)
|
||||
|
@ -303,6 +305,6 @@ def test_pretrained_model():
|
|||
#test_collate()
|
||||
#preprocess_cnndm_abs()
|
||||
#test_train_model()
|
||||
#test_pretrained_model()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
test_pretrained_model()
|
||||
#if __name__ == "__main__":
|
||||
# main()
|
||||
|
|
|
@ -113,7 +113,7 @@ class AbsSumProcessor:
|
|||
model_name="bert-base-uncased",
|
||||
to_lower=False,
|
||||
cache_dir=".",
|
||||
max_len=512,
|
||||
max_src_len=640,
|
||||
max_target_len=140,
|
||||
):
|
||||
""" Initialize the preprocessor.
|
||||
|
@ -144,14 +144,11 @@ class AbsSumProcessor:
|
|||
self.sep_token = "[SEP]"
|
||||
self.cls_token = "[CLS]"
|
||||
self.pad_token = "[PAD]"
|
||||
self.tgt_bos = "[unused0]"
|
||||
self.tgt_eos = "[unused1]"
|
||||
self.tgt_bos = self.symbols["BOS"]
|
||||
self.tgt_eos = self.symbols["EOS"]
|
||||
|
||||
self.sep_vid = self.tokenizer.vocab[self.sep_token]
|
||||
self.cls_vid = self.tokenizer.vocab[self.cls_token]
|
||||
self.pad_vid = self.tokenizer.vocab[self.pad_token]
|
||||
|
||||
self.max_len = max_len
|
||||
self.max_src_len = max_src_len
|
||||
self.max_target_len = max_target_len
|
||||
|
||||
@staticmethod
|
||||
|
@ -243,7 +240,8 @@ class AbsSumProcessor:
|
|||
if train_mode:
|
||||
encoded_summaries = torch.tensor(
|
||||
[
|
||||
fit_to_block_size(summary, block_size, self.tokenizer.pad_token_id)
|
||||
[self.tgt_bos] + fit_to_block_size(summary, block_size-2, self.tokenizer.pad_token_id)
|
||||
+[self.tgt_eos]
|
||||
for _, summary in encoded_text
|
||||
]
|
||||
)
|
||||
|
@ -308,7 +306,7 @@ class AbsSumProcessor:
|
|||
try:
|
||||
if len(line) <= 0:
|
||||
continue
|
||||
story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_len))
|
||||
story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_src_len))
|
||||
except:
|
||||
print(line)
|
||||
raise
|
||||
|
@ -366,6 +364,7 @@ class AbsSum(Transformer):
|
|||
cache_dir=".",
|
||||
label_smoothing=0.1,
|
||||
test=False,
|
||||
max_pos=768,
|
||||
):
|
||||
"""Initialize a ExtractiveSummarizer.
|
||||
|
||||
|
@ -397,6 +396,7 @@ class AbsSum(Transformer):
|
|||
|
||||
self.model_class = MODEL_CLASS[model_name]
|
||||
self.cache_dir = cache_dir
|
||||
self.max_pos = max_pos
|
||||
|
||||
self.model = AbsSummarizer(
|
||||
temp_dir=cache_dir,
|
||||
|
@ -404,14 +404,14 @@ class AbsSum(Transformer):
|
|||
checkpoint=None,
|
||||
label_smoothing=label_smoothing,
|
||||
symbols=processor.symbols,
|
||||
test=test
|
||||
test=test,
|
||||
max_pos=self.max_pos,
|
||||
)
|
||||
self.processor = processor
|
||||
self.optim_bert = None
|
||||
self.optim_dec = None
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_supported_models():
|
||||
return list(MODEL_CLASS.keys())
|
||||
|
@ -423,7 +423,7 @@ class AbsSum(Transformer):
|
|||
train_dataset,
|
||||
num_gpus=None,
|
||||
gpu_ids=None,
|
||||
batch_size=140,
|
||||
batch_size=4,
|
||||
local_rank=-1,
|
||||
max_steps=5e5,
|
||||
warmup_steps_bert=8000,
|
||||
|
@ -499,6 +499,7 @@ class AbsSum(Transformer):
|
|||
#"""
|
||||
self.optim_bert = model_builder.build_optim_bert(
|
||||
self.model,
|
||||
optim=optimization_method,
|
||||
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3",
|
||||
lr_bert=learning_rate_bert,
|
||||
warmup_steps_bert=warmup_steps_bert,
|
||||
|
@ -506,6 +507,7 @@ class AbsSum(Transformer):
|
|||
)
|
||||
self.optim_dec = model_builder.build_optim_dec(
|
||||
self.model,
|
||||
optim=optimization_method,
|
||||
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3"
|
||||
lr_dec=learning_rate_dec,
|
||||
warmup_steps_dec=warmup_steps_dec,
|
||||
|
@ -518,6 +520,7 @@ class AbsSum(Transformer):
|
|||
optimizers = [self.optim_bert, self.optim_dec]
|
||||
schedulers = [self.scheduler_bert, self.scheduler_dec]
|
||||
|
||||
|
||||
self.amp = get_amp(fp16)
|
||||
if self.amp:
|
||||
self.model, optim = self.amp.initialize(self.model, optimizers, opt_level=fp16_opt_level)
|
||||
|
@ -548,7 +551,7 @@ class AbsSum(Transformer):
|
|||
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
||||
|
||||
def collate_fn(data):
|
||||
return self.processor.collate(data, block_size=512, device=device)
|
||||
return self.processor.collate(data, block_size=self.max_pos, device=device)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
|
@ -592,13 +595,13 @@ class AbsSum(Transformer):
|
|||
local_rank=-1,
|
||||
gpu_ids=None,
|
||||
batch_size=16,
|
||||
# sentence_separator="<q>",
|
||||
alpha=0.6,
|
||||
beam_size=5,
|
||||
min_length=15,
|
||||
max_length=150,
|
||||
fp16=False,
|
||||
verbose=True,
|
||||
max_seq_length=768,
|
||||
):
|
||||
"""
|
||||
Predict the summarization for the input data iterator.
|
||||
|
@ -625,11 +628,11 @@ class AbsSum(Transformer):
|
|||
List of strings which are the summaries
|
||||
|
||||
"""
|
||||
device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank)
|
||||
device, num_gpus = get_device(num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank)
|
||||
# move model to devices
|
||||
def this_model_move_callback(model, device):
|
||||
model = move_model_to_device(model, device)
|
||||
return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=None, local_rank=local_rank)
|
||||
return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank)
|
||||
|
||||
if fp16:
|
||||
self.model = self.model.half()
|
||||
|
@ -647,13 +650,13 @@ class AbsSum(Transformer):
|
|||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
)
|
||||
predictor = predictor.move_to_device(device, this_model_move_callback)
|
||||
predictor = this_model_move_callback(predictor, device)
|
||||
|
||||
|
||||
test_sampler = SequentialSampler(test_dataset)
|
||||
|
||||
def collate_fn(data):
|
||||
return self.processor.collate(data, 512, device, train_mode=False)
|
||||
return self.processor.collate(data, max_seq_length, device, train_mode=False)
|
||||
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset, sampler=test_sampler, batch_size=batch_size, collate_fn=collate_fn,
|
||||
|
@ -663,7 +666,7 @@ class AbsSum(Transformer):
|
|||
""" Transforms the output of the `from_batch` function
|
||||
into nicely formatted summaries.
|
||||
"""
|
||||
raw_summary, _, = translation
|
||||
raw_summary = translation
|
||||
summary = (
|
||||
raw_summary.replace("[unused0]", "")
|
||||
.replace("[unused3]", "")
|
||||
|
@ -677,16 +680,31 @@ class AbsSum(Transformer):
|
|||
.strip()
|
||||
)
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
def generate_summary_from_tokenid(preds, pred_score):
|
||||
batch_size = preds.size()[0] # batch.batch_size
|
||||
translations = []
|
||||
for b in range(batch_size):
|
||||
if len(preds[b]) < 1:
|
||||
pred_sents = ""
|
||||
else:
|
||||
pred_sents = self.processor.tokenizer.convert_ids_to_tokens([int(n) for n in preds[b] if int(n)!=0])
|
||||
pred_sents = " ".join(pred_sents).replace(" ##", "")
|
||||
translations.append(pred_sents)
|
||||
return translations
|
||||
|
||||
generated_summaries = []
|
||||
from tqdm import tqdm
|
||||
|
||||
for batch in tqdm(test_dataloader):
|
||||
batch_data = predictor.translate_batch(batch)
|
||||
translations = predictor.from_batch(batch_data)
|
||||
summaries = [format_summary(t) for t in translations]
|
||||
generated_summaries += summaries
|
||||
for batch in tqdm(test_dataloader, desc="Generating summary", disable=not verbose):
|
||||
input = self.processor.get_inputs(batch, device, "bert", train_mode=False)
|
||||
translations, scores = predictor(**input)
|
||||
|
||||
translations_text = generate_summary_from_tokenid(translations, scores)
|
||||
summaries = [format_summary(t) for t in translations_text]
|
||||
generated_summaries.extend(summaries)
|
||||
return generated_summaries
|
||||
|
||||
def save_model(self, global_step=None, full_name=None):
|
||||
|
|
|
@ -196,7 +196,7 @@ class AbsSummarizer(nn.Module):
|
|||
self.bert.model = BertModel(bert_config)
|
||||
|
||||
if max_pos > 512:
|
||||
my_pos_embeddi = nn.Embedding(max_pos, self.bert.model.config.hidden_size)
|
||||
my_pos_embeddings = nn.Embedding(max_pos, self.bert.model.config.hidden_size)
|
||||
my_pos_embeddings.weight.data[
|
||||
:512
|
||||
] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||
|
|
|
@ -6,7 +6,7 @@ import os
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
# from others.utils import rouge_results_to_str, test_rouge, tile
|
||||
|
@ -55,7 +55,7 @@ def tile(x, count, dim=0):
|
|||
return x
|
||||
|
||||
|
||||
class Translator(object):
|
||||
class Translator(nn.Module):
|
||||
"""
|
||||
Uses a model to translate a batch of sentences.
|
||||
|
||||
|
@ -70,14 +70,12 @@ class Translator(object):
|
|||
global_scores (:obj:`GlobalScorer`):
|
||||
object to rescore final translations
|
||||
copy_attn (bool): use copy attention during translation
|
||||
cuda (bool): use cuda
|
||||
beam_trace (bool): trace beam search for debugging
|
||||
logger(logging.Logger): logger.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# args,
|
||||
beam_size,
|
||||
min_length,
|
||||
max_length,
|
||||
|
@ -89,10 +87,9 @@ class Translator(object):
|
|||
logger=None,
|
||||
dump_beam="",
|
||||
):
|
||||
super(Translator, self).__init__()
|
||||
self.logger = logger
|
||||
self.cuda = 0 # args.visible_gpus != '-1'
|
||||
|
||||
# self.args = args
|
||||
self.model = model.module if hasattr(model, "module") else model
|
||||
self.generator = self.model.generator
|
||||
self.decoder = self.model.decoder
|
||||
|
@ -115,10 +112,6 @@ class Translator(object):
|
|||
self.beam_trace = self.dump_beam != ""
|
||||
self.beam_accum = None
|
||||
|
||||
# tensorboard_log_dir = args.model_path
|
||||
|
||||
# self.tensorboard_writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")
|
||||
|
||||
if self.beam_trace:
|
||||
self.beam_accum = {
|
||||
"predicted_ids": [],
|
||||
|
@ -126,90 +119,15 @@ class Translator(object):
|
|||
"scores": [],
|
||||
"log_probs": [],
|
||||
}
|
||||
|
||||
def move_to_device(self, device, move_to_device_fn):
|
||||
self.move_to_device_fn = move_to_device_fn
|
||||
self.model = move_to_device_fn(self.model, device)
|
||||
self.bert = move_to_device_fn(self.bert, device)
|
||||
self.generator = move_to_device_fn(self.generator, device)
|
||||
return self
|
||||
|
||||
"""
|
||||
def eval(self):
|
||||
self.model.eval()
|
||||
self.bert.eval()
|
||||
self.decoder.eval()
|
||||
self.generator.eval()
|
||||
"""
|
||||
|
||||
def _build_target_tokens(self, pred):
|
||||
# vocab = self.fields["tgt"].vocab
|
||||
tokens = []
|
||||
for tok in pred:
|
||||
tok = int(tok)
|
||||
tokens.append(tok)
|
||||
if tokens[-1] == self.end_token:
|
||||
tokens = tokens[:-1]
|
||||
break
|
||||
tokens = [t for t in tokens if t < len(self.vocab)]
|
||||
tokens = self.vocab.DecodeIds(tokens).split(" ")
|
||||
return tokens
|
||||
|
||||
def from_batch(self, translation_batch):
|
||||
batch = translation_batch["batch"]
|
||||
# assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"])
|
||||
batch_size = batch.batch_size
|
||||
|
||||
# preds, pred_score, gold_score, tgt_str, src = (
|
||||
preds, pred_score, src = (
|
||||
translation_batch["predictions"],
|
||||
translation_batch["scores"],
|
||||
# translation_batch["gold_score"],
|
||||
# batch.tgt_str,
|
||||
batch.src,
|
||||
)
|
||||
# print(preds)
|
||||
# print(pred_score)
|
||||
# print(batch.tgt_str)
|
||||
# print(batch.src)
|
||||
|
||||
translations = []
|
||||
for b in range(batch_size):
|
||||
|
||||
if len(preds[b]) < 1:
|
||||
pred_sents = ""
|
||||
else:
|
||||
pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]])
|
||||
pred_sents = " ".join(pred_sents).replace(" ##", "")
|
||||
# gold_sent = " ".join(tgt_str[b].split())
|
||||
# translation = Translation(fname[b],src[:, b] if src is not None else None,
|
||||
# src_raw, pred_sents,
|
||||
# attn[b], pred_score[b], gold_sent,
|
||||
# gold_score[b])
|
||||
# src = self.spm.DecodeIds([int(t) for t in translation_batch['batch'].src[0][5] if int(t) != len(self.spm)])
|
||||
raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500]
|
||||
raw_src = " ".join(raw_src)
|
||||
# translation = (pred_sents, gold_sent, raw_src)
|
||||
translation = (pred_sents, raw_src)
|
||||
# translation = (pred_sents[0], gold_sent)
|
||||
translations.append(translation)
|
||||
|
||||
return translations
|
||||
|
||||
def translate(self, batch, attn_debug=False):
|
||||
|
||||
#self.model.eval()
|
||||
self.eval()
|
||||
# pred_results, gold_results = [], []
|
||||
with torch.no_grad():
|
||||
batch_data = self.translate_batch(batch)
|
||||
translations = self.from_batch(batch_data)
|
||||
|
||||
# for trans in translations:
|
||||
# pred, gold, src = trans
|
||||
# pred_str = pred.replace('[unused0]', '').replace('[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace(r' +', ' ').replace(' [unused2] ', '<q>').replace('[unused2]', '').strip()
|
||||
# gold_str = gold.strip()
|
||||
return translations
|
||||
|
||||
def translate_batch(self, batch, fast=False):
|
||||
def forward(self, src, segs, mask_src):
|
||||
"""
|
||||
Translate a batch of sentences.
|
||||
|
||||
|
@ -224,22 +142,22 @@ class Translator(object):
|
|||
Shouldn't need the original dataset.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
||||
predictions, scores = self._fast_translate_batch(src, segs, mask_src, self.max_length, min_length=self.min_length)
|
||||
return predictions, scores
|
||||
|
||||
def _fast_translate_batch(self, batch, max_length, min_length=0):
|
||||
|
||||
def _fast_translate_batch(self, src, segs, mask_src, max_length, min_length=0):
|
||||
# TODO: faster code path for beam_size == 1.
|
||||
|
||||
# TODO: support these blacklisted features.
|
||||
assert not self.dump_beam
|
||||
|
||||
beam_size = self.beam_size
|
||||
batch_size = batch.batch_size
|
||||
src = batch.src
|
||||
segs = batch.segs
|
||||
mask_src = batch.mask_src
|
||||
batch_size = src.size()[0] #32 #batch.batch_size
|
||||
|
||||
src_features = self.bert(src, segs, mask_src)
|
||||
dec_states = self.decoder.init_decoder_state(src, src_features, with_cache=True)
|
||||
this_decoder = self.decoder.module if hasattr(self.decoder, "module") else self.decoder
|
||||
dec_states = this_decoder.init_decoder_state(src, src_features, with_cache=True)
|
||||
|
||||
device = src_features.device
|
||||
|
||||
|
@ -266,7 +184,7 @@ class Translator(object):
|
|||
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
|
||||
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
|
||||
# results["gold_score"] = [0] * batch_size
|
||||
results["batch"] = batch
|
||||
#results["batch"] = batch
|
||||
|
||||
for step in range(max_length):
|
||||
decoder_input = alive_seq[:, -1].view(1, -1)
|
||||
|
@ -274,7 +192,7 @@ class Translator(object):
|
|||
# Decoder forward.
|
||||
decoder_input = decoder_input.transpose(0, 1)
|
||||
|
||||
dec_out, dec_states = self.decoder(
|
||||
dec_out, dec_states = this_decoder(
|
||||
decoder_input, src_features, dec_states, step=step
|
||||
)
|
||||
|
||||
|
@ -321,13 +239,11 @@ class Translator(object):
|
|||
|
||||
# Resolve beam origin and true word ids.
|
||||
topk_beam_index = topk_ids.div(vocab_size)
|
||||
# print("topk_beam_index.shape {}".format( topk_beam_index.size()))
|
||||
topk_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Map beam_index to batch_index in the flat representation.
|
||||
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
||||
select_indices = batch_index.view(-1)
|
||||
# print("select_indices {}".format(select_indices))
|
||||
|
||||
# Append last prediction.
|
||||
alive_seq = torch.cat(
|
||||
|
@ -335,15 +251,10 @@ class Translator(object):
|
|||
)
|
||||
|
||||
is_finished = topk_ids.eq(self.end_token)
|
||||
# print("is_finished {}".format(is_finished))
|
||||
# print("is_finished size {}".format(is_finished.size()))
|
||||
if step + 1 == max_length:
|
||||
# print("reached max_length {} at step {}".format(max_length, step))
|
||||
is_finished.fill_(True)
|
||||
# print("is_finished {}".format(is_finished))
|
||||
# End condition is top beam is finished.
|
||||
end_condition = is_finished[:, 0].eq(True)
|
||||
# print("end_condition {}".format(end_condition))
|
||||
if step + 1 == max_length:
|
||||
assert not any(end_condition.eq(False))
|
||||
|
||||
|
@ -353,7 +264,6 @@ class Translator(object):
|
|||
for i in range(is_finished.size(0)):
|
||||
b = batch_offset[i]
|
||||
if end_condition[i]:
|
||||
# print("batch offset {} finished".format(b))
|
||||
is_finished[i].fill_(1)
|
||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||
# Store finished hypotheses for this batch.
|
||||
|
@ -363,9 +273,6 @@ class Translator(object):
|
|||
if end_condition[i]:
|
||||
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
|
||||
score, pred = best_hyp[0]
|
||||
# if len(pred) == 0:
|
||||
# print("batch offset {} finished with empty prediction {}".format(b, pred))
|
||||
|
||||
results["scores"][b].append(score)
|
||||
results["predictions"][b].append(pred)
|
||||
non_finished = end_condition.eq(0).nonzero().view(-1)
|
||||
|
@ -383,63 +290,7 @@ class Translator(object):
|
|||
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
||||
|
||||
empty_output = [len(results["predictions"][b]) <= 0 for b in batch_offset]
|
||||
if any(empty_output):
|
||||
print("there is empty output {}".format(empty_output))
|
||||
print(batch_offset)
|
||||
print(results)
|
||||
predictions = torch.tensor([i[0].tolist()[0:self.max_length]+[0]*(self.max_length-i[0].size()[0]) for i in results["predictions"]], device=device)
|
||||
scores = torch.tensor([i[0].item() for i in results['scores']], device=device)
|
||||
return predictions, scores
|
||||
|
||||
# print("###########################################")
|
||||
return results
|
||||
|
||||
|
||||
class Translation(object):
|
||||
"""
|
||||
Container for a translated sentence.
|
||||
|
||||
Attributes:
|
||||
src (`LongTensor`): src word ids
|
||||
src_raw ([str]): raw src words
|
||||
|
||||
pred_sents ([[str]]): words from the n-best translations
|
||||
pred_scores ([[float]]): log-probs of n-best translations
|
||||
attns ([`FloatTensor`]) : attention dist for each translation
|
||||
gold_sent ([str]): words from gold translation
|
||||
gold_score ([float]): log-prob of gold translation
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, fname, src, src_raw, pred_sents, attn, pred_scores, tgt_sent=None, gold_score=0
|
||||
):
|
||||
self.fname = fname
|
||||
self.src = src
|
||||
self.src_raw = src_raw
|
||||
self.pred_sents = pred_sents
|
||||
self.attns = attn
|
||||
self.pred_scores = pred_scores
|
||||
self.gold_sent = tgt_sent
|
||||
self.gold_score = gold_score
|
||||
|
||||
def log(self, sent_number):
|
||||
"""
|
||||
Log translation.
|
||||
"""
|
||||
|
||||
output = "\nSENT {}: {}\n".format(sent_number, self.src_raw)
|
||||
|
||||
best_pred = self.pred_sents[0]
|
||||
best_score = self.pred_scores[0]
|
||||
pred_sent = " ".join(best_pred)
|
||||
output += "PRED {}: {}\n".format(sent_number, pred_sent)
|
||||
output += "PRED SCORE: {:.4f}\n".format(best_score)
|
||||
|
||||
if self.gold_sent is not None:
|
||||
tgt_sent = " ".join(self.gold_sent)
|
||||
output += "GOLD {}: {}\n".format(sent_number, tgt_sent)
|
||||
output += "GOLD SCORE: {:.4f}\n".format(self.gold_score)
|
||||
if len(self.pred_sents) > 1:
|
||||
output += "\nBEST HYP:\n"
|
||||
for score, sent in zip(self.pred_scores, self.pred_sents):
|
||||
output += "[{:.4f}] {}\n".format(score, sent)
|
||||
|
||||
return output
|
||||
|
|
|
@ -43,7 +43,7 @@ TOKENIZER_CLASS.update(
|
|||
MAX_SEQ_LEN = 512
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
fh = logging.FileHandler("abssum_train.log")
|
||||
fh = logging.FileHandler("longer_input_abssum_train.log")
|
||||
logger.addHandler(fh)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче