testing pretrained bertabs summarization model

This commit is contained in:
Daisy Deng 2020-01-08 04:56:50 +00:00
Родитель 27825cd843
Коммит c05df5648f
3 изменённых файлов: 1049 добавлений и 1 удалений

Просмотреть файл

@ -0,0 +1,977 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0108 04:39:52.395967 139994149881664 file_utils.py:35] PyTorch version 1.2.0 available.\n"
]
}
],
"source": [
"#! /usr/bin/python3\n",
"import argparse\n",
"import logging\n",
"import os\n",
"import sys\n",
"from collections import namedtuple\n",
"\n",
"import torch\n",
"from torch.utils.data import DataLoader, SequentialSampler\n",
"from tqdm import tqdm\n",
"\n",
"#\n",
"from transformers import BertTokenizer\n",
"#from transformers import BertAbs\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"sys.path.insert(0, \"/dadendev/transformers/examples/summarization\")\n",
"from modeling_bertabs import BertAbs, build_predictor\n",
"\n",
"from utils_summarization import (\n",
" #SummarizationDataset,\n",
" build_mask,\n",
" compute_token_type_ids,\n",
" encode_for_summarization,\n",
" fit_to_block_size,\n",
")\n",
"from run_summarization import format_summary"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0108 04:39:54.246211 139994149881664 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/daden/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n",
"I0108 04:39:54.417192 139994149881664 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json from cache at /home/daden/.cache/torch/transformers/7ebb4ac81007d10b400cb6c2968d4c8f1275a3e0cc3bab7f20f81913198b542c.df616398f4c84def6fca83d755543b01cb445db4ddd218d3efeded8ded68332f\n",
"I0108 04:39:54.418124 139994149881664 configuration_utils.py:199] Model config {\n",
" \"dec_dropout\": 0.2,\n",
" \"dec_ff_size\": 2048,\n",
" \"dec_heads\": 8,\n",
" \"dec_hidden_size\": 768,\n",
" \"dec_layers\": 6,\n",
" \"enc_dropout\": 0.2,\n",
" \"enc_ff_size\": 512,\n",
" \"enc_heads\": 8,\n",
" \"enc_hidden_size\": 512,\n",
" \"enc_layers\": 6,\n",
" \"finetuning_task\": null,\n",
" \"id2label\": {\n",
" \"0\": \"LABEL_0\",\n",
" \"1\": \"LABEL_1\"\n",
" },\n",
" \"is_decoder\": false,\n",
" \"label2id\": {\n",
" \"LABEL_0\": 0,\n",
" \"LABEL_1\": 1\n",
" },\n",
" \"max_pos\": 512,\n",
" \"num_labels\": 2,\n",
" \"output_attentions\": false,\n",
" \"output_hidden_states\": false,\n",
" \"output_past\": true,\n",
" \"pruned_heads\": {},\n",
" \"torchscript\": false,\n",
" \"use_bfloat16\": false,\n",
" \"vocab_size\": 30522\n",
"}\n",
"\n",
"I0108 04:39:54.558089 139994149881664 modeling_utils.py:406] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin from cache at /home/daden/.cache/torch/transformers/6f1af625ee57a9fbf093ef0863fb774fbdae89fa99fea7a213c08ad26f0724c0.ef06f4d767c6fad3c61125520f9dbb0f219834539c0369980ee5ecb9d1ef5542\n",
"I0108 04:39:54.718922 139994149881664 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/daden/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
"I0108 04:39:54.719776 139994149881664 configuration_utils.py:199] Model config {\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"finetuning_task\": null,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"id2label\": {\n",
" \"0\": \"LABEL_0\",\n",
" \"1\": \"LABEL_1\"\n",
" },\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"is_decoder\": false,\n",
" \"label2id\": {\n",
" \"LABEL_0\": 0,\n",
" \"LABEL_1\": 1\n",
" },\n",
" \"layer_norm_eps\": 1e-12,\n",
" \"max_position_embeddings\": 512,\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"num_labels\": 2,\n",
" \"output_attentions\": false,\n",
" \"output_hidden_states\": false,\n",
" \"output_past\": true,\n",
" \"pruned_heads\": {},\n",
" \"torchscript\": false,\n",
" \"type_vocab_size\": 2,\n",
" \"use_bfloat16\": false,\n",
" \"vocab_size\": 30522\n",
"}\n",
"\n"
]
}
],
"source": [
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\", do_lower_case=True)\n",
"model = BertAbs.from_pretrained(\"bertabs-finetuned-cnndm\")\n",
"model.to(\"cuda\")\n",
"model.eval()\n",
"\n",
"symbols = {\n",
" \"BOS\": tokenizer.vocab[\"[unused0]\"],\n",
" \"EOS\": tokenizer.vocab[\"[unused1]\"],\n",
" \"PAD\": tokenizer.vocab[\"[PAD]\"],\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"nlp_path = os.path.abspath(\"../../\")\n",
"if nlp_path not in sys.path:\n",
" sys.path.insert(0, nlp_path)\n",
"from utils_nlp.models.transformers.extractive_summarization import Bunch\n",
"args = Bunch({\"block_trigram\":True, \"alpha\": 0.95, \"beam_size\": 5, \"min_length\": 20, \"max_length\": 200})"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"predictor = build_predictor(args, tokenizer, symbols, model)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package punkt to /home/daden/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"import sys\n",
"from tempfile import TemporaryDirectory\n",
"import torch\n",
"\n",
"nlp_path = os.path.abspath(\"../../\")\n",
"if nlp_path not in sys.path:\n",
" sys.path.insert(0, nlp_path)\n",
"\n",
"from utils_nlp.common.pytorch_utils import get_device\n",
"from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset\n",
"from utils_nlp.eval.evaluate_summarization import get_rouge\n",
"from utils_nlp.models.transformers.extractive_summarization import (\n",
" ExtractiveSummarizer,\n",
" ExtSumProcessedData,\n",
" ExtSumProcessor,\n",
")\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scrapbook as sb"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"#from utils_nlp.models.transformers.datasets import SummarizationDataset\n",
"from utils_nlp.dataset.cnndm import CNNDMAbsSumDataset, CNNDMSummarizationDataset\n",
"#def build_data_iterator(args, tokenizer):\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"DATA_PATH = '/tmp/tmpbvxzmv1v'"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"QUICK_RUN = False\n",
"# the data path used to save the downloaded data file\n",
"DATA_PATH = TemporaryDirectory().name\n",
"# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n",
"TOP_N = 4\n",
"CHUNK_SIZE=200\n",
"if not QUICK_RUN:\n",
" TOP_N = -1\n",
" CHUNK_SIZE = 2000"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"class SummarizationDataset(Dataset):\n",
" def __init__(self, source, target=None):\n",
" self.source = source\n",
" self.target = target\n",
" def __len__(self):\n",
" return len(self.source)\n",
" def __getitem__(self, idx):\n",
" return self.source[idx], self.target[idx]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 489k/489k [00:07<00:00, 68.9kKB/s] \n",
"I0108 04:40:12.525922 139994149881664 utils.py:173] Opening tar file /tmp/tmpbbgje18p/cnndm.tar.gz.\n"
]
}
],
"source": [
"train_dataset, test_dataset = CNNDMAbsSumDataset(top_n=TOP_N, local_cache_path=DATA_PATH)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"data = list(test_dataset.get_source()), list(test_dataset.get_target())"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"11490"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(data[1])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"test_sum_dataset = SummarizationDataset(data[0], data[1])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"Batch = namedtuple(\"Batch\", [ \"batch_size\", \"src\", \"segs\", \"mask_src\", \"tgt_str\"])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"def collate(data, tokenizer, block_size, device):\n",
" \"\"\" Collate formats the data passed to the data loader.\n",
" In particular we tokenize the data batch after batch to avoid keeping them\n",
" all in memory. We output the data as a namedtuple to fit the original BertAbs's\n",
" API.\n",
" \"\"\"\n",
" data = [x for x in data if not len(x[1]) == 0] # remove empty_files\n",
" #print(data)\n",
" #names = [name for name, _, _ in data]\n",
" # summaries = [\" \".join(summary_list) for _, _, summary_list in data]\n",
" summaries = [\" \".join(summary_list) for _, summary_list in data]\n",
" \n",
"\n",
" encoded_text = [encode_for_summarization(story, summary, tokenizer) for story, summary in data]\n",
" \n",
" \n",
" #\"\"\"\"\"\"\n",
" encoded_stories = torch.tensor(\n",
" [fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]\n",
" )\n",
" encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)\n",
" encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)\n",
" #\"\"\"\n",
" print(len(encoded_stories))\n",
"\n",
" batch = Batch(\n",
" #document_names=None,\n",
" batch_size=len(encoded_stories),\n",
" src=encoded_stories.to(device),\n",
" segs=encoder_token_type_ids.to(device),\n",
" mask_src=encoder_mask.to(device),\n",
" tgt_str=summaries,\n",
" )\n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def encode_for_summarization(story_lines, summary_lines, tokenizer, max_len=512):\n",
" \"\"\" Encode the story and summary lines, and join them\n",
" as specified in [1] by using `[SEP] [CLS]` tokens to separate\n",
" sentences.\n",
" \"\"\"\n",
" story_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in story_lines]\n",
" story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]\n",
" summary_lines_token_ids = [tokenizer.encode(line, max_length=max_len) for line in summary_lines]\n",
" summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]\n",
"\n",
" return story_token_ids, summary_token_ids"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def build_data_iterator(dataset, tokenizer, batch_size=16, device='cuda'):\n",
" \n",
" sampler = SequentialSampler(dataset)\n",
"\n",
" def collate_fn(data):\n",
" return collate(data, tokenizer, block_size=512, device=device)\n",
"\n",
" iterator = DataLoader(dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn,)\n",
"\n",
" return iterator"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from utils_nlp.common.pytorch_utils import get_device\n",
"device, num_gpus = get_device(num_gpus=4, local_rank=-1)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"data_iterator = build_data_iterator(train_sum_dataset, tokenizer, batch_size=64, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"reference_summaries = []\n",
"generated_summaries = []"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/180 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 1%| | 1/180 [00:23<1:08:39, 23.01s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 1%| | 2/180 [00:41<1:03:56, 21.55s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 2%|▏ | 3/180 [01:01<1:02:13, 21.09s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 2%|▏ | 4/180 [01:20<59:58, 20.45s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 3%|▎ | 5/180 [01:39<58:56, 20.21s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 3%|▎ | 6/180 [01:57<56:05, 19.34s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 4%|▍ | 7/180 [02:16<55:39, 19.31s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 4%|▍ | 8/180 [02:37<57:22, 20.02s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 5%|▌ | 9/180 [02:58<57:06, 20.04s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 6%|▌ | 10/180 [03:17<56:26, 19.92s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 6%|▌ | 11/180 [03:38<56:26, 20.04s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 7%|▋ | 12/180 [03:59<57:18, 20.47s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 7%|▋ | 13/180 [04:20<57:12, 20.56s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 8%|▊ | 14/180 [04:40<56:14, 20.33s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 8%|▊ | 15/180 [05:01<56:39, 20.60s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 9%|▉ | 16/180 [05:20<55:06, 20.16s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 9%|▉ | 17/180 [05:40<55:04, 20.27s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 10%|█ | 18/180 [06:12<1:03:38, 23.57s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 11%|█ | 19/180 [06:44<1:10:07, 26.13s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 11%|█ | 20/180 [07:14<1:13:00, 27.38s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 12%|█▏ | 21/180 [07:43<1:13:54, 27.89s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 12%|█▏ | 22/180 [08:15<1:16:35, 29.08s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 13%|█▎ | 23/180 [08:43<1:15:04, 28.69s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 13%|█▎ | 24/180 [09:20<1:20:53, 31.11s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 14%|█▍ | 25/180 [09:50<1:20:03, 30.99s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 14%|█▍ | 26/180 [10:26<1:22:53, 32.30s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 15%|█▌ | 27/180 [10:55<1:20:17, 31.49s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 16%|█▌ | 28/180 [11:25<1:18:29, 30.98s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"64\n"
]
}
],
"source": [
"for batch in tqdm(data_iterator):\n",
" \n",
" batch_data = predictor.translate_batch(batch)\n",
" translations = predictor.from_batch(batch_data)\n",
" summaries = [format_summary(t) for t in translations]\n",
" #save_summaries(summaries, args.summaries_output_dir, batch.document_names)\n",
"\n",
" if True:\n",
" reference_summaries += batch.tgt_str\n",
" generated_summaries += summaries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reference_summaries[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generated_summaries[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
" def _write_list_to_file(list_items, filename):\n",
" with open(filename, \"w\") as filehandle:\n",
" # for cnt, line in enumerate(filehandle):\n",
" for item in list_items:\n",
" filehandle.write(\"%s\\n\" % item)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_write_list_to_file(generated_summaries, \"./generated_summaries\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3.6 cm3",
"language": "python",
"name": "cm3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -30,6 +30,71 @@ from utils_nlp.models.transformers.extractive_summarization import get_dataset,
def CNNDMAbsSumDataset(*args, **kwargs):
"""Load the CNN/Daily Mail dataset preprocessed by harvardnlp group."""
REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}",
"-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'}
def _clean(x):
return re.sub(
r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''",
lambda m: REMAP.get(m.group()), x)
def _remove_ttags(line):
line = re.sub(r"<t>", "", line)
# change </t> to <q>
# pyrouge test requires <q> as sentence splitter
line = re.sub(r"</t>", "<q>", line)
return line
def _target_sentence_tokenization(line):
return line.split("<q>")
URLS = ["https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz"]
def _setup_datasets(url, top_n=-1, local_cache_path=".data"):
FILE_NAME = "cnndm.tar.gz"
maybe_download(url, FILE_NAME, local_cache_path)
dataset_tar = os.path.join(local_cache_path, FILE_NAME)
extracted_files = extract_archive(dataset_tar)
for fname in extracted_files:
if fname.endswith("train.txt.src"):
train_source_file = fname
if fname.endswith("train.txt.tgt.tagged"):
train_target_file = fname
if fname.endswith("test.txt.src"):
test_source_file = fname
if fname.endswith("test.txt.tgt.tagged"):
test_target_file = fname
return (
SummarizationDataset(
train_source_file,
train_target_file,
[_clean],
[_clean, _remove_ttags,],
None,
top_n,
),
SummarizationDataset(
test_source_file,
test_target_file,
[_clean],
[_clean, _remove_ttags,],
None,
top_n,
),
)
return _setup_datasets(*((URLS[0],) + args), **kwargs)
def CNNDMSummarizationDataset(*args, **kwargs):
"""Load the CNN/Daily Mail dataset preprocessed by harvardnlp group."""

Просмотреть файл

@ -240,7 +240,10 @@ def _preprocess(param):
sentences, preprocess_pipeline, word_tokenize = param
for function in preprocess_pipeline:
sentences = function(sentences)
return [word_tokenize(sentence) for sentence in sentences]
if word_tokenize:
return [word_tokenize(sentence) for sentence in sentences]
else:
return [sentences]
def _create_data_from_iterator(iterator, preprocessing, word_tokenizer):
@ -299,6 +302,9 @@ class SummarizationDataset(IterableDataset):
def __iter__(self):
for x in self._source:
yield x
def get_source(self):
return self._source
def get_target(self):
return self._target