Merge branch 'daden/presumm' of https://github.com/microsoft/nlp-recipes into daden/presumm
This commit is contained in:
Коммит
fd84b4b32f
|
@ -25,6 +25,12 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas
|
||||||
from utils_nlp.eval.evaluate_summarization import get_rouge
|
from utils_nlp.eval.evaluate_summarization import get_rouge
|
||||||
|
|
||||||
os.environ["NCCL_IB_DISABLE"] = "0"
|
os.environ["NCCL_IB_DISABLE"] = "0"
|
||||||
|
#os.environ["NCCL_DEBUG"] = "INFO"
|
||||||
|
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"
|
||||||
|
#os.environ["MASTER_PORT"] = "29952"
|
||||||
|
#os.environ["MASTER_ADDR"] = "172.12.0.6"
|
||||||
|
#os.environ['NCCL_SOCKET_IFNAME'] = 'lo'
|
||||||
|
|
||||||
|
|
||||||
def shorten_dataset(dataset, top_n=-1):
|
def shorten_dataset(dataset, top_n=-1):
|
||||||
if top_n == -1:
|
if top_n == -1:
|
||||||
|
@ -56,6 +62,8 @@ parser.add_argument("--lr_dec", type=float, default=2e-1,
|
||||||
help="Learning rate for the decoder.")
|
help="Learning rate for the decoder.")
|
||||||
parser.add_argument("--batch_size", type=int, default=5,
|
parser.add_argument("--batch_size", type=int, default=5,
|
||||||
help="batch size in terms of input token numbers in training")
|
help="batch size in terms of input token numbers in training")
|
||||||
|
parser.add_argument("--max_pos", type=int, default=640,
|
||||||
|
help="maximum input length in terms of input token numbers in training")
|
||||||
parser.add_argument("--max_steps", type=int, default=5e4,
|
parser.add_argument("--max_steps", type=int, default=5e4,
|
||||||
help="Maximum number of training steps run in training. If quick_run is set,\
|
help="Maximum number of training steps run in training. If quick_run is set,\
|
||||||
it's not used.")
|
it's not used.")
|
||||||
|
@ -107,7 +115,7 @@ def main():
|
||||||
print("data_dir is {}".format(args.data_dir))
|
print("data_dir is {}".format(args.data_dir))
|
||||||
print("cache_dir is {}".format(args.cache_dir))
|
print("cache_dir is {}".format(args.cache_dir))
|
||||||
ngpus_per_node = torch.cuda.device_count()
|
ngpus_per_node = torch.cuda.device_count()
|
||||||
processor = AbsSumProcessor(cache_dir=args.cache_dir)
|
processor = AbsSumProcessor(cache_dir=args.cache_dir, max_src_len=max_pos)
|
||||||
summarizer = AbsSum(
|
summarizer = AbsSum(
|
||||||
processor, cache_dir=args.cache_dir
|
processor, cache_dir=args.cache_dir
|
||||||
)
|
)
|
||||||
|
@ -120,13 +128,14 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
|
||||||
print("world_size is {}".format(world_size))
|
print("world_size is {}".format(world_size))
|
||||||
print("local_rank is {} and rank is {}".format(local_rank, rank))
|
print("local_rank is {} and rank is {}".format(local_rank, rank))
|
||||||
|
|
||||||
|
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
init_method=args.dist_url,
|
init_method=args.dist_url,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# return
|
||||||
## should not load checkpoint from this place, otherwise, huge memory increase
|
## should not load checkpoint from this place, otherwise, huge memory increase
|
||||||
if args.checkpoint_filename:
|
if args.checkpoint_filename:
|
||||||
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
|
checkpoint = os.path.join(args.cache_dir, args.checkpoint_filename)
|
||||||
|
|
|
@ -25,7 +25,7 @@ from utils_nlp.models.transformers.datasets import SummarizationNonIterableDatas
|
||||||
from utils_nlp.eval.evaluate_summarization import get_rouge
|
from utils_nlp.eval.evaluate_summarization import get_rouge
|
||||||
|
|
||||||
CACHE_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
CACHE_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
||||||
DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization"
|
DATA_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
||||||
MODEL_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
MODEL_PATH = "/dadendev/nlp-recipes/examples/text_summarization/abstemp"
|
||||||
TOP_N = 10
|
TOP_N = 10
|
||||||
|
|
||||||
|
@ -259,10 +259,11 @@ def test_pretrained_model():
|
||||||
checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt"))
|
checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt"))
|
||||||
|
|
||||||
#checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt"))
|
#checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt"))
|
||||||
checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt"))
|
#checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_400.pt"))
|
||||||
summarizer = AbsSum(
|
summarizer = AbsSum(
|
||||||
processor,
|
processor,
|
||||||
cache_dir=CACHE_PATH,
|
cache_dir=CACHE_PATH,
|
||||||
|
max_pos=512,
|
||||||
)
|
)
|
||||||
summarizer.model.load_checkpoint(checkpoint['model'])
|
summarizer.model.load_checkpoint(checkpoint['model'])
|
||||||
"""
|
"""
|
||||||
|
@ -284,14 +285,15 @@ def test_pretrained_model():
|
||||||
return
|
return
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_n = 8
|
top_n = 96*4
|
||||||
src = test_sum_dataset.source[0:top_n]
|
src = test_sum_dataset.source[0:top_n]
|
||||||
reference_summaries = ["".join(t).rstrip("\n") for t in test_sum_dataset.target[0:top_n]]
|
reference_summaries = ["".join(t).rstrip("\n") for t in test_sum_dataset.target[0:top_n]]
|
||||||
print("start prediction")
|
print("start prediction")
|
||||||
generated_summaries = summarizer.predict(
|
generated_summaries = summarizer.predict(
|
||||||
shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=4, num_gpus=2
|
shorten_dataset(test_sum_dataset, top_n=top_n), batch_size=96+16, num_gpus=1, max_seq_length=512
|
||||||
)
|
)
|
||||||
print(generated_summaries[0])
|
print(generated_summaries[0])
|
||||||
|
print(len(generated_summaries))
|
||||||
assert len(generated_summaries) == len(reference_summaries)
|
assert len(generated_summaries) == len(reference_summaries)
|
||||||
RESULT_DIR = TemporaryDirectory().name
|
RESULT_DIR = TemporaryDirectory().name
|
||||||
rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR)
|
rouge_score = get_rouge(generated_summaries, reference_summaries, RESULT_DIR)
|
||||||
|
@ -303,6 +305,6 @@ def test_pretrained_model():
|
||||||
#test_collate()
|
#test_collate()
|
||||||
#preprocess_cnndm_abs()
|
#preprocess_cnndm_abs()
|
||||||
#test_train_model()
|
#test_train_model()
|
||||||
#test_pretrained_model()
|
test_pretrained_model()
|
||||||
if __name__ == "__main__":
|
#if __name__ == "__main__":
|
||||||
main()
|
# main()
|
||||||
|
|
|
@ -113,7 +113,7 @@ class AbsSumProcessor:
|
||||||
model_name="bert-base-uncased",
|
model_name="bert-base-uncased",
|
||||||
to_lower=False,
|
to_lower=False,
|
||||||
cache_dir=".",
|
cache_dir=".",
|
||||||
max_len=512,
|
max_src_len=640,
|
||||||
max_target_len=140,
|
max_target_len=140,
|
||||||
):
|
):
|
||||||
""" Initialize the preprocessor.
|
""" Initialize the preprocessor.
|
||||||
|
@ -144,14 +144,11 @@ class AbsSumProcessor:
|
||||||
self.sep_token = "[SEP]"
|
self.sep_token = "[SEP]"
|
||||||
self.cls_token = "[CLS]"
|
self.cls_token = "[CLS]"
|
||||||
self.pad_token = "[PAD]"
|
self.pad_token = "[PAD]"
|
||||||
self.tgt_bos = "[unused0]"
|
self.tgt_bos = self.symbols["BOS"]
|
||||||
self.tgt_eos = "[unused1]"
|
self.tgt_eos = self.symbols["EOS"]
|
||||||
|
|
||||||
self.sep_vid = self.tokenizer.vocab[self.sep_token]
|
|
||||||
self.cls_vid = self.tokenizer.vocab[self.cls_token]
|
|
||||||
self.pad_vid = self.tokenizer.vocab[self.pad_token]
|
|
||||||
|
|
||||||
self.max_len = max_len
|
self.max_src_len = max_src_len
|
||||||
self.max_target_len = max_target_len
|
self.max_target_len = max_target_len
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -243,7 +240,8 @@ class AbsSumProcessor:
|
||||||
if train_mode:
|
if train_mode:
|
||||||
encoded_summaries = torch.tensor(
|
encoded_summaries = torch.tensor(
|
||||||
[
|
[
|
||||||
fit_to_block_size(summary, block_size, self.tokenizer.pad_token_id)
|
[self.tgt_bos] + fit_to_block_size(summary, block_size-2, self.tokenizer.pad_token_id)
|
||||||
|
+[self.tgt_eos]
|
||||||
for _, summary in encoded_text
|
for _, summary in encoded_text
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -308,7 +306,7 @@ class AbsSumProcessor:
|
||||||
try:
|
try:
|
||||||
if len(line) <= 0:
|
if len(line) <= 0:
|
||||||
continue
|
continue
|
||||||
story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_len))
|
story_lines_token_ids.append(self.tokenizer.encode(line, max_length=self.max_src_len))
|
||||||
except:
|
except:
|
||||||
print(line)
|
print(line)
|
||||||
raise
|
raise
|
||||||
|
@ -325,9 +323,9 @@ class AbsSumProcessor:
|
||||||
except:
|
except:
|
||||||
print(line)
|
print(line)
|
||||||
raise
|
raise
|
||||||
summary_token_ids = [
|
summary_token_ids = [
|
||||||
token for sentence in summary_lines_token_ids for token in sentence
|
token for sentence in summary_lines_token_ids for token in sentence
|
||||||
]
|
]
|
||||||
return story_token_ids, summary_token_ids
|
return story_token_ids, summary_token_ids
|
||||||
else:
|
else:
|
||||||
return story_token_ids
|
return story_token_ids
|
||||||
|
@ -366,6 +364,7 @@ class AbsSum(Transformer):
|
||||||
cache_dir=".",
|
cache_dir=".",
|
||||||
label_smoothing=0.1,
|
label_smoothing=0.1,
|
||||||
test=False,
|
test=False,
|
||||||
|
max_pos=768,
|
||||||
):
|
):
|
||||||
"""Initialize a ExtractiveSummarizer.
|
"""Initialize a ExtractiveSummarizer.
|
||||||
|
|
||||||
|
@ -397,6 +396,7 @@ class AbsSum(Transformer):
|
||||||
|
|
||||||
self.model_class = MODEL_CLASS[model_name]
|
self.model_class = MODEL_CLASS[model_name]
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
|
self.max_pos = max_pos
|
||||||
|
|
||||||
self.model = AbsSummarizer(
|
self.model = AbsSummarizer(
|
||||||
temp_dir=cache_dir,
|
temp_dir=cache_dir,
|
||||||
|
@ -404,12 +404,12 @@ class AbsSum(Transformer):
|
||||||
checkpoint=None,
|
checkpoint=None,
|
||||||
label_smoothing=label_smoothing,
|
label_smoothing=label_smoothing,
|
||||||
symbols=processor.symbols,
|
symbols=processor.symbols,
|
||||||
test=test
|
test=test,
|
||||||
|
max_pos=self.max_pos,
|
||||||
)
|
)
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.optim_bert = None
|
self.optim_bert = None
|
||||||
self.optim_dec = None
|
self.optim_dec = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -423,7 +423,7 @@ class AbsSum(Transformer):
|
||||||
train_dataset,
|
train_dataset,
|
||||||
num_gpus=None,
|
num_gpus=None,
|
||||||
gpu_ids=None,
|
gpu_ids=None,
|
||||||
batch_size=140,
|
batch_size=4,
|
||||||
local_rank=-1,
|
local_rank=-1,
|
||||||
max_steps=5e5,
|
max_steps=5e5,
|
||||||
warmup_steps_bert=8000,
|
warmup_steps_bert=8000,
|
||||||
|
@ -499,6 +499,7 @@ class AbsSum(Transformer):
|
||||||
#"""
|
#"""
|
||||||
self.optim_bert = model_builder.build_optim_bert(
|
self.optim_bert = model_builder.build_optim_bert(
|
||||||
self.model,
|
self.model,
|
||||||
|
optim=optimization_method,
|
||||||
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3",
|
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3",
|
||||||
lr_bert=learning_rate_bert,
|
lr_bert=learning_rate_bert,
|
||||||
warmup_steps_bert=warmup_steps_bert,
|
warmup_steps_bert=warmup_steps_bert,
|
||||||
|
@ -506,6 +507,7 @@ class AbsSum(Transformer):
|
||||||
)
|
)
|
||||||
self.optim_dec = model_builder.build_optim_dec(
|
self.optim_dec = model_builder.build_optim_dec(
|
||||||
self.model,
|
self.model,
|
||||||
|
optim=optimization_method,
|
||||||
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3"
|
visible_gpus=None, #",".join([str(i) for i in range(num_gpus)]), #"0,1,2,3"
|
||||||
lr_dec=learning_rate_dec,
|
lr_dec=learning_rate_dec,
|
||||||
warmup_steps_dec=warmup_steps_dec,
|
warmup_steps_dec=warmup_steps_dec,
|
||||||
|
@ -518,6 +520,7 @@ class AbsSum(Transformer):
|
||||||
optimizers = [self.optim_bert, self.optim_dec]
|
optimizers = [self.optim_bert, self.optim_dec]
|
||||||
schedulers = [self.scheduler_bert, self.scheduler_dec]
|
schedulers = [self.scheduler_bert, self.scheduler_dec]
|
||||||
|
|
||||||
|
|
||||||
self.amp = get_amp(fp16)
|
self.amp = get_amp(fp16)
|
||||||
if self.amp:
|
if self.amp:
|
||||||
self.model, optim = self.amp.initialize(self.model, optimizers, opt_level=fp16_opt_level)
|
self.model, optim = self.amp.initialize(self.model, optimizers, opt_level=fp16_opt_level)
|
||||||
|
@ -548,7 +551,7 @@ class AbsSum(Transformer):
|
||||||
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
def collate_fn(data):
|
def collate_fn(data):
|
||||||
return self.processor.collate(data, block_size=512, device=device)
|
return self.processor.collate(data, block_size=self.max_pos, device=device)
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
@ -592,13 +595,13 @@ class AbsSum(Transformer):
|
||||||
local_rank=-1,
|
local_rank=-1,
|
||||||
gpu_ids=None,
|
gpu_ids=None,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
# sentence_separator="<q>",
|
|
||||||
alpha=0.6,
|
alpha=0.6,
|
||||||
beam_size=5,
|
beam_size=5,
|
||||||
min_length=15,
|
min_length=15,
|
||||||
max_length=150,
|
max_length=150,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
max_seq_length=768,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Predict the summarization for the input data iterator.
|
Predict the summarization for the input data iterator.
|
||||||
|
@ -625,11 +628,11 @@ class AbsSum(Transformer):
|
||||||
List of strings which are the summaries
|
List of strings which are the summaries
|
||||||
|
|
||||||
"""
|
"""
|
||||||
device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank)
|
device, num_gpus = get_device(num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank)
|
||||||
# move model to devices
|
# move model to devices
|
||||||
def this_model_move_callback(model, device):
|
def this_model_move_callback(model, device):
|
||||||
model = move_model_to_device(model, device)
|
model = move_model_to_device(model, device)
|
||||||
return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=None, local_rank=local_rank)
|
return parallelize_model(model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
self.model = self.model.half()
|
self.model = self.model.half()
|
||||||
|
@ -647,13 +650,13 @@ class AbsSum(Transformer):
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
predictor = predictor.move_to_device(device, this_model_move_callback)
|
predictor = this_model_move_callback(predictor, device)
|
||||||
|
|
||||||
|
|
||||||
test_sampler = SequentialSampler(test_dataset)
|
test_sampler = SequentialSampler(test_dataset)
|
||||||
|
|
||||||
def collate_fn(data):
|
def collate_fn(data):
|
||||||
return self.processor.collate(data, 512, device, train_mode=False)
|
return self.processor.collate(data, max_seq_length, device, train_mode=False)
|
||||||
|
|
||||||
test_dataloader = DataLoader(
|
test_dataloader = DataLoader(
|
||||||
test_dataset, sampler=test_sampler, batch_size=batch_size, collate_fn=collate_fn,
|
test_dataset, sampler=test_sampler, batch_size=batch_size, collate_fn=collate_fn,
|
||||||
|
@ -663,7 +666,7 @@ class AbsSum(Transformer):
|
||||||
""" Transforms the output of the `from_batch` function
|
""" Transforms the output of the `from_batch` function
|
||||||
into nicely formatted summaries.
|
into nicely formatted summaries.
|
||||||
"""
|
"""
|
||||||
raw_summary, _, = translation
|
raw_summary = translation
|
||||||
summary = (
|
summary = (
|
||||||
raw_summary.replace("[unused0]", "")
|
raw_summary.replace("[unused0]", "")
|
||||||
.replace("[unused3]", "")
|
.replace("[unused3]", "")
|
||||||
|
@ -677,16 +680,31 @@ class AbsSum(Transformer):
|
||||||
.strip()
|
.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
def generate_summary_from_tokenid(preds, pred_score):
|
||||||
|
batch_size = preds.size()[0] # batch.batch_size
|
||||||
|
translations = []
|
||||||
|
for b in range(batch_size):
|
||||||
|
if len(preds[b]) < 1:
|
||||||
|
pred_sents = ""
|
||||||
|
else:
|
||||||
|
pred_sents = self.processor.tokenizer.convert_ids_to_tokens([int(n) for n in preds[b] if int(n)!=0])
|
||||||
|
pred_sents = " ".join(pred_sents).replace(" ##", "")
|
||||||
|
translations.append(pred_sents)
|
||||||
|
return translations
|
||||||
|
|
||||||
generated_summaries = []
|
generated_summaries = []
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
for batch in tqdm(test_dataloader):
|
for batch in tqdm(test_dataloader, desc="Generating summary", disable=not verbose):
|
||||||
batch_data = predictor.translate_batch(batch)
|
input = self.processor.get_inputs(batch, device, "bert", train_mode=False)
|
||||||
translations = predictor.from_batch(batch_data)
|
translations, scores = predictor(**input)
|
||||||
summaries = [format_summary(t) for t in translations]
|
|
||||||
generated_summaries += summaries
|
translations_text = generate_summary_from_tokenid(translations, scores)
|
||||||
|
summaries = [format_summary(t) for t in translations_text]
|
||||||
|
generated_summaries.extend(summaries)
|
||||||
return generated_summaries
|
return generated_summaries
|
||||||
|
|
||||||
def save_model(self, global_step=None, full_name=None):
|
def save_model(self, global_step=None, full_name=None):
|
||||||
|
|
|
@ -196,7 +196,7 @@ class AbsSummarizer(nn.Module):
|
||||||
self.bert.model = BertModel(bert_config)
|
self.bert.model = BertModel(bert_config)
|
||||||
|
|
||||||
if max_pos > 512:
|
if max_pos > 512:
|
||||||
my_pos_embeddi = nn.Embedding(max_pos, self.bert.model.config.hidden_size)
|
my_pos_embeddings = nn.Embedding(max_pos, self.bert.model.config.hidden_size)
|
||||||
my_pos_embeddings.weight.data[
|
my_pos_embeddings.weight.data[
|
||||||
:512
|
:512
|
||||||
] = self.bert.model.embeddings.position_embeddings.weight.data
|
] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||||
|
|
|
@ -6,7 +6,7 @@ import os
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
# from others.utils import rouge_results_to_str, test_rouge, tile
|
# from others.utils import rouge_results_to_str, test_rouge, tile
|
||||||
|
@ -55,7 +55,7 @@ def tile(x, count, dim=0):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Translator(object):
|
class Translator(nn.Module):
|
||||||
"""
|
"""
|
||||||
Uses a model to translate a batch of sentences.
|
Uses a model to translate a batch of sentences.
|
||||||
|
|
||||||
|
@ -70,14 +70,12 @@ class Translator(object):
|
||||||
global_scores (:obj:`GlobalScorer`):
|
global_scores (:obj:`GlobalScorer`):
|
||||||
object to rescore final translations
|
object to rescore final translations
|
||||||
copy_attn (bool): use copy attention during translation
|
copy_attn (bool): use copy attention during translation
|
||||||
cuda (bool): use cuda
|
|
||||||
beam_trace (bool): trace beam search for debugging
|
beam_trace (bool): trace beam search for debugging
|
||||||
logger(logging.Logger): logger.
|
logger(logging.Logger): logger.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
# args,
|
|
||||||
beam_size,
|
beam_size,
|
||||||
min_length,
|
min_length,
|
||||||
max_length,
|
max_length,
|
||||||
|
@ -89,10 +87,9 @@ class Translator(object):
|
||||||
logger=None,
|
logger=None,
|
||||||
dump_beam="",
|
dump_beam="",
|
||||||
):
|
):
|
||||||
|
super(Translator, self).__init__()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cuda = 0 # args.visible_gpus != '-1'
|
|
||||||
|
|
||||||
# self.args = args
|
|
||||||
self.model = model.module if hasattr(model, "module") else model
|
self.model = model.module if hasattr(model, "module") else model
|
||||||
self.generator = self.model.generator
|
self.generator = self.model.generator
|
||||||
self.decoder = self.model.decoder
|
self.decoder = self.model.decoder
|
||||||
|
@ -115,10 +112,6 @@ class Translator(object):
|
||||||
self.beam_trace = self.dump_beam != ""
|
self.beam_trace = self.dump_beam != ""
|
||||||
self.beam_accum = None
|
self.beam_accum = None
|
||||||
|
|
||||||
# tensorboard_log_dir = args.model_path
|
|
||||||
|
|
||||||
# self.tensorboard_writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")
|
|
||||||
|
|
||||||
if self.beam_trace:
|
if self.beam_trace:
|
||||||
self.beam_accum = {
|
self.beam_accum = {
|
||||||
"predicted_ids": [],
|
"predicted_ids": [],
|
||||||
|
@ -126,90 +119,15 @@ class Translator(object):
|
||||||
"scores": [],
|
"scores": [],
|
||||||
"log_probs": [],
|
"log_probs": [],
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
def move_to_device(self, device, move_to_device_fn):
|
|
||||||
self.move_to_device_fn = move_to_device_fn
|
|
||||||
self.model = move_to_device_fn(self.model, device)
|
|
||||||
self.bert = move_to_device_fn(self.bert, device)
|
|
||||||
self.generator = move_to_device_fn(self.generator, device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.bert.eval()
|
self.bert.eval()
|
||||||
self.decoder.eval()
|
self.decoder.eval()
|
||||||
self.generator.eval()
|
self.generator.eval()
|
||||||
|
"""
|
||||||
|
|
||||||
def _build_target_tokens(self, pred):
|
def forward(self, src, segs, mask_src):
|
||||||
# vocab = self.fields["tgt"].vocab
|
|
||||||
tokens = []
|
|
||||||
for tok in pred:
|
|
||||||
tok = int(tok)
|
|
||||||
tokens.append(tok)
|
|
||||||
if tokens[-1] == self.end_token:
|
|
||||||
tokens = tokens[:-1]
|
|
||||||
break
|
|
||||||
tokens = [t for t in tokens if t < len(self.vocab)]
|
|
||||||
tokens = self.vocab.DecodeIds(tokens).split(" ")
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def from_batch(self, translation_batch):
|
|
||||||
batch = translation_batch["batch"]
|
|
||||||
# assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"])
|
|
||||||
batch_size = batch.batch_size
|
|
||||||
|
|
||||||
# preds, pred_score, gold_score, tgt_str, src = (
|
|
||||||
preds, pred_score, src = (
|
|
||||||
translation_batch["predictions"],
|
|
||||||
translation_batch["scores"],
|
|
||||||
# translation_batch["gold_score"],
|
|
||||||
# batch.tgt_str,
|
|
||||||
batch.src,
|
|
||||||
)
|
|
||||||
# print(preds)
|
|
||||||
# print(pred_score)
|
|
||||||
# print(batch.tgt_str)
|
|
||||||
# print(batch.src)
|
|
||||||
|
|
||||||
translations = []
|
|
||||||
for b in range(batch_size):
|
|
||||||
|
|
||||||
if len(preds[b]) < 1:
|
|
||||||
pred_sents = ""
|
|
||||||
else:
|
|
||||||
pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]])
|
|
||||||
pred_sents = " ".join(pred_sents).replace(" ##", "")
|
|
||||||
# gold_sent = " ".join(tgt_str[b].split())
|
|
||||||
# translation = Translation(fname[b],src[:, b] if src is not None else None,
|
|
||||||
# src_raw, pred_sents,
|
|
||||||
# attn[b], pred_score[b], gold_sent,
|
|
||||||
# gold_score[b])
|
|
||||||
# src = self.spm.DecodeIds([int(t) for t in translation_batch['batch'].src[0][5] if int(t) != len(self.spm)])
|
|
||||||
raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500]
|
|
||||||
raw_src = " ".join(raw_src)
|
|
||||||
# translation = (pred_sents, gold_sent, raw_src)
|
|
||||||
translation = (pred_sents, raw_src)
|
|
||||||
# translation = (pred_sents[0], gold_sent)
|
|
||||||
translations.append(translation)
|
|
||||||
|
|
||||||
return translations
|
|
||||||
|
|
||||||
def translate(self, batch, attn_debug=False):
|
|
||||||
|
|
||||||
#self.model.eval()
|
|
||||||
self.eval()
|
|
||||||
# pred_results, gold_results = [], []
|
|
||||||
with torch.no_grad():
|
|
||||||
batch_data = self.translate_batch(batch)
|
|
||||||
translations = self.from_batch(batch_data)
|
|
||||||
|
|
||||||
# for trans in translations:
|
|
||||||
# pred, gold, src = trans
|
|
||||||
# pred_str = pred.replace('[unused0]', '').replace('[unused3]', '').replace('[PAD]', '').replace('[unused1]', '').replace(r' +', ' ').replace(' [unused2] ', '<q>').replace('[unused2]', '').strip()
|
|
||||||
# gold_str = gold.strip()
|
|
||||||
return translations
|
|
||||||
|
|
||||||
def translate_batch(self, batch, fast=False):
|
|
||||||
"""
|
"""
|
||||||
Translate a batch of sentences.
|
Translate a batch of sentences.
|
||||||
|
|
||||||
|
@ -224,22 +142,22 @@ class Translator(object):
|
||||||
Shouldn't need the original dataset.
|
Shouldn't need the original dataset.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
predictions, scores = self._fast_translate_batch(src, segs, mask_src, self.max_length, min_length=self.min_length)
|
||||||
|
return predictions, scores
|
||||||
|
|
||||||
|
|
||||||
def _fast_translate_batch(self, batch, max_length, min_length=0):
|
def _fast_translate_batch(self, src, segs, mask_src, max_length, min_length=0):
|
||||||
# TODO: faster code path for beam_size == 1.
|
# TODO: faster code path for beam_size == 1.
|
||||||
|
|
||||||
# TODO: support these blacklisted features.
|
# TODO: support these blacklisted features.
|
||||||
assert not self.dump_beam
|
assert not self.dump_beam
|
||||||
|
|
||||||
beam_size = self.beam_size
|
beam_size = self.beam_size
|
||||||
batch_size = batch.batch_size
|
batch_size = src.size()[0] #32 #batch.batch_size
|
||||||
src = batch.src
|
|
||||||
segs = batch.segs
|
|
||||||
mask_src = batch.mask_src
|
|
||||||
|
|
||||||
src_features = self.bert(src, segs, mask_src)
|
src_features = self.bert(src, segs, mask_src)
|
||||||
dec_states = self.decoder.init_decoder_state(src, src_features, with_cache=True)
|
this_decoder = self.decoder.module if hasattr(self.decoder, "module") else self.decoder
|
||||||
|
dec_states = this_decoder.init_decoder_state(src, src_features, with_cache=True)
|
||||||
|
|
||||||
device = src_features.device
|
device = src_features.device
|
||||||
|
|
||||||
|
@ -266,7 +184,7 @@ class Translator(object):
|
||||||
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
|
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
|
||||||
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
|
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
|
||||||
# results["gold_score"] = [0] * batch_size
|
# results["gold_score"] = [0] * batch_size
|
||||||
results["batch"] = batch
|
#results["batch"] = batch
|
||||||
|
|
||||||
for step in range(max_length):
|
for step in range(max_length):
|
||||||
decoder_input = alive_seq[:, -1].view(1, -1)
|
decoder_input = alive_seq[:, -1].view(1, -1)
|
||||||
|
@ -274,7 +192,7 @@ class Translator(object):
|
||||||
# Decoder forward.
|
# Decoder forward.
|
||||||
decoder_input = decoder_input.transpose(0, 1)
|
decoder_input = decoder_input.transpose(0, 1)
|
||||||
|
|
||||||
dec_out, dec_states = self.decoder(
|
dec_out, dec_states = this_decoder(
|
||||||
decoder_input, src_features, dec_states, step=step
|
decoder_input, src_features, dec_states, step=step
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -321,13 +239,11 @@ class Translator(object):
|
||||||
|
|
||||||
# Resolve beam origin and true word ids.
|
# Resolve beam origin and true word ids.
|
||||||
topk_beam_index = topk_ids.div(vocab_size)
|
topk_beam_index = topk_ids.div(vocab_size)
|
||||||
# print("topk_beam_index.shape {}".format( topk_beam_index.size()))
|
|
||||||
topk_ids = topk_ids.fmod(vocab_size)
|
topk_ids = topk_ids.fmod(vocab_size)
|
||||||
|
|
||||||
# Map beam_index to batch_index in the flat representation.
|
# Map beam_index to batch_index in the flat representation.
|
||||||
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
||||||
select_indices = batch_index.view(-1)
|
select_indices = batch_index.view(-1)
|
||||||
# print("select_indices {}".format(select_indices))
|
|
||||||
|
|
||||||
# Append last prediction.
|
# Append last prediction.
|
||||||
alive_seq = torch.cat(
|
alive_seq = torch.cat(
|
||||||
|
@ -335,15 +251,10 @@ class Translator(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
is_finished = topk_ids.eq(self.end_token)
|
is_finished = topk_ids.eq(self.end_token)
|
||||||
# print("is_finished {}".format(is_finished))
|
|
||||||
# print("is_finished size {}".format(is_finished.size()))
|
|
||||||
if step + 1 == max_length:
|
if step + 1 == max_length:
|
||||||
# print("reached max_length {} at step {}".format(max_length, step))
|
|
||||||
is_finished.fill_(True)
|
is_finished.fill_(True)
|
||||||
# print("is_finished {}".format(is_finished))
|
|
||||||
# End condition is top beam is finished.
|
# End condition is top beam is finished.
|
||||||
end_condition = is_finished[:, 0].eq(True)
|
end_condition = is_finished[:, 0].eq(True)
|
||||||
# print("end_condition {}".format(end_condition))
|
|
||||||
if step + 1 == max_length:
|
if step + 1 == max_length:
|
||||||
assert not any(end_condition.eq(False))
|
assert not any(end_condition.eq(False))
|
||||||
|
|
||||||
|
@ -353,7 +264,6 @@ class Translator(object):
|
||||||
for i in range(is_finished.size(0)):
|
for i in range(is_finished.size(0)):
|
||||||
b = batch_offset[i]
|
b = batch_offset[i]
|
||||||
if end_condition[i]:
|
if end_condition[i]:
|
||||||
# print("batch offset {} finished".format(b))
|
|
||||||
is_finished[i].fill_(1)
|
is_finished[i].fill_(1)
|
||||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||||
# Store finished hypotheses for this batch.
|
# Store finished hypotheses for this batch.
|
||||||
|
@ -363,9 +273,6 @@ class Translator(object):
|
||||||
if end_condition[i]:
|
if end_condition[i]:
|
||||||
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
|
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
|
||||||
score, pred = best_hyp[0]
|
score, pred = best_hyp[0]
|
||||||
# if len(pred) == 0:
|
|
||||||
# print("batch offset {} finished with empty prediction {}".format(b, pred))
|
|
||||||
|
|
||||||
results["scores"][b].append(score)
|
results["scores"][b].append(score)
|
||||||
results["predictions"][b].append(pred)
|
results["predictions"][b].append(pred)
|
||||||
non_finished = end_condition.eq(0).nonzero().view(-1)
|
non_finished = end_condition.eq(0).nonzero().view(-1)
|
||||||
|
@ -383,63 +290,7 @@ class Translator(object):
|
||||||
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
||||||
|
|
||||||
empty_output = [len(results["predictions"][b]) <= 0 for b in batch_offset]
|
empty_output = [len(results["predictions"][b]) <= 0 for b in batch_offset]
|
||||||
if any(empty_output):
|
predictions = torch.tensor([i[0].tolist()[0:self.max_length]+[0]*(self.max_length-i[0].size()[0]) for i in results["predictions"]], device=device)
|
||||||
print("there is empty output {}".format(empty_output))
|
scores = torch.tensor([i[0].item() for i in results['scores']], device=device)
|
||||||
print(batch_offset)
|
return predictions, scores
|
||||||
print(results)
|
|
||||||
|
|
||||||
# print("###########################################")
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
class Translation(object):
|
|
||||||
"""
|
|
||||||
Container for a translated sentence.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
src (`LongTensor`): src word ids
|
|
||||||
src_raw ([str]): raw src words
|
|
||||||
|
|
||||||
pred_sents ([[str]]): words from the n-best translations
|
|
||||||
pred_scores ([[float]]): log-probs of n-best translations
|
|
||||||
attns ([`FloatTensor`]) : attention dist for each translation
|
|
||||||
gold_sent ([str]): words from gold translation
|
|
||||||
gold_score ([float]): log-prob of gold translation
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, fname, src, src_raw, pred_sents, attn, pred_scores, tgt_sent=None, gold_score=0
|
|
||||||
):
|
|
||||||
self.fname = fname
|
|
||||||
self.src = src
|
|
||||||
self.src_raw = src_raw
|
|
||||||
self.pred_sents = pred_sents
|
|
||||||
self.attns = attn
|
|
||||||
self.pred_scores = pred_scores
|
|
||||||
self.gold_sent = tgt_sent
|
|
||||||
self.gold_score = gold_score
|
|
||||||
|
|
||||||
def log(self, sent_number):
|
|
||||||
"""
|
|
||||||
Log translation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
output = "\nSENT {}: {}\n".format(sent_number, self.src_raw)
|
|
||||||
|
|
||||||
best_pred = self.pred_sents[0]
|
|
||||||
best_score = self.pred_scores[0]
|
|
||||||
pred_sent = " ".join(best_pred)
|
|
||||||
output += "PRED {}: {}\n".format(sent_number, pred_sent)
|
|
||||||
output += "PRED SCORE: {:.4f}\n".format(best_score)
|
|
||||||
|
|
||||||
if self.gold_sent is not None:
|
|
||||||
tgt_sent = " ".join(self.gold_sent)
|
|
||||||
output += "GOLD {}: {}\n".format(sent_number, tgt_sent)
|
|
||||||
output += "GOLD SCORE: {:.4f}\n".format(self.gold_score)
|
|
||||||
if len(self.pred_sents) > 1:
|
|
||||||
output += "\nBEST HYP:\n"
|
|
||||||
for score, sent in zip(self.pred_scores, self.pred_sents):
|
|
||||||
output += "[{:.4f}] {}\n".format(score, sent)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ TOKENIZER_CLASS.update(
|
||||||
MAX_SEQ_LEN = 512
|
MAX_SEQ_LEN = 512
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
fh = logging.FileHandler("abssum_train.log")
|
fh = logging.FileHandler("longer_input_abssum_train.log")
|
||||||
logger.addHandler(fh)
|
logger.addHandler(fh)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче