fix the input for preprocessing
This commit is contained in:
Родитель
9b727e9fee
Коммит
804d88c451
|
@ -10,7 +10,7 @@ import pytest
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
torch.set_printoptions(threshold=5000)
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from utils_nlp.models.transformers.datasets import (
|
||||
|
@ -32,22 +32,23 @@ TOP_N = 10
|
|||
# @pytest.fixture()
|
||||
def source_data():
|
||||
return [
|
||||
(
|
||||
"Boston, MA welcome to Microsoft/nlp. Welcome to text summarization."
|
||||
"Welcome to Microsoft NERD."
|
||||
[
|
||||
"Boston, MA welcome to Microsoft/nlp",
|
||||
"Welcome to text summarization.",
|
||||
"Welcome to Microsoft NERD.",
|
||||
"Look outside, what a beautiful Charlse River fall view."
|
||||
),
|
||||
("I am just another test case"),
|
||||
("want to test more"),
|
||||
],
|
||||
["I am just another test case"],
|
||||
["want to test more"],
|
||||
]
|
||||
|
||||
|
||||
# @pytest.fixture()
|
||||
def target_data():
|
||||
return [
|
||||
("welcome to microsoft/nlp." "Welcome to text summarization." "Welcome to Microsoft NERD."),
|
||||
("I am just another test summary"),
|
||||
("yest, I agree"),
|
||||
["welcome to microsoft/nlp.", "Welcome to text summarization.", "Welcome to Microsoft NERD."],
|
||||
["I am just another test summary"],
|
||||
["yest, I agree"],
|
||||
]
|
||||
|
||||
|
||||
|
@ -66,6 +67,15 @@ def test_preprocessing():
|
|||
print(batch)
|
||||
print(len(batch.src[0]))
|
||||
|
||||
def test_collate():
|
||||
test_data_path = os.path.join(DATA_PATH, "test_abssum_dataset_full.pt")
|
||||
test_sum_dataset = torch.load(test_data_path)
|
||||
temp = shorten_dataset(test_sum_dataset, top_n=2)
|
||||
processor = AbsSumProcessor()
|
||||
batch = processor.collate(temp, 512, "cuda:0")
|
||||
print(batch.tgt)
|
||||
print(batch.tgt_num_tokens)
|
||||
#print(len(batch.src[0]))
|
||||
|
||||
def shorten_dataset(dataset, top_n=-1):
|
||||
if top_n == -1:
|
||||
|
@ -149,7 +159,7 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
|
|||
save_every = -1
|
||||
this_validate = None
|
||||
else:
|
||||
save_every = 100
|
||||
save_every = 400
|
||||
|
||||
#summarizer.model.load_checkpoint(checkpoint['model'])
|
||||
summarizer.fit(
|
||||
|
@ -158,17 +168,17 @@ def main_worker(local_rank, ngpus_per_node, summarizer, args):
|
|||
num_gpus=None,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
batch_size=8,
|
||||
batch_size=6,
|
||||
max_steps=50000/world_size,
|
||||
learning_rate_bert=0.002,
|
||||
learning_rate_dec=0.2,
|
||||
learning_rate_bert=0.003,
|
||||
learning_rate_dec=0.3,
|
||||
warmup_steps_bert=20000,
|
||||
warmup_steps_dec=10000,
|
||||
save_every=save_every,
|
||||
report_every=10,
|
||||
validation_function=this_validate,
|
||||
fp16=True,
|
||||
fp16_opt_level="O2",
|
||||
fp16_opt_level="O1",
|
||||
checkpoint=None
|
||||
)
|
||||
if rank == 0 or local_rank == -1:
|
||||
|
@ -192,18 +202,20 @@ def test_train_model():
|
|||
|
||||
train_sum_dataset, test_sum_dataset = preprocess_cnndm_abs()
|
||||
|
||||
#train_sum_dataset = shorten_dataset(train_sum_dataset, top_n=4) ## at lease gradient_accumulation * batch_size long
|
||||
|
||||
def this_validate(class_obj):
|
||||
return validate(class_obj, test_sum_dataset, CACHE_PATH)
|
||||
summarizer.fit(
|
||||
train_sum_dataset,
|
||||
batch_size=5,
|
||||
batch_size=8,
|
||||
max_steps=30000,
|
||||
local_rank=-1,
|
||||
learning_rate_bert=0.002,
|
||||
learning_rate_dec=0.2,
|
||||
warmup_steps_bert=20000,
|
||||
warmup_steps_dec=10000,
|
||||
num_gpus=1,
|
||||
num_gpus=2,
|
||||
report_every=10,
|
||||
save_every=100,
|
||||
validation_function=this_validate,
|
||||
|
@ -244,12 +256,10 @@ def test_pretrained_model():
|
|||
train_sum_dataset, test_sum_dataset = preprocess_cnndm_abs()
|
||||
|
||||
processor = AbsSumProcessor(cache_dir=CACHE_PATH)
|
||||
#checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt"))
|
||||
#checkpoint = torch.load(os.path.join(MODEL_PATH, "bert-base-uncased_step_2900.pt"))
|
||||
checkpoint = torch.load(os.path.join(MODEL_PATH, "new_model_step_148000_torch1.4.0.pt"))
|
||||
|
||||
checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt"))
|
||||
#checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000_with_global_step.pt"))
|
||||
|
||||
#checkpoint = torch.load(os.path.join(MODEL_PATH, "summarizer_step20000.pt"))
|
||||
summarizer = AbsSum(
|
||||
processor,
|
||||
cache_dir=CACHE_PATH,
|
||||
|
@ -289,8 +299,9 @@ def test_pretrained_model():
|
|||
|
||||
|
||||
#test_preprocessing()
|
||||
#test_collate()
|
||||
#preprocess_cnndm_abs()
|
||||
test_train_model()
|
||||
#test_train_model()
|
||||
#test_pretrained_model()
|
||||
#if __name__ == "__main__":
|
||||
# main()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче