notebook works for preprocessed data

This commit is contained in:
Daisy Deng 2020-03-18 14:57:24 +00:00
Родитель 868266117b
Коммит fda15d37ba
3 изменённых файлов: 205 добавлений и 285 удалений

Просмотреть файл

@ -68,7 +68,7 @@
"## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
"QUICK_RUN = True\n",
"## Set USE_PREPROCSSED_DATA = True to skip the data preprocessing\n",
"USE_PREPROCSSED_DATA = False"
"USE_PREPROCSSED_DATA = True"
]
},
{
@ -221,6 +221,15 @@
"MODEL_NAME = \"distilbert-base-uncased\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"processor = ExtSumProcessor(model_name=MODEL_NAME)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -261,7 +270,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -277,7 +286,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {
"scrolled": true
},
@ -293,29 +302,11 @@
"Preprocess the data and save the data to disk."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"processor = ExtSumProcessor(model_name=MODEL_NAME)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"run me\n",
"run me\n"
]
}
],
"outputs": [],
"source": [
"\n",
"ext_sum_train = processor.preprocess(train_dataset.get_source(), train_dataset.get_target(), oracle_mode=\"greedy\")\n",
@ -426,13 +417,13 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"if USE_PREPROCSSED_DATA:\n",
" download_path = CNNDMBertSumProcessedData.download(local_path=PROCESSED_DATA_PATH)\n",
" ext_sum_train, ext_sum_test = ExtSumProcessedData().splits(root=download_path)\n",
" ext_sum_train, ext_sum_test = ExtSumProcessedData().splits(root=download_path, train_iterable=True)\n",
" "
]
},
@ -513,36 +504,14 @@
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d837ccfb27684124b30c7340df196daa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Downloading', max=546, style=ProgressStyle(description_width=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"outputs": [],
"source": [
"summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 21,
"metadata": {
"scrolled": true
},
@ -551,18 +520,17 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration: 0%| | 0/200 [00:00<?, ?it/s]/dadendev/anaconda3/envs/cm3/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
"Iteration: 0it [00:00, ?it/s]/dadendev/anaconda3/envs/cm3/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"Iteration: 100%|██████████| 200/200 [01:08<00:00, 3.63it/s]\n",
"Iteration: 0%| | 0/200 [00:00<?, ?it/s]"
"Iteration: 201it [00:46, 6.07it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"timestamp: 18/03/2020 05:59:09, average loss: 12.685018, time duration: 68.321994,\n",
" number of examples in current reporting: 994, step 100\n",
"timestamp: 18/03/2020 14:43:03, average loss: 12.168449, time duration: 46.319392,\n",
" number of examples in current reporting: 1005, step 100\n",
" out of total 500\n"
]
},
@ -570,16 +538,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration: 100%|██████████| 200/200 [00:55<00:00, 3.53it/s]\n",
"Iteration: 0%| | 0/200 [00:00<?, ?it/s]"
"Iteration: 401it [01:28, 1.09s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"timestamp: 18/03/2020 06:00:04, average loss: 12.165642, time duration: 55.452829,\n",
" number of examples in current reporting: 994, step 200\n",
"timestamp: 18/03/2020 14:43:46, average loss: 10.760865, time duration: 42.219547,\n",
" number of examples in current reporting: 1009, step 200\n",
" out of total 500\n"
]
},
@ -587,16 +554,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration: 100%|██████████| 200/200 [00:55<00:00, 3.62it/s]\n",
"Iteration: 0%| | 0/200 [00:00<?, ?it/s]"
"Iteration: 601it [02:02, 6.25it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"timestamp: 18/03/2020 06:01:00, average loss: 12.017096, time duration: 55.258818,\n",
" number of examples in current reporting: 994, step 300\n",
"timestamp: 18/03/2020 14:44:19, average loss: 10.480665, time duration: 33.348272,\n",
" number of examples in current reporting: 1008, step 300\n",
" out of total 500\n"
]
},
@ -604,16 +570,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration: 100%|██████████| 200/200 [00:54<00:00, 3.69it/s]\n",
"Iteration: 0%| | 0/200 [00:00<?, ?it/s]"
"Iteration: 801it [02:36, 5.23it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"timestamp: 18/03/2020 06:01:54, average loss: 11.911538, time duration: 54.852058,\n",
" number of examples in current reporting: 994, step 400\n",
"timestamp: 18/03/2020 14:44:54, average loss: 10.546902, time duration: 34.703021,\n",
" number of examples in current reporting: 1008, step 400\n",
" out of total 500\n"
]
},
@ -621,15 +586,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration: 100%|██████████| 200/200 [00:54<00:00, 3.65it/s]"
"Iteration: 1001it [03:10, 6.07it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"timestamp: 18/03/2020 06:02:49, average loss: 11.925171, time duration: 54.771416,\n",
" number of examples in current reporting: 994, step 500\n",
"timestamp: 18/03/2020 14:45:27, average loss: 10.645204, time duration: 33.513287,\n",
" number of examples in current reporting: 1000, step 500\n",
" out of total 500\n"
]
},
@ -642,6 +607,7 @@
}
],
"source": [
"MAX_STEPS=5e2\n",
"summarizer.fit(\n",
" ext_sum_train,\n",
" num_gpus=NUM_GPUS,\n",
@ -653,12 +619,13 @@
" verbose=True,\n",
" report_every=REPORT_EVERY,\n",
" clip_grad_norm=False,\n",
" use_preprocessed_data=USE_PREPROCSSED_DATA,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"metadata": {},
"outputs": [
{
@ -682,7 +649,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"metadata": {},
"outputs": [
{
@ -691,7 +658,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 23,
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@ -721,16 +688,16 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['oracle_ids', 'source', 'target', 'src_txt'])"
"dict_keys(['src', 'labels', 'segs', 'clss', 'src_txt', 'tgt_txt'])"
]
},
"execution_count": 24,
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
@ -741,16 +708,16 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1000"
"11489"
]
},
"execution_count": 25,
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
@ -761,13 +728,13 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"if \"segs\" in ext_sum_test[0]: # preprocessed_data\n",
" source = [i['src_txt'] for i in test_dataset]\n",
" target = [i['tgt_txt'] for i in test_dataset]\n",
" source = [i['src_txt'] for i in ext_sum_test]\n",
" target = [i['tgt_txt'] for i in ext_sum_test]\n",
"else:\n",
" source = []\n",
" target = []\n",
@ -778,16 +745,16 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1000"
"11489"
]
},
"execution_count": 27,
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
@ -798,182 +765,13 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"new_target = [''.join(i) for i in list(target)]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'marseille prosecutor says `` so far no videos were used in the crash investigation `` despite media reports .journalists at bild and paris match are `` very confident `` the video clip is real , an editor says .andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says .'"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_target[0]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# clear cache\n",
"import gc; gc.collect()\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Scoring: 100%|██████████| 11/11 [00:15<00:00, 1.28s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 0, 3] [0, 1, 3]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 4] [0, 1, 4]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 3] [0, 1, 3]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 3] [0, 1, 3]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 3] [0, 1, 3]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 3] [0, 1, 3]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 3] [0, 1, 3]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"[1, 0, 2] [0, 1, 2]\n",
"CPU times: user 21.8 s, sys: 4.77 s, total: 26.5 s\n",
"Wall time: 15.4 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"%%time\n",
"prediction = summarizer.predict(ext_sum_test, num_gpus=NUM_GPUS, batch_size=96)"
]
},
{
"cell_type": "code",
"execution_count": 32,
@ -982,7 +780,7 @@
{
"data": {
"text/plain": [
"1000"
"'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program when the incident happened in january<q>he was flown back to chicago via air on march 20 but he died on sunday<q>initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed<q>his cousin claims he was attacked and thrown 40ft from a bridge'"
]
},
"execution_count": 32,
@ -990,13 +788,72 @@
"output_type": "execute_result"
}
],
"source": [
"new_target[0]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# clear cache\n",
"import gc; gc.collect()\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Scoring: 100%|██████████| 120/120 [00:23<00:00, 5.53it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 56.8 s, sys: 24.1 s, total: 1min 20s\n",
"Wall time: 24.4 s\n"
]
}
],
"source": [
"%%time\n",
"prediction = summarizer.predict(ext_sum_test, num_gpus=NUM_GPUS, batch_size=96)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"11489"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(prediction)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 36,
"metadata": {
"scrolled": true
},
@ -1005,17 +862,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Number of candidates: 1000\n",
"Number of references: 1000\n",
"{'rouge-1': {'f': 0.2913884877129372,\n",
" 'p': 0.22226840793392896,\n",
" 'r': 0.4674816151214348},\n",
" 'rouge-2': {'f': 0.10959347065888589,\n",
" 'p': 0.08214785364002329,\n",
" 'r': 0.18135539304426254},\n",
" 'rouge-l': {'f': 0.19622803575213676,\n",
" 'p': 0.14959018961944281,\n",
" 'r': 0.316327499245177}}\n"
"Number of candidates: 11489\n",
"Number of references: 11489\n",
"{'rouge-1': {'f': 0.4130683230423722,\n",
" 'p': 0.36433051425982743,\n",
" 'r': 0.5163535052789319},\n",
" 'rouge-2': {'f': 0.17487013092190415,\n",
" 'p': 0.15476697073043347,\n",
" 'r': 0.21792905995234627},\n",
" 'rouge-l': {'f': 0.25772796734227815,\n",
" 'p': 0.22731730255675572,\n",
" 'r': 0.32272796615510924}}\n"
]
}
],
@ -1026,36 +883,88 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 46,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program when the incident happened in january<q>he was flown back to chicago via air on march 20 but he died on sunday<q>initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed<q>his cousin claims he was attacked and thrown 40ft from a bridge'"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_target[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 47,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'The person would not be held for any length of time in an American facility .Although they advised that details could change before the announcement , administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border .Such an outbreak could spread quickly through the immigrant population and could infect large numbers of Border Patrol agents , leaving the southwestern border defenses weakened , the officials argued .'"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prediction[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 48,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'\\n'"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"source[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 49,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/scrapbook.scrap.json+json": {
"data": 0.17487013092190415,
"encoder": "json",
"name": "rouge_2_f_score",
"version": 1
}
},
"metadata": {
"scrapbook": {
"data": true,
"display": false,
"name": "rouge_2_f_score"
}
},
"output_type": "display_data"
}
],
"source": [
"# for testing\n",
"sb.glue(\"rouge_2_f_score\", rouge_scores['rouge-2']['f'])"
@ -1070,7 +979,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
@ -1086,7 +995,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
@ -1102,7 +1011,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 52,
"metadata": {},
"outputs": [
{
@ -1111,7 +1020,7 @@
"dict_keys(['source', 'src_txt'])"
]
},
"execution_count": 58,
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
@ -1122,14 +1031,14 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Scoring: 100%|██████████| 1/1 [00:00<00:00, 4.09it/s]\n"
"Scoring: 100%|██████████| 1/1 [00:00<00:00, 3.63it/s]\n"
]
}
],
@ -1139,16 +1048,16 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['But under the new rule , set to be announced in the next 48 hours , Border Patrol agents would immediately return anyone to Mexico — without any detainment and without any due process — who attempts to cross the southwestern border between the legal ports of entry .The person would not be held for any length of time in an American facility .Although they advised that details could change before the announcement , administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border .']"
"['The person would not be held for any length of time in an American facility .Although they advised that details could change before the announcement , administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border .Such an outbreak could spread quickly through the immigrant population and could infect large numbers of Border Patrol agents , leaving the southwestern border defenses weakened , the officials argued .']"
]
},
"execution_count": 60,
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
@ -1166,7 +1075,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [

Просмотреть файл

@ -321,6 +321,11 @@ class Transformer:
break
if fp16 and amp:
self.amp_state_dict = amp.state_dict()
# release GPU memories
self.model.cpu()
torch.cuda.empty_cache()
return global_step, tr_loss / global_step
def predict(self, eval_dataloader, get_inputs, num_gpus, gpu_ids, verbose=True):

Просмотреть файл

@ -705,6 +705,7 @@ class ExtractiveSummarizer(Transformer):
save_every=-1,
world_size=1,
rank=0,
use_preprocessed_data=False,
**kwargs,
):
"""
@ -787,7 +788,7 @@ class ExtractiveSummarizer(Transformer):
)
# batch_size is the number of tokens in a batch
if False: #use_preprocessed_data:
if use_preprocessed_data:
train_dataloader = get_dataloader(
train_dataset.get_stream(),
is_labeled=True,
@ -928,6 +929,11 @@ class ExtractiveSummarizer(Transformer):
top_n=top_n,
)
prediction.extend(temp_pred)
# release GPU memories
self.model.cpu()
torch.cuda.empty_cache()
return prediction
def predict_scores(self, test_dataloader, num_gpus=1, gpu_ids=None, verbose=True):