fix the input for preprocessing

This commit is contained in:
Daisy Deng 2020-02-20 20:03:19 +00:00
Родитель 9b727e9fee
Коммит 804d88c451
1 изменённых файлов: 35 добавлений и 24 удалений

Просмотреть файл

@ -10,7 +10,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
torch.set_printoptions(threshold=5000)
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from utils_nlp.models.transformers.datasets import ( from utils_nlp.models.transformers.datasets import (
@ -32,22 +32,23 @@ TOP_N = 10
# @pytest.fixture() # @pytest.fixture()
def source_data(): def source_data():
return [ return [
( [
"Boston, MA welcome to Microsoft/nlp. Welcome to text summarization." "Boston, MA welcome to Microsoft/nlp",
"Welcome to Microsoft NERD." "Welcome to text summarization.",
"Welcome to Microsoft NERD.",
"Look outside, what a beautiful Charlse River fall view." "Look outside, what a beautiful Charlse River fall view."
), ],
("I am just another test case"), ["I am just another test case"],
("want to test more"), ["want to test more"],
] ]
# @pytest.fixture() # @pytest.fixture()
def target_data(): def target_data():
return [ return [
("welcome to microsoft/nlp." "Welcome to text summarization." "Welcome to Microsoft NERD."), ["welcome to microsoft/nlp.", "Welcome to text summarization.", "Welcome to Microsoft NERD."],
("I am just another test summary"), ["I am just another test summary"],
("yest, I agree"), ["yest, I agree"],
] ]
@ -66,6 +67,15 @@ def test_preprocessing():
print(batch) print(batch)
print(len(batch.src[0])) print(len(batch.src[0]))
def test_collate():
test_data_path = os.path.join(DATA_PATH, "test_abssum_dataset_full.pt")
test_sum_dataset = torch.load(test_data_path)
temp = shorten_dataset(test_sum_dataset, top_n=2)
processor = AbsSumProcessor()
batch = processor.collate(temp, 512, "cuda:0")
print(batch.tgt)
print(batch.tgt_num_tokens)
#print(len(batch.src[0]))
def shorten_dataset(dataset, top_n=-1): def shorten_dataset(dataset, top_n=-1):
if top_n == -1: if top_n == -1:
@ -149,7 +159,7 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
save_every = -1 save_every = -1
this_validate = None this_validate = None
else: else:
save_every = 100 save_every = 400
#summarizer.model.load_checkpoint(checkpoint['model']) #summarizer.model.load_checkpoint(checkpoint['model'])
summarizer.fit( summarizer.fit(
@ -158,17 +168,17 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
num_gpus=None, num_gpus=None,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
batch_size=8, batch_size=6,
max_steps=50000/world_size, max_steps=50000/world_size,
learning_rate_bert=0.002, learning_rate_bert=0.003,
learning_rate_dec=0.2, learning_rate_dec=0.3,
warmup_steps_bert=20000, warmup_steps_bert=20000,
warmup_steps_dec=10000, warmup_steps_dec=10000,
save_every=save_every, save_every=save_every,
report_every=10, report_every=10,
validation_function=this_validate, validation_function=this_validate,
fp16=True, fp16=True,
fp16_opt_level="O2", fp16_opt_level="O1",
checkpoint=None checkpoint=None
) )
if rank == 0 or local_rank == -1: if rank == 0 or local_rank == -1:
@ -192,18 +202,20 @@ def test_train_model():
train_sum_dataset, test_sum_dataset = preprocess_cnndm_abs() train_sum_dataset, test_sum_dataset = preprocess_cnndm_abs()
#train_sum_dataset = shorten_dataset(train_sum_dataset, top_n=4) ## at lease gradient_accumulation * batch_size long
def this_validate(class_obj): def this_validate(class_obj):
return validate(class_obj, test_sum_dataset, CACHE_PATH) return validate(class_obj, test_sum_dataset, CACHE_PATH)
summarizer.fit( summarizer.fit(
train_sum_dataset, train_sum_dataset,
batch_size=5, batch_size=8,
max_steps=30000, max_steps=30000,
local_rank=-1, local_rank=-1,
learning_rate_bert=0.002, learning_rate_bert=0.002,
learning_rate_dec=0.2, learning_rate_dec=0.2,
warmup_steps_bert=20000, warmup_steps_bert=20000,
warmup_steps_dec=10000, warmup_steps_dec=10000,
num_gpus=1, num_gpus=2,
report_every=10, report_every=10,
save_every=100, save_every=100,
validation_function=this_validate, validation_function=this_validate,
@ -244,12 +256,10 @@ def test_pretrained_model():
train_sum_dataset, test_sum_dataset = preprocess_cnndm_abs() train_sum_dataset, test_sum_dataset = preprocess_cnndm_abs()
processor = AbsSumProcessor(cache_dir=CACHE_PATH) processor = AbsSumProcessor(cache_dir=CACHE_PATH)
#checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt")) checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt"))
#checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_2900.pt"))
checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt")) #checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt"))
#checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000.pt"))
summarizer = AbsSum( summarizer = AbsSum(
processor, processor,
cache_dir=CACHE_PATH, cache_dir=CACHE_PATH,
@ -289,8 +299,9 @@ def test_pretrained_model():
#test_preprocessing() #test_preprocessing()
#test_collate()
#preprocess_cnndm_abs() #preprocess_cnndm_abs()
test_train_model() #test_train_model()
#test_pretrained_model() #test_pretrained_model()
#if __name__ == "__main__": if __name__ == "__main__":
# main() main()