testing pretrained bertabs summarization model
This commit is contained in:
Родитель
27825cd843
Коммит
c05df5648f
|
@ -0,0 +1,977 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I0108 04:39:52.395967 139994149881664 file_utils.py:35] PyTorch version 1.2.0 available.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#! /usr/bin/python3\n",
|
||||
"import argparse\n",
|
||||
"import logging\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"from collections import namedtuple\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader, SequentialSampler\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"#\n",
|
||||
"from transformers import BertTokenizer\n",
|
||||
"#from transformers import BertAbs\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sys.path.insert(0, \"/dadendev/transformers/examples/summarization\")\n",
|
||||
"from modeling_bertabs import BertAbs, build_predictor\n",
|
||||
"\n",
|
||||
"from utils_summarization import (\n",
|
||||
" #SummarizationDataset,\n",
|
||||
" build_mask,\n",
|
||||
" compute_token_type_ids,\n",
|
||||
" encode_for_summarization,\n",
|
||||
" fit_to_block_size,\n",
|
||||
")\n",
|
||||
"from run_summarization import format_summary"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I0108 04:39:54.246211 139994149881664 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/daden/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n",
|
||||
"I0108 04:39:54.417192 139994149881664 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json from cache at /home/daden/.cache/torch/transformers/7ebb4ac81007d10b400cb6c2968d4c8f1275a3e0cc3bab7f20f81913198b542c.df616398f4c84def6fca83d755543b01cb445db4ddd218d3efeded8ded68332f\n",
|
||||
"I0108 04:39:54.418124 139994149881664 configuration_utils.py:199] Model config {\n",
|
||||
" \"dec_dropout\": 0.2,\n",
|
||||
" \"dec_ff_size\": 2048,\n",
|
||||
" \"dec_heads\": 8,\n",
|
||||
" \"dec_hidden_size\": 768,\n",
|
||||
" \"dec_layers\": 6,\n",
|
||||
" \"enc_dropout\": 0.2,\n",
|
||||
" \"enc_ff_size\": 512,\n",
|
||||
" \"enc_heads\": 8,\n",
|
||||
" \"enc_hidden_size\": 512,\n",
|
||||
" \"enc_layers\": 6,\n",
|
||||
" \"finetuning_task\": null,\n",
|
||||
" \"id2label\": {\n",
|
||||
" \"0\": \"LABEL_0\",\n",
|
||||
" \"1\": \"LABEL_1\"\n",
|
||||
" },\n",
|
||||
" \"is_decoder\": false,\n",
|
||||
" \"label2id\": {\n",
|
||||
" \"LABEL_0\": 0,\n",
|
||||
" \"LABEL_1\": 1\n",
|
||||
" },\n",
|
||||
" \"max_pos\": 512,\n",
|
||||
" \"num_labels\": 2,\n",
|
||||
" \"output_attentions\": false,\n",
|
||||
" \"output_hidden_states\": false,\n",
|
||||
" \"output_past\": true,\n",
|
||||
" \"pruned_heads\": {},\n",
|
||||
" \"torchscript\": false,\n",
|
||||
" \"use_bfloat16\": false,\n",
|
||||
" \"vocab_size\": 30522\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"I0108 04:39:54.558089 139994149881664 modeling_utils.py:406] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin from cache at /home/daden/.cache/torch/transformers/6f1af625ee57a9fbf093ef0863fb774fbdae89fa99fea7a213c08ad26f0724c0.ef06f4d767c6fad3c61125520f9dbb0f219834539c0369980ee5ecb9d1ef5542\n",
|
||||
"I0108 04:39:54.718922 139994149881664 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/daden/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
|
||||
"I0108 04:39:54.719776 139994149881664 configuration_utils.py:199] Model config {\n",
|
||||
" \"attention_probs_dropout_prob\": 0.1,\n",
|
||||
" \"finetuning_task\": null,\n",
|
||||
" \"hidden_act\": \"gelu\",\n",
|
||||
" \"hidden_dropout_prob\": 0.1,\n",
|
||||
" \"hidden_size\": 768,\n",
|
||||
" \"id2label\": {\n",
|
||||
" \"0\": \"LABEL_0\",\n",
|
||||
" \"1\": \"LABEL_1\"\n",
|
||||
" },\n",
|
||||
" \"initializer_range\": 0.02,\n",
|
||||
" \"intermediate_size\": 3072,\n",
|
||||
" \"is_decoder\": false,\n",
|
||||
" \"label2id\": {\n",
|
||||
" \"LABEL_0\": 0,\n",
|
||||
" \"LABEL_1\": 1\n",
|
||||
" },\n",
|
||||
" \"layer_norm_eps\": 1e-12,\n",
|
||||
" \"max_position_embeddings\": 512,\n",
|
||||
" \"num_attention_heads\": 12,\n",
|
||||
" \"num_hidden_layers\": 12,\n",
|
||||
" \"num_labels\": 2,\n",
|
||||
" \"output_attentions\": false,\n",
|
||||
" \"output_hidden_states\": false,\n",
|
||||
" \"output_past\": true,\n",
|
||||
" \"pruned_heads\": {},\n",
|
||||
" \"torchscript\": false,\n",
|
||||
" \"type_vocab_size\": 2,\n",
|
||||
" \"use_bfloat16\": false,\n",
|
||||
" \"vocab_size\": 30522\n",
|
||||
"}\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\", do_lower_case=True)\n",
|
||||
"model = BertAbs.from_pretrained(\"bertabs-finetuned-cnndm\")\n",
|
||||
"model.to(\"cuda\")\n",
|
||||
"model.eval()\n",
|
||||
"\n",
|
||||
"symbols = {\n",
|
||||
" \"BOS\": tokenizer.vocab[\"[unused0]\"],\n",
|
||||
" \"EOS\": tokenizer.vocab[\"[unused1]\"],\n",
|
||||
" \"PAD\": tokenizer.vocab[\"[PAD]\"],\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nlp_path = os.path.abspath(\"../../\")\n",
|
||||
"if nlp_path not in sys.path:\n",
|
||||
" sys.path.insert(0, nlp_path)\n",
|
||||
"from utils_nlp.models.transformers.extractive_summarization import Bunch\n",
|
||||
"args = Bunch({\"block_trigram\":True, \"alpha\": 0.95, \"beam_size\": 5, \"min_length\": 20, \"max_length\": 200})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictor = build_predictor(args, tokenizer, symbols, model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[nltk_data] Downloading package punkt to /home/daden/nltk_data...\n",
|
||||
"[nltk_data] Package punkt is already up-to-date!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"import sys\n",
|
||||
"from tempfile import TemporaryDirectory\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"nlp_path = os.path.abspath(\"../../\")\n",
|
||||
"if nlp_path not in sys.path:\n",
|
||||
" sys.path.insert(0, nlp_path)\n",
|
||||
"\n",
|
||||
"from utils_nlp.common.pytorch_utils import get_device\n",
|
||||
"from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset\n",
|
||||
"from utils_nlp.eval.evaluate_summarization import get_rouge\n",
|
||||
"from utils_nlp.models.transformers.extractive_summarization import (\n",
|
||||
" ExtractiveSummarizer,\n",
|
||||
" ExtSumProcessedData,\n",
|
||||
" ExtSumProcessor,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import scrapbook as sb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#from utils_nlp.models.transformers.datasets import SummarizationDataset\n",
|
||||
"from utils_nlp.dataset.cnndm import CNNDMAbsSumDataset, CNNDMSummarizationDataset\n",
|
||||
"#def build_data_iterator(args, tokenizer):\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_PATH = '/tmp/tmpbvxzmv1v'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"QUICK_RUN = False\n",
|
||||
"# the data path used to save the downloaded data file\n",
|
||||
"DATA_PATH = TemporaryDirectory().name\n",
|
||||
"# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n",
|
||||
"TOP_N = 4\n",
|
||||
"CHUNK_SIZE=200\n",
|
||||
"if not QUICK_RUN:\n",
|
||||
" TOP_N = -1\n",
|
||||
" CHUNK_SIZE = 2000"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"class SummarizationDataset(Dataset):\n",
|
||||
" def __init__(self, source, target=None):\n",
|
||||
" self.source = source\n",
|
||||
" self.target = target\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.source)\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" return self.source[idx], self.target[idx]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 489k/489k [00:07<00:00, 68.9kKB/s] \n",
|
||||
"I0108 04:40:12.525922 139994149881664 utils.py:173] Opening tar file /tmp/tmpbbgje18p/cnndm.tar.gz.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"train_dataset, test_dataset = CNNDMAbsSumDataset(top_n=TOP_N, local_cache_path=DATA_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = list(test_dataset.get_source()), list(test_dataset.get_target())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"11490"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(data[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_sum_dataset = SummarizationDataset(data[0], data[1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Batch = namedtuple(\"Batch\", [ \"batch_size\", \"src\", \"segs\", \"mask_src\", \"tgt_str\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def collate(data, tokenizer, block_size, device):\n",
|
||||
" \"\"\" Collate formats the data passed to the data loader.\n",
|
||||
" In particular we tokenize the data batch after batch to avoid keeping them\n",
|
||||
" all in memory. We output the data as a namedtuple to fit the original BertAbs's\n",
|
||||
" API.\n",
|
||||
" \"\"\"\n",
|
||||
" data = [x for x in data if not len(x[1]) == 0] # remove empty_files\n",
|
||||
" #print(data)\n",
|
||||
" #names = [name for name, _, _ in data]\n",
|
||||
" # summaries = [\" \".join(summary_list) for _, _, summary_list in data]\n",
|
||||
" summaries = [\" \".join(summary_list) for _, summary_list in data]\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" encoded_text = [encode_for_summarization(story, summary, tokenizer) for story, summary in data]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" #\"\"\"\"\"\"\n",
|
||||
" encoded_stories = torch.tensor(\n",
|
||||
" [fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]\n",
|
||||
" )\n",
|
||||
" encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)\n",
|
||||
" encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)\n",
|
||||
" #\"\"\"\n",
|
||||
" print(len(encoded_stories))\n",
|
||||
"\n",
|
||||
" batch = Batch(\n",
|
||||
" #document_names=None,\n",
|
||||
" batch_size=len(encoded_stories),\n",
|
||||
" src=encoded_stories.to(device),\n",
|
||||
" segs=encoder_token_type_ids.to(device),\n",
|
||||
" mask_src=encoder_mask.to(device),\n",
|
||||
" tgt_str=summaries,\n",
|
||||
" )\n",
|
||||
" return batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def encode_for_summarization(story_lines, summary_lines, tokenizer, max_len=512):\n",
|
||||
" \"\"\" Encode the story and summary lines, and join them\n",
|
||||
" as specified in [1] by using `[SEP] [CLS]` tokens to separate\n",
|
||||
" sentences.\n",
|
||||
" \"\"\"\n",
|
||||
" story_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in story_lines]\n",
|
||||
" story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]\n",
|
||||
" summary_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in summary_lines]\n",
|
||||
" summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]\n",
|
||||
"\n",
|
||||
" return story_token_ids, summary_token_ids"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_data_iterator(dataset, tokenizer, batch_size=16, device='cuda'):\n",
|
||||
" \n",
|
||||
" sampler = SequentialSampler(dataset)\n",
|
||||
"\n",
|
||||
" def collate_fn(data):\n",
|
||||
" return collate(data, tokenizer, block_size=512, device=device)\n",
|
||||
"\n",
|
||||
" iterator = DataLoader(dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn,)\n",
|
||||
"\n",
|
||||
" return iterator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils_nlp.common.pytorch_utils import get_device\n",
|
||||
"device, num_gpus = get_device(num_gpus=4, local_rank=-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_iterator = build_data_iterator(train_sum_dataset, tokenizer, batch_size=64, device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reference_summaries = []\n",
|
||||
"generated_summaries = []"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 0%| | 0/180 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 1%| | 1/180 [00:23<1:08:39, 23.01s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 1%| | 2/180 [00:41<1:03:56, 21.55s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 2%|▏ | 3/180 [01:01<1:02:13, 21.09s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 2%|▏ | 4/180 [01:20<59:58, 20.45s/it] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 3%|▎ | 5/180 [01:39<58:56, 20.21s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 3%|▎ | 6/180 [01:57<56:05, 19.34s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 4%|▍ | 7/180 [02:16<55:39, 19.31s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 4%|▍ | 8/180 [02:37<57:22, 20.02s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 5%|▌ | 9/180 [02:58<57:06, 20.04s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 6%|▌ | 10/180 [03:17<56:26, 19.92s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 6%|▌ | 11/180 [03:38<56:26, 20.04s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 7%|▋ | 12/180 [03:59<57:18, 20.47s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 7%|▋ | 13/180 [04:20<57:12, 20.56s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 8%|▊ | 14/180 [04:40<56:14, 20.33s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 8%|▊ | 15/180 [05:01<56:39, 20.60s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 9%|▉ | 16/180 [05:20<55:06, 20.16s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 9%|▉ | 17/180 [05:40<55:04, 20.27s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 10%|█ | 18/180 [06:12<1:03:38, 23.57s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 11%|█ | 19/180 [06:44<1:10:07, 26.13s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 11%|█ | 20/180 [07:14<1:13:00, 27.38s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 12%|█▏ | 21/180 [07:43<1:13:54, 27.89s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 12%|█▏ | 22/180 [08:15<1:16:35, 29.08s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 13%|█▎ | 23/180 [08:43<1:15:04, 28.69s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 13%|█▎ | 24/180 [09:20<1:20:53, 31.11s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 14%|█▍ | 25/180 [09:50<1:20:03, 30.99s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 14%|█▍ | 26/180 [10:26<1:22:53, 32.30s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 15%|█▌ | 27/180 [10:55<1:20:17, 31.49s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
" 16%|█▌ | 28/180 [11:25<1:18:29, 30.98s/it]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for batch in tqdm(data_iterator):\n",
|
||||
" \n",
|
||||
" batch_data = predictor.translate_batch(batch)\n",
|
||||
" translations = predictor.from_batch(batch_data)\n",
|
||||
" summaries = [format_summary(t) for t in translations]\n",
|
||||
" #save_summaries(summaries, args.summaries_output_dir, batch.document_names)\n",
|
||||
"\n",
|
||||
" if True:\n",
|
||||
" reference_summaries += batch.tgt_str\n",
|
||||
" generated_summaries += summaries"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reference_summaries[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"generated_summaries[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" def _write_list_to_file(list_items, filename):\n",
|
||||
" with open(filename, \"w\") as filehandle:\n",
|
||||
" # for cnt, line in enumerate(filehandle):\n",
|
||||
" for item in list_items:\n",
|
||||
" filehandle.write(\"%s\\n\" % item)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"_write_list_to_file(generated_summaries, \"./generated_summaries\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "python3.6 cm3",
|
||||
"language": "python",
|
||||
"name": "cm3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -30,6 +30,71 @@ from utils_nlp.models.transformers.extractive_summarization import get_dataset,
|
|||
|
||||
|
||||
|
||||
def CNNDMAbsSumDataset(*args, **kwargs):
|
||||
"""Load the CNN/Daily Mail dataset preprocessed by harvardnlp group."""
|
||||
|
||||
REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}",
|
||||
"-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'}
|
||||
|
||||
|
||||
def _clean(x):
|
||||
return re.sub(
|
||||
r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''",
|
||||
lambda m: REMAP.get(m.group()), x)
|
||||
|
||||
|
||||
def _remove_ttags(line):
|
||||
line = re.sub(r"<t>", "", line)
|
||||
# change </t> to <q>
|
||||
# pyrouge test requires <q> as sentence splitter
|
||||
line = re.sub(r"</t>", "<q>", line)
|
||||
return line
|
||||
|
||||
|
||||
def _target_sentence_tokenization(line):
|
||||
return line.split("<q>")
|
||||
|
||||
URLS = ["https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz"]
|
||||
|
||||
|
||||
def _setup_datasets(url, top_n=-1, local_cache_path=".data"):
|
||||
FILE_NAME = "cnndm.tar.gz"
|
||||
maybe_download(url, FILE_NAME, local_cache_path)
|
||||
dataset_tar = os.path.join(local_cache_path, FILE_NAME)
|
||||
extracted_files = extract_archive(dataset_tar)
|
||||
for fname in extracted_files:
|
||||
if fname.endswith("train.txt.src"):
|
||||
train_source_file = fname
|
||||
if fname.endswith("train.txt.tgt.tagged"):
|
||||
train_target_file = fname
|
||||
if fname.endswith("test.txt.src"):
|
||||
test_source_file = fname
|
||||
if fname.endswith("test.txt.tgt.tagged"):
|
||||
test_target_file = fname
|
||||
|
||||
return (
|
||||
SummarizationDataset(
|
||||
train_source_file,
|
||||
train_target_file,
|
||||
[_clean],
|
||||
[_clean, _remove_ttags,],
|
||||
None,
|
||||
top_n,
|
||||
),
|
||||
SummarizationDataset(
|
||||
test_source_file,
|
||||
test_target_file,
|
||||
[_clean],
|
||||
[_clean, _remove_ttags,],
|
||||
None,
|
||||
top_n,
|
||||
),
|
||||
)
|
||||
|
||||
return _setup_datasets(*((URLS[0],) + args), **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
def CNNDMSummarizationDataset(*args, **kwargs):
|
||||
"""Load the CNN/Daily Mail dataset preprocessed by harvardnlp group."""
|
||||
|
|
|
@ -240,7 +240,10 @@ def _preprocess(param):
|
|||
sentences, preprocess_pipeline, word_tokenize = param
|
||||
for function in preprocess_pipeline:
|
||||
sentences = function(sentences)
|
||||
return [word_tokenize(sentence) for sentence in sentences]
|
||||
if word_tokenize:
|
||||
return [word_tokenize(sentence) for sentence in sentences]
|
||||
else:
|
||||
return [sentences]
|
||||
|
||||
|
||||
def _create_data_from_iterator(iterator, preprocessing, word_tokenizer):
|
||||
|
@ -299,6 +302,9 @@ class SummarizationDataset(IterableDataset):
|
|||
def __iter__(self):
|
||||
for x in self._source:
|
||||
yield x
|
||||
|
||||
def get_source(self):
|
||||
return self._source
|
||||
|
||||
def get_target(self):
|
||||
return self._target
|
||||
|
|
Загрузка…
Ссылка в новой задаче