зеркало из https://github.com/microsoft/HiTab.git
update code for data2text
This commit is contained in:
Родитель
0b7cff3fe0
Коммит
7cd255becc
|
@ -1,2 +1,3 @@
|
|||
.DS_STORE
|
||||
.idea/
|
||||
.idea/
|
||||
**/__pycache__/**
|
|
@ -0,0 +1,46 @@
|
|||
# Data-to-Text Generation
|
||||
|
||||
We explore four baseline models to generate meaning text from hierarchical tables in HiTab.
|
||||
Three of them are transformer-based models: T5, BART, and BERT-to-BERT. The other is a Pointer-Generator Network based on LSTM architecture.
|
||||
|
||||
## [0] Preliminaries
|
||||
To start with, make sure to install the following requirements:
|
||||
```
|
||||
pip install openpyxl
|
||||
pip install datasets
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
|
||||
## [1] Data Pre-processing
|
||||
Read in the `train_samples.jsonl`, `dev_samples.jsonl`, `test_samples.jsonl` in the `./data/` directory.
|
||||
|
||||
Process each sample with: (1) highlighted/linked table cells, (2) with additional operations and answer(s).
|
||||
- The generation `target` label is the annotated `sub_sentence`.
|
||||
- To create a serialized table data input, we need to: (1) find all linked entity/quantity cells, (2) find all of their ascendants, then linearize their cell contents following a top-down left-to-right order. If extra operational information is required, we will then append the answer formula and answer string to the `source` as the final model input.
|
||||
|
||||
This process create pairs of source-target for train/dev/test sets.
|
||||
To perform data pre-processing for the **cell highlight** setting, simply run:
|
||||
```bash
|
||||
python do_preprocess.py
|
||||
```
|
||||
Or to enable the **cell & calculation** setting, specify the additional argument by:
|
||||
```bash
|
||||
python do_preprocess.py --add_aggr
|
||||
```
|
||||
Both will load the data from `hitab/data/` directory and generate a processed version in `hitab/data2text/data/`.
|
||||
|
||||
Note that the input samples require a another layer of tokenization, using `hitab/data2text/experiment/pointer_generator/parse_sample.py`.
|
||||
|
||||
|
||||
## [2] Experiment: Training and Evaluation
|
||||
|
||||
The `experiment` directory contains the code for training (`train_d2t.py`) and evaluation (`eval_d2t.py`).
|
||||
The T5, BART, and BERT-to-BERT directly call the training process from the installed [`transformers`](https://github.com/huggingface/transformers) library.
|
||||
Pointer-Generator Network (PGN) requires additional code modules, specifically in the `pointer_generator` directory.
|
||||
|
||||
To follow the training pipeline, take BART for an example, run:
|
||||
```bash
|
||||
python run_experiment.py --expr_name bart --do_train --do_eval --do_test
|
||||
```
|
||||
Alter the `expr_name` argument among t5/bart/b2b/pgn to explore different models.
|
|
@ -0,0 +1,322 @@
|
|||
"""Preprocess train/dev/test data into source/target pairs. """
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# %% table info
|
||||
|
||||
def load_table(table_id: str) -> Dict:
|
||||
table_path = os.path.join(args.dataset_dir, args.table_subdir, args.table_type, f"{table_id}.json")
|
||||
with open(table_path, 'r') as fr:
|
||||
table = json.load(fr)
|
||||
return table
|
||||
|
||||
def get_ascendant_cells(cell_coords: Tuple[int, int], table: Dict, field: str) -> Dict[Tuple, str]:
|
||||
"""From a cell coordinate, find all its ascendant cells.
|
||||
args:
|
||||
cell_coords: header cell (1, 1), or data cell (17, 4)
|
||||
rets:
|
||||
asc_dict: {(0, 2): 'cell_str2', ...}
|
||||
"""
|
||||
asc_cells = set([cell_coords])
|
||||
|
||||
# find leaf coords
|
||||
if field == 'top_root':
|
||||
base_col_idx = cell_coords[1]
|
||||
|
||||
n = table['top_header_rows_num']
|
||||
if cell_coords[0] >= n:
|
||||
base_row_idx = n - 1
|
||||
else:
|
||||
base_row_idx = cell_coords[0]
|
||||
elif field == 'left_root':
|
||||
base_row_idx = cell_coords[0]
|
||||
|
||||
n = table['left_header_columns_num']
|
||||
if cell_coords[1] >= n:
|
||||
base_col_idx = n - 1
|
||||
else:
|
||||
base_col_idx = cell_coords[1]
|
||||
else:
|
||||
base_row_idx, base_col_idx = cell_coords
|
||||
|
||||
dfs(table[field], (base_row_idx, base_col_idx), asc_cells)
|
||||
|
||||
asc_dict = {}
|
||||
for coords in asc_cells:
|
||||
if max(coords) > -1 and coords[0] < len(table['texts']) and coords[1] < len(table['texts'][coords[0]]):
|
||||
asc_dict[coords] = table['texts'][coords[0]][coords[1]]
|
||||
|
||||
return asc_dict
|
||||
|
||||
def dfs(node: Dict, ref_coords: Tuple[int, int], ascendants: List[Tuple[int, int]]) -> bool:
|
||||
"""Searching from the (current) node.
|
||||
If node coordinates match, return True to propagate back to ascendants.
|
||||
Else: continue to children if have any. Otherwise would terminate the path.
|
||||
"""
|
||||
if (node['row_index'] == ref_coords[0]) and (node['column_index'] == ref_coords[1]):
|
||||
ascendants.add(ref_coords)
|
||||
return True
|
||||
for child_node in node['children']:
|
||||
if dfs(child_node, ref_coords, ascendants):
|
||||
r, c = node['row_index'], node['column_index']
|
||||
ascendants.add( (r, c) )
|
||||
return True
|
||||
return False # no 'children' or not any that matches
|
||||
|
||||
|
||||
# table list for parent (metric) evaluation
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""Only has single blankspace as delimiters."""
|
||||
parts = text.split()
|
||||
parts = [p for part in parts for p in part.split('\t')]
|
||||
parts = [p for part in parts for p in part.split('\n')]
|
||||
cleaned_text = ' '.join(parts)
|
||||
return cleaned_text
|
||||
|
||||
def get_tuple(attr: str, text: str) -> Tuple[str, str]:
|
||||
"""Return table-parent entry: attr|||value """
|
||||
raw_value = clean_text(text)
|
||||
value = raw_value.replace('|', '-')
|
||||
return (attr, value)
|
||||
|
||||
def get_table_parent_list(linked_cells: Dict, table: Dict) -> List:
|
||||
"""Return a list of tuples as required by the PARENT metric.
|
||||
args:
|
||||
linked_cells: {'corner', 'top', 'left', 'data'}
|
||||
rets:
|
||||
*table_parent_array: List[Tuple(attribute, value)]
|
||||
table_parent_str: '\t'-separated
|
||||
"""
|
||||
table_parent_array = []
|
||||
|
||||
title_tuple = get_tuple('title', table['title'])
|
||||
table_parent_array.append(title_tuple)
|
||||
|
||||
for coords, cellstr in linked_cells.items():
|
||||
cell_tuple = get_tuple(attr='cell', text=str(cellstr))
|
||||
table_parent_array.append(cell_tuple)
|
||||
|
||||
return table_parent_array
|
||||
|
||||
|
||||
|
||||
# %% iterate
|
||||
|
||||
def iterate_entity_link_by_field(entity_link: Dict, field: str, return_text: bool = False) -> Iterable[Tuple]:
|
||||
"""Iterate the cells in the `entity_link` field.
|
||||
args:
|
||||
entity_link: {'
|
||||
'top': {'the fy 2017 r&d budget': {'(0, 1)': '2017 actual'}},
|
||||
'left': {'pre-production development activities': {'(18, 0)': 'total'}},
|
||||
'top_left_corner': {}
|
||||
}
|
||||
rets:
|
||||
Iterate(cell_coords): [(0,1), ...]
|
||||
"""
|
||||
field_links = entity_link[field]
|
||||
for text_span, ref_cells in field_links.items():
|
||||
for cell_coords, cell_text in ref_cells.items():
|
||||
cell_coords = eval(cell_coords)
|
||||
int_cell_coords = (int(cell_coords[0]), int(cell_coords[1]))
|
||||
if return_text: yield {int_cell_coords: cell_text}
|
||||
else: yield int_cell_coords
|
||||
|
||||
def iterate_entity_link(entity_link: Dict, return_text: bool = False) -> Iterable[Tuple]:
|
||||
"""Iterate the cells in the `entity_link` field.
|
||||
args:
|
||||
entity_link: {'
|
||||
'top': {'the fy 2017 r&d budget': {'(0, 1)': '2017 actual'}},
|
||||
'left': {'pre-production development activities': {'(18, 0)': 'total'}},
|
||||
'top_left_corner': {}
|
||||
}
|
||||
rets:
|
||||
Iterate(cell_coords): [(0,1), ...]
|
||||
"""
|
||||
for field in entity_link.keys(): # ['top', 'left', 'top_left_corner']
|
||||
for item in iterate_entity_link_by_field(entity_link, field, return_text):
|
||||
yield item
|
||||
|
||||
|
||||
def iterate_quantity_link(quantity_link: Dict, return_text: bool = True) -> Iterable[Tuple]:
|
||||
"""Iterate the cells in the `quantity_link` field.
|
||||
args:
|
||||
quantity_link: {'
|
||||
''125.3 billion': {'(17, 1)': 125289.0},
|
||||
'[ANSWER]': {'(18, 1)': 154983.0}
|
||||
}
|
||||
rets:
|
||||
Iterate(cell_coords): [(17,1), ...]
|
||||
"""
|
||||
for text_span, ref_cells in quantity_link.items():
|
||||
for cell_coords, cell_text in ref_cells.items():
|
||||
cell_coords = eval(cell_coords)
|
||||
int_cell_coords = (int(cell_coords[0]), int(cell_coords[1]))
|
||||
if return_text: yield {int_cell_coords: cell_text}
|
||||
else: yield int_cell_coords
|
||||
|
||||
|
||||
def iterate_cells_coords(highlighted_cells: Dict) -> List[Tuple[int, int]]:
|
||||
cell_coords = list(highlighted_cells.keys())
|
||||
return sorted(cell_coords, key=lambda x: (x[0], x[1]))
|
||||
|
||||
|
||||
|
||||
# %% cell string serialization
|
||||
|
||||
def join_cells(cell_strings: List[Any], cell_delimiter: str = '|') -> str:
|
||||
return f" {cell_delimiter} ".join([str(cs) for cs in cell_strings])
|
||||
|
||||
|
||||
def join_aggrs(aggregation: List[str], answer: List[Any]) -> str:
|
||||
return f"{' '.join(aggregation)} {' '.join([str(a) for a in answer])}"
|
||||
|
||||
|
||||
def add_tag(text: str, tag: str, do_head: bool = True, do_tail: bool = False):
|
||||
"""Add field tags to the text."""
|
||||
if do_head == True: prefix = f'{tag} '
|
||||
else: prefix = ''
|
||||
|
||||
if do_tail == True: suffix = f' {tag}'
|
||||
else: suffix = ''
|
||||
|
||||
return f'{prefix}{text}{suffix}'
|
||||
|
||||
|
||||
|
||||
# %% main pipieline
|
||||
|
||||
def prepare_model_input(sample: Dict) -> Dict:
|
||||
table = load_table(sample['table_id'])
|
||||
|
||||
source_texts = []
|
||||
source_texts.append( add_tag(table['title'], '<title>') )
|
||||
|
||||
linked_cells = sample['linked_cells']
|
||||
|
||||
highlight_cells = {}
|
||||
if args.no_asc:
|
||||
for ent_cell_dict in iterate_entity_link_by_field(linked_cells['entity_link'], 'top', return_text=True):
|
||||
highlight_cells.update(ent_cell_dict)
|
||||
for ent_cell_dict in iterate_entity_link_by_field(linked_cells['entity_link'], 'left', return_text=True):
|
||||
highlight_cells.update(ent_cell_dict)
|
||||
else:
|
||||
for ent_cell_coords in iterate_entity_link_by_field(linked_cells['entity_link'], 'top', return_text=False):
|
||||
ent_cell_dict = get_ascendant_cells(ent_cell_coords, table, 'top_root')
|
||||
highlight_cells.update(ent_cell_dict)
|
||||
for ent_cell_coords in iterate_entity_link_by_field(linked_cells['entity_link'], 'left', return_text=False):
|
||||
ent_cell_dict = get_ascendant_cells(ent_cell_coords, table, 'left_root')
|
||||
highlight_cells.update(ent_cell_dict)
|
||||
|
||||
|
||||
for ent_cell_dict in iterate_entity_link_by_field(linked_cells['entity_link'], 'top_left_corner', return_text=True):
|
||||
highlight_cells.update(ent_cell_dict)
|
||||
|
||||
if args.no_asc:
|
||||
for qtt_cell_dict in iterate_quantity_link(linked_cells['quantity_link']):
|
||||
highlight_cells.update(qtt_cell_dict)
|
||||
else:
|
||||
for qtt_cell_coords in iterate_quantity_link(linked_cells['quantity_link'], return_text=False):
|
||||
top_cell_dict = get_ascendant_cells(qtt_cell_coords, table, 'top_root')
|
||||
highlight_cells.update(top_cell_dict)
|
||||
left_cell_dict = get_ascendant_cells(qtt_cell_coords, table, 'left_root')
|
||||
highlight_cells.update(left_cell_dict)
|
||||
|
||||
if args.no_split_fields:
|
||||
cell_strings = [highlight_cells[k] for k in iterate_cells_coords(highlight_cells)]
|
||||
source_texts.append( add_tag(join_cells(cell_strings), '<cell>') )
|
||||
else:
|
||||
top_header_rows_num = table['top_header_rows_num']
|
||||
left_header_columns_num = table['left_header_columns_num']
|
||||
top_strings, left_strings, corner_strings, data_strings = [], [], [], []
|
||||
for coords in iterate_cells_coords(highlight_cells):
|
||||
cell_str = highlight_cells[coords]
|
||||
r, c = coords
|
||||
if (r < top_header_rows_num) and (c < left_header_columns_num):
|
||||
corner_strings.append(cell_str)
|
||||
elif (r < top_header_rows_num):
|
||||
top_strings.append(cell_str)
|
||||
elif (c < left_header_columns_num):
|
||||
left_strings.append(cell_str)
|
||||
else:
|
||||
data_strings.append(cell_str)
|
||||
|
||||
source_texts.append( add_tag(join_cells(top_strings), '<top>') )
|
||||
source_texts.append( add_tag(join_cells(left_strings), '<left>') )
|
||||
source_texts.append( add_tag(join_cells(corner_strings), '<corner>') )
|
||||
source_texts.append( add_tag(join_cells(data_strings), '<data>') )
|
||||
|
||||
if args.add_aggr:
|
||||
aggr_str = join_aggrs(sample['aggregation'], sample['answer'])
|
||||
source_texts.append( add_tag(aggr_str, '<agg>') )
|
||||
|
||||
table_list = get_table_parent_list(highlight_cells, table)
|
||||
|
||||
return {
|
||||
'source': ' '.join(source_texts),
|
||||
'target': sample['sub_sentence'],
|
||||
'table_parent': table_list,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
for in_path, out_path in zip(args.input_paths, args.output_paths):
|
||||
print(f"from [{in_path}] >> to [{out_path}]")
|
||||
|
||||
with open(in_path, 'r') as fr:
|
||||
dataset = [json.loads(l.strip()) for l in fr]
|
||||
if args.test_topk:
|
||||
dataset = dataset[: args.test_topk]
|
||||
print(f"collected {len(dataset)} samples")
|
||||
|
||||
fw = open(out_path, 'w')
|
||||
for idx, sample in enumerate(dataset):
|
||||
if (idx + 1) % args.logging_steps == 0:
|
||||
print(f"finished processing {idx + 1} samples")
|
||||
|
||||
result = prepare_model_input(sample)
|
||||
fw.write(f"{json.dumps(result)}\n")
|
||||
|
||||
fw.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--dataset_dir', type=str, default='../data')
|
||||
parser.add_argument('--file_names', type=str, nargs='+',
|
||||
default=['test_samples.jsonl', 'dev_samples.jsonl', 'train_samples.jsonl'])
|
||||
parser.add_argument('--output_dir', type=str, default='data')
|
||||
|
||||
parser.add_argument('--table_subdir', type=str, default='tables')
|
||||
parser.add_argument('--table_type',type=str, default='hmt', choices=['hmt', 'raw'],
|
||||
help='Use `raw` if including ascendant cells, otherwise use `hmt`.')
|
||||
|
||||
parser.add_argument('--no_asc', action='store_true')
|
||||
parser.add_argument('--add_aggr', action='store_true')
|
||||
parser.add_argument('--no_split_fields', action='store_true')
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=500)
|
||||
parser.add_argument('--test_topk', type=int, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.input_paths = [os.path.join(args.dataset_dir, fn) for fn in args.file_names]
|
||||
args.output_paths = [os.path.join(args.output_dir, fn) for fn in args.file_names]
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.no_asc: canonical_table_type = 'hmt'
|
||||
else: canonical_table_type = 'raw'
|
||||
if args.table_type != canonical_table_type:
|
||||
logging.info(f"Should use `{canonical_table_type}` version of data. ")
|
||||
args.table_type = canonical_table_type
|
||||
|
||||
main()
|
|
@ -0,0 +1,83 @@
|
|||
"""Test with a specified evaluation metrics. """
|
||||
|
||||
|
||||
import os
|
||||
from .utils import prepare_tokenizer, get_testset, ModelTestDict
|
||||
from .pointer_generator import BeamSearch
|
||||
from .evaluation import EvalDict, DecodeDict
|
||||
|
||||
|
||||
|
||||
def run_test(args):
|
||||
"""Test a fine-tuned model. Using huggingface/transformers.
|
||||
Load the test set, prepare tokenizer, load tuned models.
|
||||
Then perform evaluation or decoding.
|
||||
"""
|
||||
testset = get_testset(data_files=args.test_outpath)
|
||||
tokenizer = prepare_tokenizer(name=args.tokenizer_name)
|
||||
|
||||
model = ModelTestDict[args.expr_name](
|
||||
run_dir=args.run_dir,
|
||||
path=args.model_path,
|
||||
name=args.model_name,
|
||||
device=args.device,
|
||||
)
|
||||
if args.do_test:
|
||||
for metric in args.metrics:
|
||||
print(f'Start evaluation with metrics [{metric}]')
|
||||
EvalDict[metric](args, testset, tokenizer, model)
|
||||
if args.do_decode:
|
||||
args.test_decode_path = os.path.join(args.run_dir, args.test_decode_name)
|
||||
DecodeDict[args.metrics[0]](args, testset, tokenizer, model)
|
||||
|
||||
|
||||
|
||||
# pointer generator network
|
||||
|
||||
def find_best_pgn_model_index(run_dir, main_metric_key='bleu-4'):
|
||||
"""Find the best model at testing. """
|
||||
detailed_run_dir = os.path.join(run_dir, 'train', 'models')
|
||||
decode_dirs = os.listdir(detailed_run_dir)
|
||||
decode_metrics = []
|
||||
for dd in decode_dirs:
|
||||
mfile = os.path.join(run_dir, dd, 'metrics')
|
||||
ckpt_metrics = {}
|
||||
with open(mfile, 'r') as fr:
|
||||
for line in fr:
|
||||
mkey, mval = line.strip().split('\t')
|
||||
ckpt_metrics[mkey] = float(mval)
|
||||
decode_metrics.append(ckpt_metrics)
|
||||
|
||||
best_ckpt_idx = -1
|
||||
best_ckpt_mval = 0.0
|
||||
for idx, mdict in decode_metrics:
|
||||
mval = mdict[main_metric_key]
|
||||
if mval > best_ckpt_mval:
|
||||
best_ckpt_mval = mval
|
||||
best_ckpt_idx = idx
|
||||
return best_ckpt_idx
|
||||
|
||||
|
||||
def run_test_pgn(args):
|
||||
try:
|
||||
# best_ckpt_idx = find_best_pgn_model_index(args.run_dir)
|
||||
best_ckpt_idx = 99
|
||||
best_ckpt_path = os.path.join(args.run_dir, 'train', 'models', f'model_{best_ckpt_idx}')
|
||||
except:
|
||||
best_ckpt_path = None
|
||||
print(f'<<< Perform the Final Test ... (use model [{best_ckpt_path}]) >>>')
|
||||
tester = BeamSearch(args, best_ckpt_path, args.decode_data_path)
|
||||
tester.run(args.logging_steps)
|
||||
# tester.eval_parent(args.logging_steps)
|
||||
print(f'<<< Finished the Final Test ! >>>')
|
||||
|
||||
|
||||
|
||||
# collection
|
||||
|
||||
TestFunctionDict = {
|
||||
't5': run_test,
|
||||
'bart': run_test,
|
||||
'b2b': run_test,
|
||||
'pg': run_test_pgn,
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
"""Quick imports of Evaluation Modules. """
|
||||
|
||||
from .eval import eval_with_bleu, eval_with_parent
|
||||
|
||||
EvalDict = {
|
||||
'bleu': eval_with_bleu,
|
||||
'parent': eval_with_parent,
|
||||
}
|
||||
|
||||
from .decode import decode_with_bleu, decode_with_parent
|
||||
DecodeDict = {
|
||||
'bleu': decode_with_bleu,
|
||||
'parent': decode_with_parent,
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
"""Decode. """
|
||||
|
||||
from .utils import (
|
||||
beam_generate,
|
||||
rank_prediction_set_by_bleu,
|
||||
select_prediction_set_by_parent,
|
||||
)
|
||||
from ..utils import parent_scorer
|
||||
|
||||
|
||||
def decode_with_bleu(args, testset, tokenizer, model):
|
||||
"""Decode testset and write out, when BLEU metrics is specified. """
|
||||
|
||||
raw_predictions = [
|
||||
beam_generate(sample, tokenizer, model, args)
|
||||
for sample in testset
|
||||
]
|
||||
|
||||
references = [
|
||||
[tokenizer.tokenize(sample['target'])]
|
||||
for sample in testset
|
||||
]
|
||||
|
||||
ranked_predictions = rank_prediction_set_by_bleu(
|
||||
raw_predictions, references)
|
||||
|
||||
with open(args.test_decode_path, 'w') as fw:
|
||||
for idx, (pred_list, ref) in enumerate(zip(ranked_predictions, references)):
|
||||
fw.write(f"#{idx}\n")
|
||||
for ii, psent, pscore in pred_list:
|
||||
fw.write(f'[{ii}: {pscore:.4f}] {psent}\n')
|
||||
fw.write(f'{ref[0]}\n\n')
|
||||
print(f'Wrote {len(ranked_predictions)} prediction & reference instances into target file: [{args.test_decode_path}]')
|
||||
|
||||
return
|
||||
|
||||
|
||||
def decode_with_parent(args, testset, tokenizer, model):
|
||||
"""Do evaluation on the testset, when BLEU metrics is specified. """
|
||||
|
||||
raw_predictions = [ beam_generate(sample, tokenizer, model, args)
|
||||
for sample in testset]
|
||||
references = [ [tokenizer.tokenize(sample['target'])]
|
||||
for sample in testset]
|
||||
tokenized_tables = []
|
||||
for sample in testset:
|
||||
raw_table_parent = sample['table_parent']
|
||||
tokenized_table_parent = []
|
||||
for attr, value in raw_table_parent:
|
||||
value_tokens = tokenizer.tokenize(value)
|
||||
tokenized_table_parent.append( ([attr], value_tokens) )
|
||||
tokenized_tables.append(tokenized_table_parent)
|
||||
|
||||
pred_tokens_dict = {}
|
||||
for idx in range(args.num_return_sequences):
|
||||
pred_tokens_dict[idx] = [sample[idx]['tokens_clear'] for sample in raw_predictions]
|
||||
|
||||
for idx, predictions in pred_tokens_dict.items():
|
||||
(idx_p, idx_r, idx_f1, idx_all_f1) = parent_scorer(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
tables=tokenized_tables,
|
||||
return_dict=False,
|
||||
)
|
||||
print(f"Idx#{idx} - PARENT: {idx_p:.3f}, {idx_r:.3f}, {idx_f1:.3f}")
|
||||
|
||||
best_predictions = select_prediction_set_by_parent(
|
||||
raw_predictions, references, tokenized_tables)
|
||||
(avg_p, avg_r, avg_f, all_f) = parent_scorer(
|
||||
predictions=best_predictions,
|
||||
references=references,
|
||||
tables=tokenized_tables,
|
||||
return_dict=False
|
||||
)
|
||||
print(f"BEST PARENT: {avg_p: .3f}, {avg_r:.3f}, {avg_f:.3f}")
|
||||
|
||||
with open(args.test_decode_path, 'w') as fw:
|
||||
for idx, (pred, ref, tab) in enumerate(zip(best_predictions, references, tokenized_tables)):
|
||||
sample_parent = parent_scorer(
|
||||
predictions=[pred],
|
||||
refereces=[ref],
|
||||
tables=[tab],
|
||||
return_dict=True
|
||||
)
|
||||
fw.write(f"#{idx} BLEU: [{sample_parent['average_f1']:.4f}]\n")
|
||||
fw.write(f'{pred}\n{ref[0]}\n\n')
|
||||
print(f'Wrote {len(predictions)} prediction & reference pairs into target file: [{args.test_decode_path}]')
|
||||
|
||||
return
|
|
@ -0,0 +1,86 @@
|
|||
"""Evaluation. """
|
||||
|
||||
from .utils import (
|
||||
beam_generate,
|
||||
select_prediction_set_by_bleu,
|
||||
select_prediction_set_by_parent,
|
||||
)
|
||||
from ..utils import bleu_scorer, parent_scorer
|
||||
|
||||
|
||||
def eval_with_bleu(args, testset, tokenizer, model):
|
||||
"""Do evaluation on the testset, when BLEU metrics is specified. """
|
||||
|
||||
raw_predictions = [
|
||||
beam_generate(sample, tokenizer, model, args)
|
||||
for sample in testset
|
||||
]
|
||||
|
||||
references = [
|
||||
[tokenizer.tokenize(sample['target'])]
|
||||
for sample in testset
|
||||
]
|
||||
|
||||
pred_tokens_dict = {}
|
||||
for idx in range(args.num_return_sequences):
|
||||
pred_tokens_dict[idx] = [sample[idx]['tokens_clear'] for sample in raw_predictions]
|
||||
|
||||
for idx, predictions in pred_tokens_dict.items():
|
||||
idx_results = bleu_scorer.compute(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
)
|
||||
print(f"Idx#{idx} - BLEU: {idx_results['bleu']: .3f}")
|
||||
|
||||
best_predictions = select_prediction_set_by_bleu(
|
||||
raw_predictions, references, bleu_scorer)
|
||||
best_results = bleu_scorer.compute(
|
||||
predictions=best_predictions,
|
||||
references=references
|
||||
)
|
||||
print(f"BEST BLEU: {best_results['bleu']: .3f}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
def eval_with_parent(args, testset, tokenizer, model):
|
||||
"""Do evaluation on the testset, when BLEU metrics is specified. """
|
||||
|
||||
raw_predictions = [ beam_generate(sample, tokenizer, model, args)
|
||||
for sample in testset]
|
||||
references = [ [tokenizer.tokenize(sample['target'])]
|
||||
for sample in testset]
|
||||
tokenized_tables = []
|
||||
for sample in testset:
|
||||
raw_table_parent = sample['table_parent']
|
||||
tokenized_table_parent = []
|
||||
for attr, value in raw_table_parent:
|
||||
value_tokens = tokenizer.tokenize(value)
|
||||
tokenized_table_parent.append( ([attr], value_tokens) )
|
||||
tokenized_tables.append(tokenized_table_parent)
|
||||
|
||||
pred_tokens_dict = {}
|
||||
for idx in range(args.num_return_sequences):
|
||||
pred_tokens_dict[idx] = [sample[idx]['tokens_clear'] for sample in raw_predictions]
|
||||
|
||||
for idx, predictions in pred_tokens_dict.items():
|
||||
(idx_p, idx_r, idx_f1, idx_all_f1) = parent_scorer(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
tables=tokenized_tables,
|
||||
return_dict=False,
|
||||
)
|
||||
print(f"Idx#{idx} - PARENT: {idx_p:.3f}, {idx_r:.3f}, {idx_f1:.3f}")
|
||||
|
||||
best_predictions = select_prediction_set_by_parent(
|
||||
raw_predictions, references, tokenized_tables)
|
||||
(avg_p, avg_r, avg_f, all_f) = parent_scorer(
|
||||
predictions=best_predictions,
|
||||
references=references,
|
||||
tables=tokenized_tables,
|
||||
return_dict=False
|
||||
)
|
||||
print(f"BEST PARENT: {avg_p: .3f}, {avg_r:.3f}, {avg_f:.3f}")
|
||||
|
||||
return
|
|
@ -0,0 +1,183 @@
|
|||
"""Generation utility functions. """
|
||||
|
||||
|
||||
import torch
|
||||
from ..utils import special_tokens_map
|
||||
|
||||
|
||||
# %% beam generate
|
||||
|
||||
def tokenize_sample_test(sample, tokenizer, args, verbose=False):
|
||||
"""Tokenize on the sample source text, while testing."""
|
||||
|
||||
if verbose:
|
||||
print(f"[utils >> tknz_sample] has table {sample['table_id']} & subsent [{sample['sub_sent_id']}]")
|
||||
|
||||
cls_id = special_tokens_map[args.expr_name]['cls']
|
||||
sep_id = special_tokens_map[args.expr_name]['sep']
|
||||
|
||||
input_ids = [cls_id]
|
||||
position_ids = [0]
|
||||
for text_span in sample['source']:
|
||||
span_tokens = tokenizer.tokenize(text_span)
|
||||
span_token_ids = tokenizer.convert_tokens_to_ids(span_tokens)
|
||||
input_ids.extend(span_token_ids)
|
||||
input_ids.append(sep_id)
|
||||
position_ids.extend([i for i in range(len(span_token_ids) + 1)])
|
||||
input_ids = input_ids[:args.input_maxlen]
|
||||
position_ids = position_ids[:args.input_maxlen]
|
||||
attention_mask = [1 for _ in input_ids]
|
||||
|
||||
input_ids = torch.LongTensor([input_ids])
|
||||
attention_mask = torch.LongTensor([attention_mask])
|
||||
position_ids = torch.LongTensor([position_ids])
|
||||
input_features = {
|
||||
'input_ids': input_ids.to(args.device),
|
||||
'attention_mask': attention_mask.to(args.device),
|
||||
'position_ids': position_ids.to(args.device)
|
||||
}
|
||||
return input_features
|
||||
|
||||
|
||||
def clear_tokens(token_list, tokenizer):
|
||||
"""Clean a token sequence by remove <pad>s.
|
||||
Skip special tokens noted as f'<{}>'.
|
||||
"""
|
||||
valid_token_list = [
|
||||
token for token in token_list
|
||||
if token not in tokenizer.all_special_tokens
|
||||
]
|
||||
return valid_token_list
|
||||
|
||||
|
||||
def beam_generate(sample, tokenizer, model, args, verbose=False):
|
||||
"""Generate outputs from a model with beam search decoding.
|
||||
|
||||
args:
|
||||
sample: {'table_id', 'sub_sent_id', 'source', 'target'}
|
||||
rets:
|
||||
generation: List[str]
|
||||
"""
|
||||
|
||||
# generate vocab ids
|
||||
sample_features = tokenize_sample_test(sample, tokenizer, args)
|
||||
if args.expr_name == 'b2b':
|
||||
gen_ids = model.generate(
|
||||
input_ids=sample_features['input_ids'],
|
||||
attention_mask=sample_features['attention_mask'],
|
||||
position_ids=sample_features['position_ids'],
|
||||
max_length=args.decode_maxlen,
|
||||
num_beams=args.num_beams,
|
||||
num_return_sequences=args.num_return_sequences
|
||||
)
|
||||
else:
|
||||
gen_ids = model.generate(
|
||||
input_ids=sample_features['input_ids'],
|
||||
attention_mask=sample_features['attention_mask'],
|
||||
max_length=args.decode_maxlen,
|
||||
num_beams=args.num_beams,
|
||||
num_return_sequences=args.num_return_sequences
|
||||
)
|
||||
if verbose == True:
|
||||
print(f'[beam_gen] has GEN-IDS with size {gen_ids.size()}')
|
||||
|
||||
gen_features = dict()
|
||||
for iret, gen_ids in enumerate(gen_ids):
|
||||
gen_tokens = tokenizer.convert_ids_to_tokens(gen_ids)
|
||||
gen_tokens_clear = clear_tokens(gen_tokens, tokenizer)
|
||||
gen_sentence = tokenizer.convert_tokens_to_string(gen_tokens_clear)
|
||||
|
||||
gen_features[iret] = {
|
||||
'ids': gen_ids,
|
||||
'tokens': gen_tokens,
|
||||
'tokens_clear': gen_tokens_clear,
|
||||
'sentence': gen_sentence
|
||||
}
|
||||
|
||||
return gen_features
|
||||
|
||||
|
||||
# %% select optimal set
|
||||
|
||||
from ..utils import bleu_scorer, parent_scorer
|
||||
|
||||
|
||||
def select_prediction_set_by_bleu(
|
||||
prediction_dicts, references, return_index=False):
|
||||
"""Select sequence-wise-ly from predictions the best predset against references."""
|
||||
predictions = []
|
||||
indices = []
|
||||
|
||||
for sample_pred_dict, ref_list in zip(prediction_dicts, references):
|
||||
max_idx = 0
|
||||
max_score = 0.0
|
||||
|
||||
for idx, d in sample_pred_dict.items():
|
||||
res = bleu_scorer.compute(
|
||||
predictions=[d['tokens_clear']],
|
||||
references=[ref_list]
|
||||
)
|
||||
score = res['bleu']
|
||||
|
||||
if score > max_score:
|
||||
max_idx = idx
|
||||
max_score = score
|
||||
|
||||
# print(f'[utils >> select_predset] sample max score: [{max_score}]')
|
||||
predictions.append(sample_pred_dict[max_idx]['tokens_clear'])
|
||||
indices.append(max_idx)
|
||||
|
||||
if return_index: return predictions, indices
|
||||
return predictions
|
||||
|
||||
|
||||
def select_prediction_set_by_parent(prediction_dicts, references, tables, return_index=False):
|
||||
"""Select sequence-wise-ly from predictions the best predset against references."""
|
||||
predictions = []
|
||||
indices = []
|
||||
|
||||
for sample_pred_dict, ref_list, table in zip(prediction_dicts, references, tables):
|
||||
max_idx = 0
|
||||
max_score = 0.0
|
||||
|
||||
for idx, d in sample_pred_dict.items():
|
||||
p, r, f1, all_f1 = parent_scorer(
|
||||
predictions=[d['tokens_clear']],
|
||||
references=[ref_list],
|
||||
tables=[table],
|
||||
return_dict=False
|
||||
)
|
||||
|
||||
if f1 > max_score:
|
||||
max_idx = idx
|
||||
max_score = f1
|
||||
|
||||
# print(f'[utils >> select_predset] sample max score: [{max_score}]')
|
||||
predictions.append(sample_pred_dict[max_idx]['tokens_clear'])
|
||||
indices.append(max_idx)
|
||||
|
||||
if return_index: return predictions, indices
|
||||
return predictions
|
||||
|
||||
|
||||
|
||||
# sort / rank multiple predictions
|
||||
|
||||
def rank_prediction_set_by_bleu(prediction_dicts, references): # return_scores=True
|
||||
"""Rank sequence-wise-ly from predictions the best predset against references."""
|
||||
from experiment.utils.metrics import bleu_scorer
|
||||
|
||||
sorted_predictions = []
|
||||
for sample_pred_dict, ref_list in zip(prediction_dicts, references):
|
||||
pred_score_pairs = []
|
||||
for idx, d in sample_pred_dict.items():
|
||||
res = bleu_scorer.compute(
|
||||
predictions=[d['tokens_clear']],
|
||||
references=[ref_list]
|
||||
)
|
||||
pred_score_pairs.append( (idx, d['sentence'], res['bleu']) )
|
||||
|
||||
pred_score_pairs = sorted(pred_score_pairs, key=lambda x: x[2])
|
||||
sorted_predictions.append(pred_score_pairs)
|
||||
|
||||
return sorted_predictions
|
|
@ -0,0 +1,5 @@
|
|||
"""Quick imports of Pointer Generator modules. """
|
||||
|
||||
from .data import Vocab, Batcher
|
||||
from .model.model import Model
|
||||
from .decode import BeamSearch
|
|
@ -0,0 +1,60 @@
|
|||
"""Initialize the config dictionary for varied experiments.
|
||||
|
||||
Default model configurations.
|
||||
"""
|
||||
|
||||
|
||||
# %% tokens
|
||||
SENTENCE_STA = '<s>'
|
||||
SENTENCE_END = '</s>'
|
||||
|
||||
UNK = 0
|
||||
PAD = 1
|
||||
BOS = 2
|
||||
EOS = 3
|
||||
|
||||
PAD_TOKEN = '[PAD]'
|
||||
UNK_TOKEN = '[UNK]'
|
||||
BOS_TOKEN = '[BOS]'
|
||||
EOS_TOKEN = '[EOS]'
|
||||
|
||||
|
||||
# %% model
|
||||
emb_dim = 128
|
||||
hidden_dim = 256
|
||||
vocab_size = 30000
|
||||
|
||||
beam_size = 6
|
||||
max_enc_steps = 512
|
||||
max_dec_steps = 40
|
||||
max_tes_steps = 60
|
||||
min_dec_steps = 8
|
||||
|
||||
|
||||
# batch_size = 64
|
||||
# lr = 5e-5
|
||||
cov_loss_wt = 1.0
|
||||
pointer_gen = True
|
||||
is_coverage = True
|
||||
|
||||
max_grad_norm = 2.0
|
||||
adagrad_init_acc = 0.1
|
||||
rand_unif_init_mag = 0.02
|
||||
trunc_norm_init_std = 1e-4
|
||||
|
||||
eps = 1e-12
|
||||
use_gpu = True
|
||||
# lr_coverage = 5e-5
|
||||
max_iterations = 500
|
||||
|
||||
|
||||
# %% transformer
|
||||
tran = False
|
||||
# d_k = 64
|
||||
# d_v = 64
|
||||
# n_head = 6
|
||||
# dropout = 0.1
|
||||
# n_layers = 6
|
||||
# d_model = 128
|
||||
# d_inner = 512
|
||||
# n_warmup_steps = 4000
|
|
@ -0,0 +1,402 @@
|
|||
"""Classes for data.
|
||||
* Vocab: vocabulary instance initiated from the voab file
|
||||
* Example: an source-target(-table-parent) example
|
||||
* Batch: a list of batched and masked examples
|
||||
* Batcher: load examples and batch them for model input
|
||||
"""
|
||||
|
||||
|
||||
# %% Vocabulary
|
||||
|
||||
import csv
|
||||
|
||||
# <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
|
||||
SENTENCE_STA = '<s>'
|
||||
SENTENCE_END = '</s>'
|
||||
|
||||
PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
|
||||
UNK_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
|
||||
BOS_TOKEN = '[BOS]' # This has a vocab id, which is used at the start of every decoder input sequence
|
||||
EOS_TOKEN = '[EOS]' # This has a vocab id, which is used at the end of untruncated target sequences
|
||||
# Note: none of <s>, </s>, [PAD], [UNK], [START], [STOP] should appear in the vocab file.
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
"""Vocabulary class. """
|
||||
|
||||
def __init__(self, file: str, max_size: int):
|
||||
self.word2idx = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0 # keeps track of total number of words in the Vocab
|
||||
|
||||
# [UNK], [PAD], [BOS] and [EOS] get the ids 0,1,2,3.
|
||||
for w in [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]:
|
||||
self.word2idx[w] = self.count
|
||||
self.idx2word[self.count] = w
|
||||
self.count += 1
|
||||
|
||||
# Read the vocab file and add words up to max_size
|
||||
with open(file, 'r') as fin:
|
||||
for line in fin:
|
||||
items = line.split()
|
||||
if len(items) != 2:
|
||||
print('Warning: incorrectly formatted line in vocabulary file: %s' % line.strip())
|
||||
continue
|
||||
w = items[0]
|
||||
if w in [SENTENCE_STA, SENTENCE_END, UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]:
|
||||
raise Exception(
|
||||
'<s>, </s>, [UNK], [PAD], [BOS] and [EOS] shouldn\'t be in the vocab file, but %s is' % w)
|
||||
if w in self.word2idx:
|
||||
raise Exception('Duplicated word in vocabulary file: %s' % w)
|
||||
self.word2idx[w] = self.count
|
||||
self.idx2word[self.count] = w
|
||||
self.count += 1
|
||||
if max_size != 0 and self.count >= max_size:
|
||||
break
|
||||
print("Finished constructing vocabulary of %i total words. Last word added: %s" % (
|
||||
self.count, self.idx2word[self.count - 1]))
|
||||
|
||||
def word2id(self, word):
|
||||
if word not in self.word2idx:
|
||||
return self.word2idx[UNK_TOKEN]
|
||||
return self.word2idx[word]
|
||||
|
||||
def id2word(self, word_id):
|
||||
if word_id not in self.idx2word:
|
||||
raise ValueError('Id not found in vocab: %d' % word_id)
|
||||
return self.idx2word[word_id]
|
||||
|
||||
def size(self):
|
||||
return self.count
|
||||
|
||||
def write_metadata(self, path):
|
||||
print( "Writing word embedding metadata file to %s..." % (path))
|
||||
with open(path, "w") as f:
|
||||
fieldnames = ['word']
|
||||
writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
|
||||
for i in range(self.size()):
|
||||
writer.writerow({"word": self.idx2word[i]})
|
||||
|
||||
|
||||
|
||||
# %% Example
|
||||
|
||||
from typing import List
|
||||
from experiment.pointer_generator import config, utils
|
||||
|
||||
|
||||
class Example(object):
|
||||
"""A hmt-table example. """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: List[List[str]],
|
||||
target: List[List[str]],
|
||||
table_parent: List,
|
||||
vocab: Vocab,
|
||||
config=config,
|
||||
):
|
||||
"""Initialize the example with source and target texts.
|
||||
args:
|
||||
source: List[List[str]], list of parsed tokens/words-list.
|
||||
target: List[str], a single list of parsed target text.
|
||||
table_parent: List[List[List[str], List[str]]], parsed attributes and fields.
|
||||
"""
|
||||
# Get ids of special tokens
|
||||
bos_decoding = vocab.word2id(BOS_TOKEN)
|
||||
eos_decoding = vocab.word2id(EOS_TOKEN)
|
||||
|
||||
# process the source input
|
||||
src_words = [sword for sitem in source for sword in sitem]
|
||||
if len(src_words) > config.max_enc_steps:
|
||||
src_words = src_words[: config.max_enc_steps]
|
||||
self.enc_len = len(src_words)
|
||||
self.enc_inp = [vocab.word2id(w) for w in src_words]
|
||||
|
||||
# process the target text
|
||||
tgt_words = target
|
||||
tgt_ids = [vocab.word2id(w) for w in tgt_words]
|
||||
# get the decoder input dequence and target sequence
|
||||
self.dec_inp, self.tgt = self.get_dec_seq(tgt_ids,
|
||||
config.max_dec_steps, bos_decoding, eos_decoding)
|
||||
self.dec_len = len(self.dec_inp)
|
||||
|
||||
# if using pg mode, need to store some extra info
|
||||
if config.pointer_gen:
|
||||
# Store a version of the enc_input where in-article OOVs are represented by their temporary OOV id;
|
||||
# also store the in-article OOVs words themselves
|
||||
self.enc_inp_extend_vocab, self.article_oovs = utils.article2ids(src_words, vocab, config)
|
||||
|
||||
# Get a verison of the reference summary where in-article OOVs are represented by their temporary article OOV id
|
||||
abs_ids_extend_vocab = utils.abstract2ids(tgt_words, vocab, self.article_oovs, config)
|
||||
|
||||
# Overwrite decoder target sequence so it uses the temp article OOV ids
|
||||
_, self.tgt = self.get_dec_seq(abs_ids_extend_vocab, config.max_dec_steps, bos_decoding, eos_decoding)
|
||||
|
||||
# store the original strings
|
||||
self.original_source = source
|
||||
self.original_target = target
|
||||
self.original_table_parent = table_parent
|
||||
|
||||
def get_dec_seq(self, sequence: List[int], max_len: int, start_id: int, stop_id: int):
|
||||
"""Perform decoding seuqence processing, add special tokens and do truncation. """
|
||||
src = [start_id] + sequence[:]
|
||||
tgt = sequence[:]
|
||||
if len(src) > max_len: # truncate
|
||||
src = src[: max_len]
|
||||
tgt = tgt[: max_len] # no end_token
|
||||
else: # no truncation
|
||||
tgt.append(stop_id) # end token
|
||||
assert len(src) == len(tgt)
|
||||
return src, tgt
|
||||
|
||||
def pad_enc_seq(self, max_len: int, pad_id: int) -> None:
|
||||
"""Pad the encoding sequence to config-specified max length. """
|
||||
while len(self.enc_inp) < max_len:
|
||||
self.enc_inp.append(pad_id)
|
||||
if config.pointer_gen:
|
||||
while len(self.enc_inp_extend_vocab) < max_len:
|
||||
self.enc_inp_extend_vocab.append(pad_id)
|
||||
|
||||
def pad_dec_seq(self, max_len: int, pad_id: int) -> None:
|
||||
"""Pad the decoding sequence to config-specified max length. """
|
||||
while len(self.dec_inp) < max_len:
|
||||
self.dec_inp.append(pad_id)
|
||||
while len(self.tgt) < max_len:
|
||||
self.tgt.append(pad_id)
|
||||
|
||||
|
||||
|
||||
# %% Batch
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Batch(object):
|
||||
def __init__(self, example_list: List[Example], vocab: Vocab, batch_size: int):
|
||||
self.batch_size = batch_size
|
||||
self.pad_id = vocab.word2id(PAD_TOKEN) # id of the PAD token used to pad sequences
|
||||
self.init_encoder_seq(example_list) # initialize the input to the encoder
|
||||
self.init_decoder_seq(example_list) # initialize the input and targets for the decoder
|
||||
self.store_orig_strings(example_list) # store the original strings
|
||||
|
||||
def init_encoder_seq(self, example_list: List[Example]):
|
||||
"""Create self enc_batch/enc_lens/enc_padding_mask from the list of examples. """
|
||||
|
||||
# Determine the maximum length of the encoder input sequence in this batch
|
||||
max_enc_seq_len = max([ex.enc_len for ex in example_list])
|
||||
|
||||
# Pad the encoder input sequences up to the length of the longest sequence
|
||||
for ex in example_list:
|
||||
ex.pad_enc_seq(max_enc_seq_len, self.pad_id)
|
||||
|
||||
# Initialize the numpy arrays
|
||||
# Note: our enc_batch can have different length (second dimension) for each batch
|
||||
# because we use dynamic_rnn for the encoder.
|
||||
self.enc_batch = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
|
||||
self.enc_lens = np.zeros((self.batch_size), dtype=np.int32)
|
||||
self.enc_padding_mask = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.float32)
|
||||
|
||||
# Fill in the numpy arrays
|
||||
for i, ex in enumerate(example_list):
|
||||
self.enc_batch[i, :] = ex.enc_inp[:]
|
||||
self.enc_lens[i] = ex.enc_len
|
||||
for j in range(ex.enc_len):
|
||||
self.enc_padding_mask[i][j] = 1
|
||||
|
||||
# For pointer-generator mode, need to store some extra info
|
||||
if config.pointer_gen:
|
||||
# Determine the max number of in-article OOVs in this batch
|
||||
self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
|
||||
# Store the in-article OOVs themselves
|
||||
self.art_oovs = [ex.article_oovs for ex in example_list]
|
||||
# Store the version of the enc_batch that uses the article OOV ids
|
||||
self.enc_batch_extend_vocab = np.zeros((self.batch_size, max_enc_seq_len), dtype=np.int32)
|
||||
for i, ex in enumerate(example_list):
|
||||
self.enc_batch_extend_vocab[i, :] = ex.enc_inp_extend_vocab[:]
|
||||
|
||||
def init_decoder_seq(self, example_list: List[Example]):
|
||||
"""Create self dec_batch/tgt_batch/dec_lens/dec_padding_mask from the list of examples. """
|
||||
# Pad the inputs and targets
|
||||
for ex in example_list:
|
||||
ex.pad_dec_seq(config.max_dec_steps, self.pad_id)
|
||||
|
||||
# Initialize the numpy arrays.
|
||||
self.dec_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
|
||||
self.tgt_batch = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.int32)
|
||||
self.dec_padding_mask = np.zeros((self.batch_size, config.max_dec_steps), dtype=np.float32)
|
||||
self.dec_lens = np.zeros((self.batch_size), dtype=np.int32)
|
||||
|
||||
# Fill in the numpy arrays
|
||||
for i, ex in enumerate(example_list):
|
||||
self.dec_batch[i, :] = ex.dec_inp[:]
|
||||
self.tgt_batch[i, :] = ex.tgt[:]
|
||||
self.dec_lens[i] = ex.dec_len
|
||||
for j in range(ex.dec_len):
|
||||
self.dec_padding_mask[i][j] = 1
|
||||
|
||||
def store_orig_strings(self, example_list: List[Example]):
|
||||
self.original_sources = [ex.original_source for ex in example_list] # list of lists
|
||||
self.original_targets = [ex.original_target for ex in example_list] # list of lists
|
||||
self.original_table_parents = [ex.original_table_parent for ex in example_list]
|
||||
|
||||
|
||||
|
||||
# %% Batcher
|
||||
|
||||
import glob
|
||||
import json
|
||||
import time
|
||||
import queue
|
||||
import random
|
||||
from threading import Thread
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Batcher(object):
|
||||
BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold
|
||||
|
||||
def __init__(
|
||||
self, vocab: Vocab, data_path: str, # hidden-intend, naming with starting '_'
|
||||
batch_size: int, single_pass: bool, mode: str,
|
||||
):
|
||||
self._vocab = vocab
|
||||
self._data_path = data_path
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.single_pass = single_pass
|
||||
self.mode = mode
|
||||
|
||||
# Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
|
||||
self._batch_queue = queue.Queue(self.BATCH_QUEUE_MAX)
|
||||
self._example_queue = queue.Queue(self.BATCH_QUEUE_MAX * self.batch_size)
|
||||
|
||||
# Different settings depending on whether we're in single_pass mode or not
|
||||
if single_pass:
|
||||
self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
|
||||
self._num_batch_q_threads = 1 # just one thread to batch examples
|
||||
self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing
|
||||
self._finished_reading = False # this will tell us when we're finished reading the dataset
|
||||
else:
|
||||
self._num_example_q_threads = 1 # num threads to fill example queue
|
||||
self._num_batch_q_threads = 1 # num threads to fill batch queue
|
||||
self._bucketing_cache_size = 1 # how many batches-worth of examples to load into cache before bucketing
|
||||
|
||||
# Start the threads that load the queues
|
||||
self._example_q_threads = []
|
||||
for _ in range(self._num_example_q_threads):
|
||||
self._example_q_threads.append(Thread(target=self.fill_example_queue))
|
||||
self._example_q_threads[-1].daemon = True
|
||||
self._example_q_threads[-1].start()
|
||||
self._batch_q_threads = []
|
||||
for _ in range(self._num_batch_q_threads):
|
||||
self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
|
||||
self._batch_q_threads[-1].daemon = True
|
||||
self._batch_q_threads[-1].start()
|
||||
|
||||
# Start a thread that watches the other threads and restarts them if they're dead
|
||||
if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
|
||||
self._watch_thread = Thread(target=self.watch_threads)
|
||||
self._watch_thread.daemon = True
|
||||
self._watch_thread.start()
|
||||
|
||||
def next_batch(self):
|
||||
# If the batch queue is empty, print a warning
|
||||
if self._batch_queue.qsize() == 0:
|
||||
tf.logging.warning(
|
||||
'Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i',
|
||||
self._batch_queue.qsize(), self._example_queue.qsize())
|
||||
if self.single_pass and self._finished_reading:
|
||||
tf.logging.info("Finished reading dataset in single_pass mode.")
|
||||
return None
|
||||
|
||||
batch = self._batch_queue.get() # get the next Batch
|
||||
return batch
|
||||
|
||||
def pair_generator(self, data_path: str, single_pass: bool, verbose: bool = False):
|
||||
"""Generate hmt text pairs to construct examples.
|
||||
Yield (source text, target text, and table parent list) for each turn.
|
||||
"""
|
||||
if verbose: print(f'[pair-generator] from data-path [{data_path}]')
|
||||
|
||||
while True:
|
||||
filelist = glob.glob(data_path)
|
||||
assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
|
||||
if single_pass: filelist = sorted(filelist)
|
||||
else: random.shuffle(filelist)
|
||||
|
||||
for f in filelist:
|
||||
print(f'[pair-gen] reading from file: {f}')
|
||||
reader = open(f, 'r')
|
||||
for line in reader:
|
||||
tabdict = json.loads(line.strip())
|
||||
if verbose: print(f"\n[pair-gen] got sample: \n{tabdict['source']}\n{tabdict['target']}")
|
||||
yield (tabdict['source'], tabdict['target'], tabdict['table_parent'])
|
||||
if single_pass:
|
||||
print("example_generator completed reading all datafiles. No more data.")
|
||||
break
|
||||
|
||||
def fill_example_queue(self):
|
||||
input_generator = self.pair_generator(self._data_path, self.single_pass)
|
||||
|
||||
while True:
|
||||
try:
|
||||
(source, target, table_parent) = input_generator.__next__()
|
||||
except StopIteration: # if there are no more examples:
|
||||
tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
|
||||
if self.single_pass:
|
||||
tf.logging.info(
|
||||
"single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
|
||||
self._finished_reading = True
|
||||
break
|
||||
else:
|
||||
raise Exception("single_pass mode is off but the example generator is out of data; error.")
|
||||
|
||||
example = Example(source, target, table_parent, self._vocab)
|
||||
self._example_queue.put(example)
|
||||
|
||||
def fill_batch_queue(self):
|
||||
while True:
|
||||
if self.mode == 'decode':
|
||||
# beam search decode mode single example repeated in the batch
|
||||
ex = self._example_queue.get()
|
||||
b = [ex for _ in range(self.batch_size)]
|
||||
self._batch_queue.put(Batch(b, self._vocab, self.batch_size))
|
||||
else:
|
||||
# Get bucketing_cache_size-many batches of Examples into a list, then sort
|
||||
inputs = []
|
||||
for _ in range(self.batch_size * self._bucketing_cache_size):
|
||||
inputs.append(self._example_queue.get())
|
||||
inputs = sorted(inputs, key=lambda inp: inp.enc_len, reverse=True) # sort by length of encoder sequence
|
||||
|
||||
# Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
|
||||
batches = []
|
||||
for i in range(0, len(inputs), self.batch_size):
|
||||
batches.append(inputs[i:i + self.batch_size])
|
||||
if not self.single_pass:
|
||||
random.shuffle(batches)
|
||||
for b in batches: # each b is a list of Example objects
|
||||
batch = Batch(b, self._vocab, self.batch_size)
|
||||
self._batch_queue.put(batch)
|
||||
|
||||
def watch_threads(self):
|
||||
while True:
|
||||
tf.logging.info(
|
||||
'Bucket queue size: %i, Input queue size: %i',
|
||||
self._batch_queue.qsize(), self._example_queue.qsize())
|
||||
|
||||
time.sleep(60)
|
||||
for idx, t in enumerate(self._example_q_threads):
|
||||
if not t.is_alive(): # if the thread is dead
|
||||
tf.logging.error('Found example queue thread dead. Restarting.')
|
||||
new_t = Thread(target=self.fill_example_queue)
|
||||
self._example_q_threads[idx] = new_t
|
||||
new_t.daemon = True
|
||||
new_t.start()
|
||||
for idx, t in enumerate(self._batch_q_threads):
|
||||
if not t.is_alive(): # if the thread is dead
|
||||
tf.logging.error('Found batch queue thread dead. Restarting.')
|
||||
new_t = Thread(target=self.fill_batch_queue)
|
||||
self._batch_q_threads[idx] = new_t
|
||||
new_t.daemon = True
|
||||
new_t.start()
|
|
@ -0,0 +1,312 @@
|
|||
"""Run Test.
|
||||
* Beam
|
||||
* BeamSearch
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
from . import config
|
||||
from .model.model import Model
|
||||
from .data import Vocab, Batch, Batcher, BOS_TOKEN, EOS_TOKEN
|
||||
from .utils import outputids2words, get_input_from_batch
|
||||
|
||||
|
||||
class Beam(object):
|
||||
"""A beam searched with probabilities and states. """
|
||||
|
||||
def __init__(self, tokens, log_probs, state, context, coverage):
|
||||
self.tokens = tokens
|
||||
self.log_probs = log_probs
|
||||
self.state = state
|
||||
self.context = context
|
||||
self.coverage = coverage
|
||||
|
||||
def extend(self, token, log_prob, state, context, coverage):
|
||||
return Beam(
|
||||
tokens=self.tokens + [token],
|
||||
log_probs=self.log_probs + [log_prob],
|
||||
state=state,
|
||||
context=context,
|
||||
coverage=coverage
|
||||
)
|
||||
|
||||
@property
|
||||
def latest_token(self):
|
||||
return self.tokens[-1]
|
||||
|
||||
@property
|
||||
def avg_log_prob(self):
|
||||
return sum(self.log_probs) / len(self.tokens)
|
||||
|
||||
|
||||
|
||||
# %% Beam Search
|
||||
|
||||
from datasets import load_metric
|
||||
bleu_scorer = load_metric('bleu')
|
||||
from ..utils.metrics import parent
|
||||
|
||||
|
||||
class BeamSearch(object):
|
||||
"""Beam search with loaded model to generate texts. """
|
||||
|
||||
def __init__(self, args, model_path: str, file_path: str):
|
||||
"""Initialize an instance to perform beam search.
|
||||
args:
|
||||
model_path: str, path of model to load from
|
||||
file_path: str, path of the test dataset (parsed)
|
||||
"""
|
||||
model_name = os.path.basename(model_path)
|
||||
self._test_dir = os.path.join(args.run_dir, f'decode_{model_name}')
|
||||
if not os.path.exists(self._test_dir): os.makedirs(self._test_dir)
|
||||
self._test_metrics_path = os.path.join(self._test_dir, 'metrics')
|
||||
|
||||
self.vocab = Vocab(args.vocab_path, args.vocab_size)
|
||||
self.batcher = Batcher(
|
||||
vocab=self.vocab,
|
||||
data_path=file_path,
|
||||
batch_size=config.beam_size,
|
||||
single_pass=True,
|
||||
mode='decode',
|
||||
)
|
||||
time.sleep(15)
|
||||
|
||||
self.model = Model(
|
||||
config=config,
|
||||
model_path=model_path,
|
||||
is_eval=True,
|
||||
is_transformer=False,
|
||||
)
|
||||
|
||||
self.use_cuda = config.use_gpu and torch.cuda.is_available()
|
||||
self.config = config
|
||||
|
||||
def sort_beams(self, beams: List[Beam]) -> List[Beam]:
|
||||
return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)
|
||||
|
||||
def beam_search(self, batch: Batch):
|
||||
# single example repeated across the batch
|
||||
(
|
||||
enc_batch, enc_lens, enc_pos, enc_padding_mask,
|
||||
enc_batch_extend_vocab, extra_zeros, c_t, coverage
|
||||
) = get_input_from_batch(batch, self.use_cuda, self.config)
|
||||
enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
|
||||
s_t = self.model.reduce_state(enc_h)
|
||||
|
||||
dec_h, dec_c = s_t # b x hidden_dim
|
||||
dec_h = dec_h.squeeze()
|
||||
dec_c = dec_c.squeeze()
|
||||
|
||||
# decoder batch preparation,
|
||||
# it has beam_size example initially everything is repeated
|
||||
beams = [
|
||||
Beam(
|
||||
tokens=[self.vocab.word2id(BOS_TOKEN)],
|
||||
log_probs=[0.0],
|
||||
state=(dec_h[0], dec_c[0]),
|
||||
context=c_t[0],
|
||||
coverage=(coverage[0] if self.config.is_coverage else None)
|
||||
)
|
||||
for _ in range(self.config.beam_size)
|
||||
]
|
||||
|
||||
steps = 0
|
||||
results = []
|
||||
while steps < self.config.max_dec_steps and len(results) < self.config.beam_size:
|
||||
latest_tokens = [h.latest_token for h in beams]
|
||||
latest_tokens = [
|
||||
t if (t < self.vocab.size())
|
||||
else self.vocab.word2id(self.config.UNK_TOKEN)
|
||||
for t in latest_tokens
|
||||
]
|
||||
y_t = torch.autograd.Variable(torch.LongTensor(latest_tokens))
|
||||
if self.use_cuda: y_t = y_t.cuda()
|
||||
all_state_h = [h.state[0] for h in beams]
|
||||
all_state_c = [h.state[1] for h in beams]
|
||||
all_context = [h.context for h in beams]
|
||||
|
||||
s_t = (
|
||||
torch.stack(all_state_h, 0).unsqueeze(0),
|
||||
torch.stack(all_state_c, 0).unsqueeze(0)
|
||||
)
|
||||
c_t = torch.stack(all_context, 0)
|
||||
|
||||
coverage_t = None
|
||||
if self.config.is_coverage:
|
||||
all_coverage = [h.coverage for h in beams]
|
||||
coverage_t = torch.stack(all_coverage, 0)
|
||||
|
||||
final_dist, s_t, c_t, attn_dist, p_gen, coverage_t = self.model.decoder(
|
||||
y_t, s_t, enc_out, enc_fea, enc_padding_mask, c_t,
|
||||
extra_zeros, enc_batch_extend_vocab, coverage_t, steps
|
||||
)
|
||||
log_probs = torch.log(final_dist)
|
||||
topk_log_probs, topk_ids = torch.topk(log_probs, self.config.beam_size * 2)
|
||||
|
||||
dec_h, dec_c = s_t
|
||||
dec_h = dec_h.squeeze()
|
||||
dec_c = dec_c.squeeze()
|
||||
|
||||
all_beams = []
|
||||
# On the first step, we only had one original hypothesis (the initial hypothesis).
|
||||
# On subsequent steps, all original hypotheses are distinct.
|
||||
num_orig_beams = 1 if steps == 0 else len(beams)
|
||||
for i in range(num_orig_beams):
|
||||
h = beams[i]
|
||||
state_i = (dec_h[i], dec_c[i])
|
||||
context_i = c_t[i]
|
||||
coverage_i = (coverage[i] if self.config.is_coverage else None)
|
||||
|
||||
# for each of the top 2*beam_size hyps:
|
||||
for j in range(self.config.beam_size * 2):
|
||||
new_beam = h.extend(
|
||||
token=topk_ids[i, j].item(),
|
||||
log_prob=topk_log_probs[i, j].item(),
|
||||
state=state_i,
|
||||
context=context_i,
|
||||
coverage=coverage_i
|
||||
)
|
||||
all_beams.append(new_beam)
|
||||
|
||||
beams = []
|
||||
for h in self.sort_beams(all_beams):
|
||||
if h.latest_token == self.vocab.word2id(EOS_TOKEN):
|
||||
if steps >= self.config.min_dec_steps:
|
||||
results.append(h)
|
||||
else:
|
||||
beams.append(h)
|
||||
if len(beams) == self.config.beam_size or len(results) == self.config.beam_size:
|
||||
break
|
||||
|
||||
steps += 1
|
||||
|
||||
if len(results) == 0:
|
||||
results = beams
|
||||
|
||||
beams_sorted = self.sort_beams(results)
|
||||
|
||||
return beams_sorted[0] # best_summary
|
||||
|
||||
def run(self, interval: int = 1000):
|
||||
"""Run beam-search on each test sample.
|
||||
interval: number of batch steps for logging info.
|
||||
"""
|
||||
counter = 0
|
||||
start = time.time()
|
||||
|
||||
all_pred_tokens, all_ref_tokens = [], []
|
||||
|
||||
batch = self.batcher.next_batch()
|
||||
while batch: # not None or not Empty
|
||||
# run beam search to get best Hypothesis
|
||||
try: best_summary = self.beam_search(batch)
|
||||
except: break # RuntimeError: Cannot pack empty tensors.
|
||||
|
||||
# extract the output ids from the hypothesis and convert back to words
|
||||
output_ids = [int(t) for t in best_summary.tokens[1:]]
|
||||
article_oovs = batch.art_oovs[0] if self.config.pointer_gen else None
|
||||
decoded_words = outputids2words(
|
||||
id_list=output_ids,
|
||||
vocab=self.vocab,
|
||||
article_oovs=article_oovs,
|
||||
)
|
||||
|
||||
# remove the [STOP] token from decoded_words, if necessary
|
||||
try:
|
||||
fst_stop_idx = decoded_words.index(EOS_TOKEN)
|
||||
decoded_words = decoded_words[: fst_stop_idx]
|
||||
except ValueError:
|
||||
decoded_words = decoded_words
|
||||
|
||||
all_pred_tokens.append(decoded_words)
|
||||
all_ref_tokens.append([
|
||||
tgt for tgt in batch.original_targets # [: 1]
|
||||
])
|
||||
|
||||
counter += 1
|
||||
if counter % interval == 0:
|
||||
print(f'{counter:d} example in {int(time.time()-start):d} sec')
|
||||
start = time.time()
|
||||
print(f'ORG: {all_ref_tokens[-1]}')
|
||||
print(f'GEN: {decoded_words}')
|
||||
|
||||
batch = self.batcher.next_batch()
|
||||
|
||||
print(f'Decoder has finished reading dataset for single_pass.')
|
||||
print(f'Now starting BLEU eval...')
|
||||
results_dict = bleu_scorer.compute(
|
||||
predictions=all_pred_tokens,
|
||||
references=all_ref_tokens,
|
||||
)
|
||||
print(f"BLEU: {results_dict['bleu']:.4f}")
|
||||
|
||||
with open(self._test_metrics_path, 'w') as fw:
|
||||
mline = f"bleu-4\t{results_dict['bleu']:.4f}\n"
|
||||
fw.write(mline)
|
||||
for idx, (pred, ref) in enumerate(zip(all_pred_tokens, all_ref_tokens)):
|
||||
iscore = bleu_scorer.compute(
|
||||
predictions=[pred],
|
||||
references=[ref]
|
||||
)
|
||||
fw.write(f"#{idx}: {iscore['bleu']:.4f}\n")
|
||||
fw.write(f'{pred}\n{ref[0]}\n\n')
|
||||
|
||||
def eval_parent(self, interval=1000):
|
||||
"""Run beam-search on each test sample."""
|
||||
counter = 0
|
||||
start = time.time()
|
||||
|
||||
all_pred_tokens, all_ref_tokens = [], []
|
||||
all_table_tokens = []
|
||||
|
||||
batch = self.batcher.next_batch()
|
||||
while batch: # not None or not Empty
|
||||
# run beam search to get best Hypothesis
|
||||
try: best_summary = self.beam_search(batch)
|
||||
except: break # RuntimeError: Cannot pack empty tensors.
|
||||
|
||||
# extract the output ids from the hypothesis and convert back to words
|
||||
output_ids = [int(t) for t in best_summary.tokens[1:]]
|
||||
article_oovs = batch.art_oovs[0] if self.config.pointer_gen else None
|
||||
decoded_words = outputids2words(
|
||||
id_list=output_ids,
|
||||
vocab=self.vocab,
|
||||
article_oovs=article_oovs,
|
||||
)
|
||||
|
||||
# remove the [STOP] token from decoded_words, if necessary
|
||||
try:
|
||||
fst_stop_idx = decoded_words.index(EOS_TOKEN)
|
||||
decoded_words = decoded_words[: fst_stop_idx]
|
||||
except ValueError:
|
||||
decoded_words = decoded_words
|
||||
|
||||
all_pred_tokens.append(decoded_words)
|
||||
all_ref_tokens.append([
|
||||
tgt for tgt in batch.original_targets[: 1]
|
||||
])
|
||||
all_table_tokens.append(batch.original_table_parents[0]) # @zhiruow
|
||||
|
||||
counter += 1
|
||||
if counter % interval == 0:
|
||||
print(f'{counter:d} example in {int(time.time()-start):d} sec')
|
||||
start = time.time()
|
||||
print(f'TGT: {all_ref_tokens[-1]}')
|
||||
print(f'GEN: {decoded_words}')
|
||||
# print(f'TAB: {all_table_tokens[-1]}')
|
||||
|
||||
batch = self.batcher.next_batch()
|
||||
|
||||
print(f'Decoder has finished reading dataset for single_pass.')
|
||||
print(f'Now starting PARENT eval...')
|
||||
results_dict = parent(
|
||||
predictions=all_pred_tokens,
|
||||
references=all_ref_tokens,
|
||||
tables=all_table_tokens,
|
||||
return_dict=True,
|
||||
)
|
||||
print(f"PARENT: {results_dict['average_f1']:.4f}")
|
|
@ -0,0 +1,39 @@
|
|||
"""Initialize the suite of pointer-generator network models.
|
||||
|
||||
With a 'Basic Module' that supports model-weight initializations.
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
|
||||
# %% Basic Module for models
|
||||
|
||||
class BasicModule(nn.Module):
|
||||
"""Initialization for models."""
|
||||
|
||||
def __init__(self, init_method='uniform'):
|
||||
"""Initialize model weights with uniform distribution by default.
|
||||
init_method: choices = ['uniform', 'normal']
|
||||
future-to-add: 'truncated_normal'
|
||||
"""
|
||||
super(BasicModule, self).__init__()
|
||||
self.init_method = init_method
|
||||
|
||||
def init_params(self, init_range=0.05):
|
||||
"""Initialize self weights/parameters with the specified method."""
|
||||
|
||||
if not (self.init_method in ['uniform', 'normal']): # 'truncated_normal'
|
||||
print(f'[BasicModule >> init_params] not supporting init_method: {self.init_method}')
|
||||
return
|
||||
|
||||
for param in self.parameters():
|
||||
if (not param.requires_grad) or len(param.shape) == 0:
|
||||
continue
|
||||
if self.init_method == 'uniform':
|
||||
nn.init.uniform_(param, a=-init_range, b=init_range)
|
||||
else:
|
||||
stddev = 1 / math.sqrt(param.shape[0])
|
||||
if self.init_method == 'normal':
|
||||
nn.init.normal_(param, std=stddev)
|
|
@ -0,0 +1,69 @@
|
|||
"""Attention Module."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from experiment.pointer_generator.model import BasicModule
|
||||
|
||||
|
||||
class Attention(BasicModule):
|
||||
"""Encoder-Decoder Attention Module."""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize with config (hidden-dim)."""
|
||||
super(Attention, self).__init__()
|
||||
self.hidden_dim = config.hidden_dim
|
||||
self.is_coverage = config.is_coverage
|
||||
|
||||
self.fc = nn.Linear(2 * config.hidden_dim, 1, bias=False)
|
||||
self.dec_fc = nn.Linear(2 * config.hidden_dim, 2 * config.hidden_dim)
|
||||
if self.is_coverage:
|
||||
self.con_fc = nn.Linear(1, 2 * config.hidden_dim, bias=False)
|
||||
|
||||
def forward(self, s_t, enc_out, enc_fea, enc_padding_mask, coverage):
|
||||
"""
|
||||
b = batch-size, l == seq-len, n == 2 * hidden_dim
|
||||
|
||||
args:
|
||||
s_t: [batch-size, 2 * hidden-dim]
|
||||
enc_out: [batch-size, seq-len, 2 * hidden-dim]
|
||||
enc_fea: [batch-size * seq-len, 2 * hidden-dim]
|
||||
enc_padding_mask: [batch-size, seq-len]
|
||||
coverage: [batch-size, seq-len]
|
||||
rets:
|
||||
c_t: [batch-size, 2 * hidden-dim]
|
||||
attn_dist: [batch-size, seq-len]
|
||||
coverage: [batch-size, seq-len]
|
||||
"""
|
||||
b, l, n = list(enc_out.size())
|
||||
|
||||
dec_fea = self.dec_fc(s_t) # [b, 2*hid]
|
||||
dec_fea_expanded = dec_fea.unsqueeze(1).expand(b,l,n).contiguous() # [b,l,2*hid]
|
||||
dec_fea_expanded = dec_fea_expanded.view(-1, n) # [b*l,2*hid]
|
||||
|
||||
att_features = enc_fea + dec_fea_expanded # [b*l,2*hid]
|
||||
if self.is_coverage:
|
||||
coverage_inp = coverage.view(-1, 1) # [b*l, 1]
|
||||
coverage_fea = self.con_fc(coverage_inp) # [b*l,2*hid]
|
||||
att_features = att_features + coverage_fea # [b*l,2*hid]
|
||||
|
||||
e = torch.tanh(att_features) # [b*l,2*hid]
|
||||
scores = self.fc(e) # [b*l,1]
|
||||
scores = scores.view(-1, l) # [b,l]
|
||||
|
||||
attn_dist_ = F.softmax(scores, dim=-1) * enc_padding_mask # [b,l]
|
||||
normalization_factor_ = attn_dist_.sum(1, keepdim=True) # [b]
|
||||
attn_dist = attn_dist_ / (normalization_factor_ + 1e-6) # [b,l]
|
||||
|
||||
attn_dist = attn_dist.unsqueeze(1) # [b,1,l]
|
||||
c_t = torch.bmm(attn_dist, enc_out) # [b,1,n]
|
||||
c_t = c_t.view(-1, 2 * self.hidden_dim) # [b,2*hid]
|
||||
|
||||
attn_dist = attn_dist.view(-1, l) # [b, l]
|
||||
|
||||
if self.is_coverage:
|
||||
coverage = coverage.view(-1, l)
|
||||
coverage = coverage + attn_dist
|
||||
|
||||
return c_t, attn_dist, coverage
|
|
@ -0,0 +1,197 @@
|
|||
"""Component/Layers of Model: Encoder, ReduceState, and Decoder."""
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_packed_sequence
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
|
||||
from experiment.pointer_generator.model import BasicModule
|
||||
from experiment.pointer_generator.model.attention import Attention
|
||||
|
||||
|
||||
# %% Encoder
|
||||
class Encoder(BasicModule):
|
||||
"""For encoding."""
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize the model encoder with config."""
|
||||
super(Encoder, self).__init__()
|
||||
self.src_word_emb = nn.Embedding(config.vocab_size, config.emb_dim)
|
||||
self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim,
|
||||
batch_first=True, bidirectional=True)
|
||||
self.fc = nn.Linear(2 * config.hidden_dim, 2 * config.hidden_dim, bias=False)
|
||||
|
||||
self.hidden_dim = config.hidden_dim
|
||||
|
||||
self.init_params()
|
||||
|
||||
def forward(self, enc_input, seq_lens):
|
||||
"""
|
||||
n == 2 * hidden-dim
|
||||
args:
|
||||
enc_input: [batch-size, seq-len, vocab-size-index]
|
||||
seq_lens: [batch-size, ]
|
||||
rets:
|
||||
encoder_outputs: [batch-size, seq-len, 2 * hidden-dim]
|
||||
encoder_feature: [batch-size * seq-len, 2 * hidden-dim]
|
||||
hidden: [2, batch-size, hidden-dim] = h, c of [batch-size, hidden-dim]
|
||||
Notes: 'seq_lens' should be in descending order.
|
||||
"""
|
||||
embedded = self.src_word_emb(enc_input) # [batch-size, seq-len, emb-dim]
|
||||
|
||||
packed = pack_padded_sequence(embedded, seq_lens, batch_first=True)
|
||||
output, hidden = self.lstm(packed)
|
||||
|
||||
encoder_outputs, _ = pad_packed_sequence(output, batch_first=True) # [b, l, n]
|
||||
encoder_outputs = encoder_outputs.contiguous() # [b, l, n]
|
||||
|
||||
encoder_feature = encoder_outputs.view(-1, 2 * self.hidden_dim) # [b*l, n]
|
||||
encoder_feature = self.fc(encoder_feature) # [b*l, n]
|
||||
|
||||
return encoder_outputs, encoder_feature, hidden
|
||||
|
||||
|
||||
|
||||
# %% ReduceState
|
||||
class ReduceState(BasicModule):
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize the reduce module with config (hidden-dim)."""
|
||||
super(ReduceState, self).__init__()
|
||||
|
||||
self.reduce_h = nn.Linear(2 * config.hidden_dim, config.hidden_dim)
|
||||
self.reduce_c = nn.Linear(2 * config.hidden_dim, config.hidden_dim)
|
||||
|
||||
self.hidden_dim = config.hidden_dim
|
||||
|
||||
self.init_params()
|
||||
|
||||
def forward(self, hidden):
|
||||
"""
|
||||
args:
|
||||
hidden: [2, batch-size, hidden-dim]
|
||||
rets:
|
||||
hidden_reduced_h: [1, batch-size, hidden-dim]
|
||||
hidden_reduced_c: [1, batch-size, hidden-dim]
|
||||
"""
|
||||
h, c = hidden # [batch-size, seq_len, hidden-dim]
|
||||
h_in = h.transpose(0, 1).contiguous().view(-1, self.hidden_dim * 2)
|
||||
hidden_reduced_h = F.relu(self.reduce_h(h_in))
|
||||
ret_h = hidden_reduced_h.unsqueeze(0) # [1, batch-size, hidden-dim]
|
||||
c_in = c.transpose(0, 1).contiguous().view(-1, self.hidden_dim * 2)
|
||||
hidden_reduced_c = F.relu(self.reduce_c(c_in))
|
||||
ret_c = hidden_reduced_c.unsqueeze(0) # [1, batch-size, hidden-dim]
|
||||
return (ret_h, ret_c)
|
||||
|
||||
|
||||
|
||||
# %% Decoder
|
||||
|
||||
class Decoder(BasicModule):
|
||||
def __init__(self, config):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
self.attention_network = Attention(config)
|
||||
|
||||
# decoder
|
||||
self.tgt_word_emb = nn.Embedding(config.vocab_size, config.emb_dim)
|
||||
self.con_fc = nn.Linear(2 * config.hidden_dim + config.emb_dim, config.emb_dim)
|
||||
self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim,
|
||||
batch_first=True, bidirectional=False)
|
||||
|
||||
if config.pointer_gen:
|
||||
self.p_gen_fc = nn.Linear(4 * config.hidden_dim + config.emb_dim, 1)
|
||||
|
||||
# p_vocab
|
||||
self.fc1 = nn.Linear(3 * config.hidden_dim, config.hidden_dim)
|
||||
self.fc2 = nn.Linear(config.hidden_dim, config.vocab_size)
|
||||
|
||||
self.hidden_dim = config.hidden_dim
|
||||
self.pointer_gen = config.pointer_gen
|
||||
|
||||
self.init_params()
|
||||
|
||||
def forward(
|
||||
self, y_t, s_t,
|
||||
enc_out, enc_fea, enc_padding_mask,
|
||||
c_t, extra_zeros, enc_batch_extend_vocab, coverage, step
|
||||
):
|
||||
"""
|
||||
args:
|
||||
y_t: [batch-size, vocab-size-index]
|
||||
s_t: h & c states, ([batch-size, hidden-dim], [batch-size, hidden-dim])
|
||||
enc_out: [batch-size, seq-len, 2 * hidden-dim]
|
||||
enc_fea: [batch-size * seq-len, 2 * hidden-dim]
|
||||
enc_padding_mask: [batch-size, seq-len]
|
||||
c_t: [batch-size, 2 * hidden-dim]
|
||||
extra_zeros:
|
||||
enc_batch_extend_vocab:
|
||||
coverage: [batch-size, seq-len]
|
||||
step: int
|
||||
rets:
|
||||
c_t: [batch-size, 2 * hidden-dim]
|
||||
attn_dist: output of attention-network, [batch-size, seq-len]
|
||||
p_gen: geneation of pointer-network, [batch-size, 1]
|
||||
coverage: coverage over the input words, [batch-size, seq-len]
|
||||
"""
|
||||
if (not self.training) and (step == 0):
|
||||
dec_h, dec_c = s_t # [batch-size, hidden-dim]
|
||||
s_t_hat = torch.cat(
|
||||
tensors=(
|
||||
dec_h.view(-1, self.hidden_dim),
|
||||
dec_c.view(-1, self.hidden_dim)
|
||||
),
|
||||
dim=1
|
||||
) # [batch-size, 2 * hidden-dim]
|
||||
c_t, _, coverage_next = self.attention_network(
|
||||
s_t_hat, enc_out, enc_fea, enc_padding_mask, coverage)
|
||||
coverage = coverage_next # [batch-size, seq-len]
|
||||
|
||||
y_t_embed = self.tgt_word_emb(y_t) # [b, emb-dim] [16, 128]
|
||||
x = self.con_fc( torch.cat((c_t, y_t_embed), dim=1) ) # [b,2*hid+emb] >> [b,emb]
|
||||
lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t)
|
||||
|
||||
dec_h, dec_c = s_t # [b, hid]
|
||||
s_t_hat = torch.cat(
|
||||
tensors=(
|
||||
dec_h.view(-1, self.hidden_dim),
|
||||
dec_c.view(-1, self.hidden_dim)
|
||||
),
|
||||
dim=1
|
||||
) # [b,2*hid]
|
||||
c_t, attn_dist, coverage_next = self.attention_network(
|
||||
s_t_hat, enc_out, enc_fea, enc_padding_mask, coverage)
|
||||
|
||||
if self.training or (step > 0):
|
||||
coverage = coverage_next
|
||||
|
||||
p_gen = None
|
||||
if self.pointer_gen:
|
||||
p_gen_inp = torch.cat((c_t, s_t_hat, x), 1) # [b, 2*hid+2*hid+emb]
|
||||
p_gen = self.p_gen_fc(p_gen_inp) # [b, 1]
|
||||
p_gen = torch.sigmoid(p_gen)
|
||||
|
||||
output = torch.cat(
|
||||
tensors=(lstm_out.view(-1, self.hidden_dim), c_t),
|
||||
dim=1
|
||||
) # [b, hid+2*hid]
|
||||
output = self.fc1(output) # [b, hid]
|
||||
# output = F.relu(output)
|
||||
|
||||
output = self.fc2(output) # [b, vocab-size]
|
||||
vocab_dist = F.softmax(output, dim=1)
|
||||
|
||||
if self.pointer_gen:
|
||||
vocab_dist_ = p_gen * vocab_dist
|
||||
attn_dist_ = (1 - p_gen) * attn_dist
|
||||
|
||||
if extra_zeros is not None:
|
||||
vocab_dist_ = torch.cat((vocab_dist_, extra_zeros), dim=1)
|
||||
|
||||
final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
|
||||
else:
|
||||
final_dist = vocab_dist
|
||||
|
||||
return final_dist, s_t, c_t, attn_dist, p_gen, coverage
|
|
@ -0,0 +1,42 @@
|
|||
"""Overall Model Architecture."""
|
||||
|
||||
import torch
|
||||
from experiment.pointer_generator.model.layers import Encoder, ReduceState, Decoder
|
||||
|
||||
|
||||
|
||||
class Model(object):
|
||||
"""Model class consists of an encoder, a reduce-state, and a decoder."""
|
||||
|
||||
def __init__(self, config, model_path=None, is_eval=False, is_transformer=False):
|
||||
super(Model, self).__init__()
|
||||
|
||||
encoder = Encoder(config)
|
||||
decoder = Decoder(config)
|
||||
reduce_state = ReduceState(config)
|
||||
if is_transformer:
|
||||
print(f'Transformer Encoder is not yet available.')
|
||||
|
||||
# share the embedding between encoder and decoder
|
||||
decoder.tgt_word_emb.weight = encoder.src_word_emb.weight
|
||||
|
||||
if is_eval:
|
||||
encoder = encoder.eval()
|
||||
decoder = decoder.eval()
|
||||
reduce_state = reduce_state.eval()
|
||||
|
||||
self.use_cuda = config.use_gpu and torch.cuda.is_available()
|
||||
if self.use_cuda:
|
||||
encoder = encoder.cuda()
|
||||
decoder = decoder.cuda()
|
||||
reduce_state = reduce_state.cuda()
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.reduce_state = reduce_state
|
||||
|
||||
if model_path is not None:
|
||||
state = torch.load(model_path, map_location=lambda storage, location: storage)
|
||||
self.encoder.load_state_dict(state['encoder_state_dict'])
|
||||
self.decoder.load_state_dict(state['decoder_state_dict'], strict=False)
|
||||
self.reduce_state.load_state_dict(state['reduce_state_dict'])
|
|
@ -0,0 +1,98 @@
|
|||
"""Parse a sample using stanford stanza tool."""
|
||||
|
||||
import stanza
|
||||
# stanza.download('en')
|
||||
nlp = stanza.Pipeline('en')
|
||||
|
||||
def parse_str_item(s, vocab_counter=None):
|
||||
doc = nlp(s.strip())
|
||||
doc_words = [ w.text for sent in doc.sentences
|
||||
for w in sent.words]
|
||||
doc_words = [dw.strip().lower() for dw in doc_words]
|
||||
doc_words = [dw for dw in doc_words if dw!='']
|
||||
if vocab_counter is not None:
|
||||
vocab_counter.update(doc_words)
|
||||
return doc_words
|
||||
|
||||
def parse_str_list(string_list, vocab_counter=None):
|
||||
parsed_string_list = []
|
||||
for string in string_list:
|
||||
doc_words = parse_str_item(string, vocab_counter)
|
||||
parsed_string_list.append(doc_words)
|
||||
return parsed_string_list
|
||||
|
||||
def parse_fielded_list(fielded_list, vocab_counter=None):
|
||||
parsed_fielded_list = []
|
||||
for attr, value in fielded_list:
|
||||
value_words = parse_str_item(value, vocab_counter)
|
||||
parsed_fielded_list.append( (attr, value_words) )
|
||||
return parsed_fielded_list
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
def parse_sample_dict(sample: Dict, vocab_counter: Dict = None) -> Dict:
|
||||
"""Parse a processed sample with pointer-generator vocab.
|
||||
args:
|
||||
sample = Dict{
|
||||
'source': str,
|
||||
'target': str,
|
||||
'table_parent': List[List[str,str]]
|
||||
}
|
||||
rets:
|
||||
parsed_sample = Dict{
|
||||
'source': List[str],
|
||||
'target': List[str],
|
||||
'table_parent': List[List[List[str], List[str]]]
|
||||
}
|
||||
"""
|
||||
source_words = parse_str_item(sample['source'], vocab_counter)
|
||||
target_words = parse_str_item(sample['target'], vocab_counter)
|
||||
parent_words = parse_fielded_list(sample['table_parent'], vocab_counter)
|
||||
parsed_sample = {
|
||||
'source': source_words,
|
||||
'target': target_words,
|
||||
'table_parent': parent_words,
|
||||
}
|
||||
return parsed_sample
|
||||
|
||||
|
||||
import json
|
||||
|
||||
def parse_datafile(infile: str, outfile: str, vocab_counter: Dict = None, report_steps: int = 1000) -> None:
|
||||
"""Parse the in-file dataset, write into the out-file, update the vocab-counter."""
|
||||
|
||||
output_instances = []
|
||||
with open(infile, 'r', encoding='utf-8') as fr:
|
||||
for idx, line in enumerate(fr):
|
||||
inins = json.loads(line.strip())
|
||||
outins = parse_sample_dict(inins, vocab_counter)
|
||||
output_instances.append(outins)
|
||||
|
||||
if (idx + 1) % report_steps == 0:
|
||||
print(f'successfully parsed {idx+1} samples..')
|
||||
|
||||
with open(outfile, 'w', encoding='utf-8') as fw:
|
||||
for idx, outins in enumerate(output_instances):
|
||||
outline = json.dumps(outins)
|
||||
fw.write(outline + '\n')
|
||||
|
||||
if (idx + 1) % report_steps == 0:
|
||||
print(f'successfully wrote {idx+1} samples..')
|
||||
|
||||
print(f'Finished! from [{infile}] to [{outfile}]')
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--infile_list', type=str, nargs='+', required=True)
|
||||
parser.add_argument('--outfile_list', type=str, nargs='+', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.infile_list) != len(args.outfile_list):
|
||||
print(f'unmatched {len(args.infile_list)} inputs and {len(args.outfile_list)} outputs. ')
|
||||
|
||||
for infile, outfile in zip(args.infile_list, args.outfile_list):
|
||||
parse_datafile(infile, outfile)
|
|
@ -0,0 +1,205 @@
|
|||
"""Trainer for the Pointer-Generator Network. """
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from . import config, utils
|
||||
from .data import Batch, Vocab, Batcher
|
||||
from .model.model import Model
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
"""Epoch-wise train and validation."""
|
||||
|
||||
def __init__(self, args):
|
||||
"""Initialize the trainer with config.
|
||||
(vocab_path, vocab_size, train_data_path, batch_size, log_root)
|
||||
"""
|
||||
self.args = args
|
||||
self.use_cuda = config.use_gpu and torch.cuda.is_available()
|
||||
|
||||
self.vocab = Vocab(args.vocab_path, args.vocab_size)
|
||||
print(
|
||||
f'model load data in batch ({args.per_device_train_batch_size}) ',
|
||||
f'from path: {args.train_data_path}'
|
||||
)
|
||||
self.batcher = Batcher(
|
||||
vocab=self.vocab,
|
||||
data_path=args.train_data_path,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
single_pass=True,
|
||||
mode='train',
|
||||
)
|
||||
time.sleep(args.train_sleep_time)
|
||||
|
||||
train_dir = os.path.join(args.run_dir, 'train') # f'train_{int(time.time())}'
|
||||
if not os.path.exists(train_dir):
|
||||
os.makedirs(train_dir)
|
||||
|
||||
self.model_dir = os.path.join(train_dir, 'models')
|
||||
if not os.path.exists(self.model_dir):
|
||||
os.makedirs(self.model_dir)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
self.summary_writer = tf.summary.FileWriter(train_dir)
|
||||
|
||||
self.running_avg_loss = None
|
||||
|
||||
def save_model(self, running_avg_loss, iepoch: int) -> None:
|
||||
"""Save the model state in path.
|
||||
Includes (encoder/decoder/reduce-state), optimizer dict, and current loss.
|
||||
"""
|
||||
state = {
|
||||
'iepoch': iepoch,
|
||||
'encoder_state_dict': self.model.encoder.state_dict(),
|
||||
'decoder_state_dict': self.model.decoder.state_dict(),
|
||||
'reduce_state_dict': self.model.reduce_state.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'current_loss': running_avg_loss
|
||||
}
|
||||
model_save_path = os.path.join(self.model_dir, f'model_{iepoch:d}.bin')
|
||||
torch.save(state, model_save_path)
|
||||
|
||||
def setup_train(self, model_path: str = None) -> float:
|
||||
"""Set-up the starting iteration index and loss before activating the training process."""
|
||||
self.model = Model(config, model_path)
|
||||
# initial_lr = config.lr_coverage if config.is_coverage else config.lr
|
||||
initial_lr = self.args.learning_rate
|
||||
|
||||
params = list(self.model.encoder.parameters()) + \
|
||||
list(self.model.decoder.parameters()) + \
|
||||
list(self.model.reduce_state.parameters())
|
||||
total_params = sum([param[0].nelement() for param in params])
|
||||
print(f'The Number of params of model: {total_params/1e3:.3f} k')
|
||||
self.optimizer = optim.Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
|
||||
print(f'trainer.optimizer with lr: {initial_lr:.6f}, acc_val: {config.adagrad_init_acc:.6f}')
|
||||
|
||||
start_loss = 0.
|
||||
|
||||
if model_path is not None:
|
||||
state = torch.load(model_path, map_location=lambda storage, location: storage)
|
||||
start_loss = state['current_loss']
|
||||
|
||||
if not config.is_coverage:
|
||||
self.optimizer.load_state_dict(state['optimizer'])
|
||||
if self.use_cuda:
|
||||
for state in self.optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if torch.is_tensor(v):
|
||||
state[k] = v.cuda()
|
||||
|
||||
return start_loss
|
||||
|
||||
def train_one_batch(self, batch: Batch):
|
||||
"""Execute one training step with a batch of data."""
|
||||
(
|
||||
enc_batch, enc_lens, enc_pos, enc_padding_mask,
|
||||
enc_batch_extend_vocab, extra_zeros, c_t, coverage
|
||||
) = utils.get_input_from_batch(batch, self.use_cuda, config)
|
||||
dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
|
||||
utils.get_output_from_batch(batch, self.use_cuda, config)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if not config.tran:
|
||||
enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_lens)
|
||||
else:
|
||||
enc_out, enc_fea, enc_h = self.model.encoder(enc_batch, enc_pos)
|
||||
|
||||
s_t = self.model.reduce_state(enc_h)
|
||||
|
||||
step_losses, cove_losses = [], []
|
||||
for di in range(min(max_dec_len, config.max_dec_steps)):
|
||||
y_t = dec_batch[:, di] # Teacher forcing
|
||||
final_dist, s_t, c_t, attn_dist, p_gen, next_coverage = \
|
||||
self.model.decoder(
|
||||
y_t, s_t, enc_out, enc_fea, enc_padding_mask,
|
||||
c_t, extra_zeros, enc_batch_extend_vocab, coverage, di
|
||||
)
|
||||
tgt = tgt_batch[:, di]
|
||||
step_mask = dec_padding_mask[:, di]
|
||||
gold_probs = torch.gather(final_dist, 1, tgt.unsqueeze(1)).squeeze()
|
||||
step_loss = -torch.log(gold_probs + config.eps)
|
||||
if config.is_coverage:
|
||||
step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
|
||||
step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
|
||||
cove_losses.append(step_coverage_loss * step_mask)
|
||||
coverage = next_coverage
|
||||
|
||||
step_loss = step_loss * step_mask
|
||||
step_losses.append(step_loss)
|
||||
|
||||
sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
|
||||
batch_avg_loss = sum_losses / dec_lens
|
||||
loss = torch.mean(batch_avg_loss)
|
||||
|
||||
loss.backward()
|
||||
|
||||
clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
|
||||
clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
|
||||
clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
if config.is_coverage:
|
||||
cove_losses = torch.sum(torch.stack(cove_losses, 1), 1)
|
||||
batch_cove_loss = cove_losses / dec_lens
|
||||
batch_cove_loss = torch.mean(batch_cove_loss)
|
||||
if loss.item() == float('nan') or batch_cove_loss.item() == float('nan'):
|
||||
print('nan')
|
||||
return loss.item(), batch_cove_loss.item()
|
||||
|
||||
return loss.item(), 0.
|
||||
|
||||
def run_one_epoch(
|
||||
self, iepoch: int,
|
||||
model_path: str = None,
|
||||
interval: int = 1000,
|
||||
save_model: bool = True,
|
||||
):
|
||||
if (iepoch == 0) or (self.running_avg_loss is None):
|
||||
self.running_avg_loss = self.setup_train(model_path)
|
||||
print(f'no.epoch {iepoch}, self avg loss: {self.running_avg_loss}')
|
||||
print(f'setup training loss {self.running_avg_loss} from model path: {model_path}')
|
||||
|
||||
self.batcher = Batcher(
|
||||
vocab=self.vocab,
|
||||
data_path=self.args.train_data_path,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
single_pass=True,
|
||||
mode='train',
|
||||
)
|
||||
time.sleep(self.args.train_sleep_time)
|
||||
|
||||
start = time.time()
|
||||
|
||||
i_iter = 0
|
||||
while True:
|
||||
batch = self.batcher.next_batch()
|
||||
if batch is None: break
|
||||
loss, cove_loss = self.train_one_batch(batch)
|
||||
self.running_avg_loss = utils.calc_running_avg_loss(
|
||||
loss, self.running_avg_loss, self.summary_writer, i_iter)
|
||||
if self.running_avg_loss == float('nan'):
|
||||
print(f'get NaN')
|
||||
break
|
||||
i_iter += 1
|
||||
|
||||
if i_iter % interval == 0:
|
||||
self.summary_writer.flush()
|
||||
time_period = time.time() - start
|
||||
print(f'step: {i_iter:d}, second: {time_period:.2f}, '
|
||||
f'loss: {loss:.3f}, cover_loss: {cove_loss:.3f}')
|
||||
start = time.time()
|
||||
|
||||
if save_model == True:
|
||||
self.save_model(self.running_avg_loss, iepoch)
|
||||
|
||||
def run(self, num_epochs):
|
||||
for iepoch in range(num_epochs):
|
||||
self.run_one_epoch(iepoch)
|
|
@ -0,0 +1,156 @@
|
|||
"""Utility functions for process and load data."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from torch.autograd import Variable
|
||||
|
||||
# %% get i/o from batch
|
||||
|
||||
def get_input_from_batch(batch, use_cuda, config):
|
||||
extra_zeros = None
|
||||
enc_lens = batch.enc_lens
|
||||
max_enc_len = np.max(enc_lens)
|
||||
enc_batch_extend_vocab = None
|
||||
batch_size = len(batch.enc_lens)
|
||||
enc_batch = Variable(torch.from_numpy(batch.enc_batch).long())
|
||||
enc_padding_mask = Variable(torch.from_numpy(batch.enc_padding_mask)).float()
|
||||
|
||||
if config.pointer_gen:
|
||||
enc_batch_extend_vocab = Variable(torch.from_numpy(batch.enc_batch_extend_vocab).long())
|
||||
# max_art_oovs is the max over all the article oov list in the batch
|
||||
if batch.max_art_oovs > 0:
|
||||
extra_zeros = Variable(torch.zeros((batch_size, batch.max_art_oovs)))
|
||||
|
||||
c_t = Variable(torch.zeros((batch_size, 2 * config.hidden_dim)))
|
||||
|
||||
coverage = None
|
||||
if config.is_coverage:
|
||||
coverage = Variable(torch.zeros(enc_batch.size()))
|
||||
|
||||
enc_pos = np.zeros((batch_size, max_enc_len))
|
||||
for i, inst in enumerate(batch.enc_batch):
|
||||
for j, w_i in enumerate(inst):
|
||||
if w_i != config.PAD:
|
||||
enc_pos[i, j] = (j + 1)
|
||||
else:
|
||||
break
|
||||
enc_pos = Variable(torch.from_numpy(enc_pos).long())
|
||||
|
||||
if use_cuda:
|
||||
c_t = c_t.cuda()
|
||||
enc_pos = enc_pos.cuda()
|
||||
enc_batch = enc_batch.cuda()
|
||||
enc_padding_mask = enc_padding_mask.cuda()
|
||||
|
||||
if coverage is not None:
|
||||
coverage = coverage.cuda()
|
||||
|
||||
if extra_zeros is not None:
|
||||
extra_zeros = extra_zeros.cuda()
|
||||
|
||||
if enc_batch_extend_vocab is not None:
|
||||
enc_batch_extend_vocab = enc_batch_extend_vocab.cuda()
|
||||
|
||||
return enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, extra_zeros, c_t, coverage
|
||||
|
||||
|
||||
def get_output_from_batch(batch, use_cuda, config):
|
||||
dec_lens = batch.dec_lens
|
||||
max_dec_len = np.max(dec_lens)
|
||||
batch_size = len(batch.dec_lens)
|
||||
dec_lens = Variable(torch.from_numpy(dec_lens)).float()
|
||||
tgt_batch = Variable(torch.from_numpy(batch.tgt_batch)).long()
|
||||
dec_batch = Variable(torch.from_numpy(batch.dec_batch).long())
|
||||
dec_padding_mask = Variable(torch.from_numpy(batch.dec_padding_mask)).float()
|
||||
|
||||
dec_pos = np.zeros((batch_size, config.max_dec_steps))
|
||||
for i, inst in enumerate(batch.dec_batch):
|
||||
for j, w_i in enumerate(inst):
|
||||
if w_i != config.PAD:
|
||||
dec_pos[i, j] = (j + 1)
|
||||
else:
|
||||
break
|
||||
dec_pos = Variable(torch.from_numpy(dec_pos).long())
|
||||
|
||||
if use_cuda:
|
||||
dec_lens = dec_lens.cuda()
|
||||
tgt_batch = tgt_batch.cuda()
|
||||
dec_batch = dec_batch.cuda()
|
||||
dec_padding_mask = dec_padding_mask.cuda()
|
||||
dec_pos = dec_pos.cuda()
|
||||
|
||||
return dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch
|
||||
|
||||
|
||||
# calc
|
||||
def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99):
|
||||
if int(running_avg_loss) == 0: # on the first iteration just take the loss
|
||||
running_avg_loss = loss
|
||||
else:
|
||||
running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
|
||||
running_avg_loss = min(running_avg_loss, 12) # clip
|
||||
|
||||
loss_sum = tf.Summary()
|
||||
tag_name = 'running_avg_loss/decay=%f' % (decay)
|
||||
loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
|
||||
summary_writer.add_summary(loss_sum, step)
|
||||
|
||||
return running_avg_loss
|
||||
|
||||
|
||||
# %% article / abstract <=> id
|
||||
|
||||
def article2ids(article_words, vocab, config):
|
||||
ids = []
|
||||
oov = []
|
||||
unk_id = vocab.word2id(config.UNK_TOKEN)
|
||||
for w in article_words:
|
||||
i = vocab.word2id(w)
|
||||
if i == unk_id: # If w is OOV
|
||||
if w not in oov: # Add to list of OOVs
|
||||
oov.append(w)
|
||||
oov_num = oov.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
|
||||
ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
|
||||
else:
|
||||
ids.append(i)
|
||||
return ids, oov
|
||||
|
||||
|
||||
def abstract2ids(abstract_words, vocab, article_oovs, config):
|
||||
ids = []
|
||||
unk_id = vocab.word2id(config.UNK_TOKEN)
|
||||
for w in abstract_words:
|
||||
i = vocab.word2id(w)
|
||||
if i == unk_id: # If w is an OOV word
|
||||
if w in article_oovs: # If w is an in-article OOV
|
||||
vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
|
||||
ids.append(vocab_idx)
|
||||
else: # If w is an out-of-article OOV
|
||||
ids.append(unk_id) # Map to the UNK token id
|
||||
else:
|
||||
ids.append(i)
|
||||
return ids
|
||||
|
||||
|
||||
# %% decode
|
||||
|
||||
def outputids2words(id_list, vocab, article_oovs):
|
||||
words = []
|
||||
for i in id_list:
|
||||
try:
|
||||
w = vocab.id2word(i) # might be [UNK]
|
||||
except ValueError as e: # w is OOV
|
||||
assert article_oovs is not None, \
|
||||
("Error: models produced a word ID that isn't in the vocabulary. "
|
||||
"This should not happen in baseline (no pointer-generator) mode. ")
|
||||
article_oov_idx = i - vocab.size()
|
||||
try:
|
||||
w = article_oovs[article_oov_idx]
|
||||
except ValueError as e: # i doesn't correspond to an article oov
|
||||
raise ValueError(
|
||||
'Error: models produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (
|
||||
i, article_oov_idx, len(article_oovs)))
|
||||
words.append(w)
|
||||
return words
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,153 @@
|
|||
"""Training Functions.
|
||||
|
||||
- `run_train` for T5, BART, and BERT-to-BERT (huggingface/transformers supported)
|
||||
- `run_train_pgn` for Pointer-Generator Network
|
||||
"""
|
||||
|
||||
import os
|
||||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||
|
||||
from .utils import (
|
||||
prepare_tokenizer, get_dataset,
|
||||
ModelPrepareDict, MetricsBuildDict
|
||||
)
|
||||
from .pointer_generator.trainer import Trainer as PgnTrainer
|
||||
from .pointer_generator.decode import BeamSearch
|
||||
|
||||
|
||||
# %% huggingface-supported models: t5, bart, bert-to-bert
|
||||
|
||||
def prepare_training_arguments(args):
|
||||
train_args = Seq2SeqTrainingArguments(
|
||||
output_dir=args.run_dir,
|
||||
do_train=args.do_train,
|
||||
do_eval=args.do_eval,
|
||||
evaluation_strategy='epoch',
|
||||
save_strategy='epoch',
|
||||
logging_steps=args.logging_steps,
|
||||
# optimization args, the trainer uses the Adam optimizer
|
||||
# and has a linear warmup for the learning rates
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
warmup_steps=args.warmup_steps,
|
||||
# misc args
|
||||
seed=args.seed,
|
||||
disable_tqdm=False,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model='bleu-4',
|
||||
# generation
|
||||
predict_with_generate=True
|
||||
)
|
||||
return train_args
|
||||
|
||||
|
||||
def run_train(args):
|
||||
"""A general training script with huggingface/transformers.
|
||||
1. prepare training and validation sets
|
||||
2. load model and organize trainer (and arguments)
|
||||
"""
|
||||
tokenizer = prepare_tokenizer(name=args.tokenizer_name)
|
||||
trainset= get_dataset(
|
||||
expr_name=args.expr_name,
|
||||
data_files=args.train_outpath,
|
||||
tokenizer=tokenizer,
|
||||
args=args
|
||||
)
|
||||
validset = get_dataset(
|
||||
expr_name=args.expr_name,
|
||||
data_files=args.dev_outpath,
|
||||
tokenizer=tokenizer,
|
||||
args=args
|
||||
)
|
||||
|
||||
train_args = prepare_training_arguments(args)
|
||||
model = ModelPrepareDict[args.expr_name](
|
||||
name=args.model_name,
|
||||
path=args.model_path,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
metric_fn = MetricsBuildDict[args.metrics[0]](tokenizer)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=train_args,
|
||||
train_dataset=trainset, eval_dataset=validset,
|
||||
tokenizer=tokenizer, compute_metrics=metric_fn,
|
||||
)
|
||||
|
||||
trainer._max_length = args.decode_maxlen
|
||||
trainer._num_beams = args.num_beams
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
|
||||
# %% pointer-generator network
|
||||
|
||||
def find_latest_pgn_model_path(model_dir):
|
||||
"""Find the path/filename of the latest model within the given directory."""
|
||||
filenames = os.listdir(model_dir)
|
||||
if len(filenames) == 0: return
|
||||
|
||||
indices = []
|
||||
for fn in filenames:
|
||||
model_name = fn.split('.')[0]
|
||||
model_index = int(model_name.split('_')[-1])
|
||||
indices.append(model_index)
|
||||
max_index = indices.index( max(indices) )
|
||||
max_file = filenames[max_index]
|
||||
|
||||
latest_model_path = os.path.join(model_dir, max_file)
|
||||
return latest_model_path
|
||||
|
||||
|
||||
|
||||
def run_train_pgn(args):
|
||||
trainer = PgnTrainer(args)
|
||||
if args.latest_model_path is not None:
|
||||
model_path = args.latest_model_path
|
||||
else:
|
||||
model_path = args.model_path
|
||||
print(f'run with model from [{model_path}]')
|
||||
|
||||
for iepoch in range(args.start_iepoch, args.start_iepoch + args.num_train_epochs):
|
||||
print(f'\n <<< START of the #{iepoch} EPOCH >>>')
|
||||
if (iepoch + 1) % args.num_eval_epochs == 0:
|
||||
do_eval = True
|
||||
else:
|
||||
do_eval = False
|
||||
|
||||
if (iepoch + 1) % args.num_save_model_epochs == 0:
|
||||
do_save_model = True
|
||||
else:
|
||||
do_save_model = False
|
||||
|
||||
trainer.run_one_epoch(
|
||||
iepoch=iepoch,
|
||||
model_path=model_path,
|
||||
interval=args.logging_steps,
|
||||
save_model=do_save_model,
|
||||
)
|
||||
args.latest_model_path = find_latest_pgn_model_path(trainer.model_dir)
|
||||
|
||||
if (do_eval == True) and (args.latest_model_path is not None):
|
||||
print(f'EVAL using model [{args.latest_model_path}]')
|
||||
tester = BeamSearch(args, args.latest_model_path, args.eval_data_path)
|
||||
tester.run(args.logging_steps)
|
||||
print(f' <<< END of the #{iepoch} EPOCH >>>\n')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %% collection
|
||||
|
||||
TrainFunctionDict = {
|
||||
't5': run_train,
|
||||
'bart': run_train,
|
||||
'b2b': run_train,
|
||||
'pgn': run_train_pgn,
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
"""Quick import of experiment utility functions.
|
||||
|
||||
preparation of tokenizer, train/dev/test dataset
|
||||
setup models, training arguments and metrics
|
||||
"""
|
||||
|
||||
from .tokenizer import prepare_tokenizer
|
||||
from .dataset import get_datasets, get_dataset, get_testset, special_tokens_map
|
||||
from .model import ModelPrepareDict, ModelTestDict
|
||||
from .metrics import MetricsBuildDict, MetricsDict, bleu_scorer, parent_scorer
|
|
@ -0,0 +1,148 @@
|
|||
"""Load train/dev and test sets for experiments."""
|
||||
|
||||
from typing import Dict
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
special_tokens_map = {
|
||||
't5': {'cls': 1, 'sep': 1}, # eos_token_id
|
||||
'bart': {'cls': 0, 'sep': 2}, # cls_token_id, sep_token_id
|
||||
'b2b': {'cls': 101, 'sep': 102}, # cls_token_id, sep_token_id
|
||||
}
|
||||
|
||||
|
||||
# %% train
|
||||
|
||||
def get_sample_naive(sample, tokenizer, args):
|
||||
"""Tokenize a sample to generate towards a model input.
|
||||
|
||||
args:
|
||||
sample: {'source': List[str], 'target': str, ...}
|
||||
tokenizer: AutoTokenizer class
|
||||
args: >= {'input_maxlen', 'decode_maxlen'}
|
||||
rets:
|
||||
features: {'input_ids', 'attention_mask', 'labels'}
|
||||
"""
|
||||
|
||||
cls_id = special_tokens_map[args.expr_name]['cls']
|
||||
sep_id = special_tokens_map[args.expr_name]['sep']
|
||||
|
||||
input_ids = [cls_id]
|
||||
for text_span in sample['source']:
|
||||
span_tokens = tokenizer.tokenize(text_span)
|
||||
span_token_ids = tokenizer.convert_tokens_to_ids(span_tokens)
|
||||
input_ids.extend(span_token_ids)
|
||||
input_ids.append(sep_id)
|
||||
input_ids = input_ids[:args.input_maxlen]
|
||||
attention_mask = [1 for _ in input_ids]
|
||||
|
||||
while len(input_ids) < args.input_maxlen:
|
||||
input_ids.append(tokenizer.pad_token_id)
|
||||
attention_mask.append(0)
|
||||
|
||||
target_inputs = tokenizer(
|
||||
text=sample['target'],
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=args.decode_maxlen
|
||||
)
|
||||
sample_features = {
|
||||
'input_ids': input_ids,
|
||||
'attention_mask': attention_mask,
|
||||
'decoder_attention_mask': target_inputs['attention_mask'],
|
||||
'labels': target_inputs['input_ids']
|
||||
}
|
||||
return sample_features
|
||||
|
||||
|
||||
def get_sample_b2b(sample, tokenizer, args):
|
||||
"""Tokenize a sample to generate towards a model input.
|
||||
|
||||
args:
|
||||
sample: {'table_id', 'source', 'target'}
|
||||
tokenizer: AutoTokenizer class
|
||||
args: >= {'input_maxlen', 'decode_maxlen'}
|
||||
rets:
|
||||
features: {'input_ids', 'attention_mask', 'labels'}
|
||||
"""
|
||||
cls_id = special_tokens_map[args.expr_name]['cls']
|
||||
sep_id = special_tokens_map[args.expr_name]['sep']
|
||||
|
||||
# concatenation
|
||||
input_ids = [cls_id]
|
||||
position_ids = [0]
|
||||
for text_span in sample['source']:
|
||||
span_tokens = tokenizer.tokenize(text_span)
|
||||
span_token_ids = tokenizer.convert_tokens_to_ids(span_tokens)
|
||||
input_ids.extend(span_token_ids)
|
||||
input_ids.append(sep_id)
|
||||
position_ids.extend([i for i in range(len(span_token_ids) + 1)])
|
||||
# truncation
|
||||
input_ids = input_ids[:args.input_maxlen]
|
||||
position_ids = position_ids[:args.input_maxlen]
|
||||
attention_mask = [1 for _ in input_ids]
|
||||
# 'max_length' padding
|
||||
while len(input_ids) < args.input_maxlen:
|
||||
input_ids.append(tokenizer.pad_token_id)
|
||||
attention_mask.append(0)
|
||||
|
||||
target_inputs = tokenizer(
|
||||
text=sample['target'],
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=args.decode_maxlen,
|
||||
return_tensors='pt',
|
||||
)
|
||||
sample_features = {
|
||||
'input_ids': input_ids,
|
||||
'attention_mask': attention_mask,
|
||||
'position_ids': position_ids,
|
||||
'decoder_input_ids': target_inputs['input_ids'][0],
|
||||
'decoder_attention_mask': target_inputs['attention_mask'][0],
|
||||
'labels': target_inputs['input_ids'][0],
|
||||
}
|
||||
return sample_features
|
||||
|
||||
|
||||
SamplePrepareDict = {
|
||||
't5': get_sample_naive,
|
||||
'bart': get_sample_naive,
|
||||
'b2b': get_sample_b2b,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def get_dataset(
|
||||
expr_name: str,
|
||||
data_files, tokenizer, args,
|
||||
file_type='json',
|
||||
):
|
||||
# datasets.arrow_dataset.Dataset
|
||||
raw_data = load_dataset(file_type, data_files=data_files)['train']
|
||||
tokenized_data = raw_data.map(
|
||||
lambda sample: SamplePrepareDict[expr_name](sample, tokenizer, args)
|
||||
)
|
||||
return tokenized_data
|
||||
|
||||
def get_datasets(
|
||||
expr_name: str,
|
||||
data_dict: Dict,
|
||||
tokenizer,
|
||||
args,
|
||||
file_type: str = 'json'
|
||||
):
|
||||
dataset = load_dataset(file_type, data_files=data_dict)
|
||||
for split in data_dict.keys():
|
||||
dataset[split].map(
|
||||
lambda sample: SamplePrepareDict[expr_name](sample, tokenizer, args)
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
# %% test
|
||||
|
||||
def get_testset(data_files, file_type='json'):
|
||||
test_data = load_dataset(file_type, data_files=data_files)
|
||||
testset = test_data['train']
|
||||
return testset
|
|
@ -0,0 +1,19 @@
|
|||
"""Quick imports of all evaluation metrics.
|
||||
|
||||
BLEU: require 'source' and 'target' text tokens.
|
||||
PARENT: require 'source' & 'target' tokens, and 'table_parent' list of tuples.
|
||||
|
||||
"""
|
||||
|
||||
from .bleu import bleu_metric_builder, bleu_scorer
|
||||
from .parent import parent_metric_builder, parent_scorer
|
||||
|
||||
MetricsDict = {
|
||||
'bleu': bleu_scorer,
|
||||
'parent': parent_scorer,
|
||||
}
|
||||
|
||||
MetricsBuildDict = {
|
||||
'bleu': bleu_metric_builder,
|
||||
'parent': parent_metric_builder,
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
"""BLEU(-4 by default) evaluation metric.
|
||||
- bleu_scorer: input `predictions` and list of `references` to calculate scores.
|
||||
- bleu_metric_builder: a function that performs evaluation with paired tokenizer.
|
||||
"""
|
||||
|
||||
|
||||
from datasets import load_metric
|
||||
|
||||
bleu_scorer = load_metric('bleu')
|
||||
|
||||
|
||||
def bleu_metric_builder(tokenizer, bleu_scorer=bleu_scorer):
|
||||
"""A builder of the BLEU Metrics."""
|
||||
|
||||
def compute_bleu_metrics(pred, verbose=False):
|
||||
"""utility to compute BLEU during training."""
|
||||
labels_ids = pred.label_ids
|
||||
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
|
||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||
label_tokens = [[tokenizer.tokenize(str)] for str in label_str] # multiple lists of tokens for each sample
|
||||
|
||||
pred_ids = pred.predictions
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
pred_tokens = [tokenizer.tokenize(str) for str in pred_str]
|
||||
|
||||
# compute the metric.
|
||||
# ['bleu', 'precisions', 'brevity_penalty', 'length_ratio', 'translation_length', 'reference_length']
|
||||
bleu_results = bleu_scorer.compute(
|
||||
predictions=pred_tokens,
|
||||
references=label_tokens,
|
||||
smooth=False
|
||||
)
|
||||
|
||||
if verbose == True:
|
||||
print(f'\n\nBLEU Results:')
|
||||
print(f"bleu: {bleu_results['bleu']:.4f}")
|
||||
print(f"precisions: {[round(item,4) for item in bleu_results['precisions']]}")
|
||||
print(f"brevity_penalty: {bleu_results['brevity_penalty']:.4f}")
|
||||
print(f"length_ratio: {bleu_results['length_ratio']:.4f}")
|
||||
print(f"translation_length: {bleu_results['translation_length']}")
|
||||
print(f"reference_length: {bleu_results['reference_length']}\n\n")
|
||||
|
||||
return {'bleu-4': round(bleu_results['bleu'], 4)}
|
||||
|
||||
return compute_bleu_metrics
|
|
@ -0,0 +1,237 @@
|
|||
"""Parent evaluation metrics.
|
||||
That additionally take tables into consideration.
|
||||
- parent_scorer
|
||||
- parent_metric_builder
|
||||
"""
|
||||
|
||||
import math
|
||||
import collections
|
||||
|
||||
|
||||
# %% utility functions
|
||||
|
||||
def overlap_probability(ngram, table, smoothing=0.0, stopwords=None):
|
||||
"""Returns the probability that the given n-gram overlaps with the table."""
|
||||
# pylint: disable=g-complex-comprehension
|
||||
if len(table[0]) == 2:
|
||||
table_values = set([tok for _, value in table for tok in value])
|
||||
else:
|
||||
table_values = set([tok for head, _, tail in table for tok in head + tail])
|
||||
|
||||
overlap = 0
|
||||
for token in ngram:
|
||||
if stopwords is not None and token in stopwords:
|
||||
overlap += 1
|
||||
continue
|
||||
if token in table_values:
|
||||
overlap += 1
|
||||
return float(overlap + smoothing) / float(len(ngram) + smoothing)
|
||||
|
||||
def _lcs(x, y):
|
||||
"""Computes the length of the LCS between two seqs. """
|
||||
n, m = len(x), len(y)
|
||||
table = dict()
|
||||
for i in range(n + 1):
|
||||
for j in range(m + 1):
|
||||
if i == 0 or j == 0:
|
||||
table[i, j] = 0
|
||||
elif x[i - 1] == y[j - 1]:
|
||||
table[i, j] = table[i - 1, j - 1] + 1
|
||||
else:
|
||||
table[i, j] = max(table[i - 1, j], table[i, j - 1])
|
||||
return table
|
||||
|
||||
def _len_lcs(x, y):
|
||||
"""Returns the length of the Longest Common Subsequence between two seqs. """
|
||||
table = _lcs(x, y)
|
||||
n, m = len(x), len(y)
|
||||
return table[n, m]
|
||||
|
||||
def _mention_probability(table_entry, sentence, smoothing=1e-6): # smoothing=0.0
|
||||
"""Returns the probability that the table entry is mentioned in the sentence."""
|
||||
if len(table_entry) == 2:
|
||||
value = table_entry[1]
|
||||
else:
|
||||
value = table_entry[0] + table_entry[2]
|
||||
overlap = _len_lcs(value, sentence)
|
||||
return float(overlap + smoothing) / float(len(value) + smoothing)
|
||||
|
||||
def _ngrams(sequence, order):
|
||||
"""Yields all ngrams of given order in sequence."""
|
||||
assert order >= 1
|
||||
for n in range(order, len(sequence) + 1):
|
||||
yield tuple(sequence[n - order: n])
|
||||
|
||||
def _ngram_counts(sequence, order):
|
||||
"""Returns count of all ngrams of given order in sequence."""
|
||||
if len(sequence) < order:
|
||||
return collections.Counter()
|
||||
return collections.Counter(_ngrams(sequence, order))
|
||||
|
||||
|
||||
|
||||
# %% metrics calculation and builder
|
||||
|
||||
def parent_scorer(
|
||||
predictions, references, tables,
|
||||
lambda_weight=0.5, smoothing=1e-5, max_order=4,
|
||||
entailment_fn=overlap_probability, mention_fn=_mention_probability,
|
||||
return_dict=True
|
||||
):
|
||||
"""Metric for comparing predictions to references given tables.
|
||||
args:
|
||||
predictions: List[str]
|
||||
references: List[ List[str] ]
|
||||
tables: List[ List[Tuple(str-field, str-cell)] ]
|
||||
...
|
||||
rets:
|
||||
result: Dict{
|
||||
'average_precision', 'average_recall',
|
||||
'average_f1', 'all_f1_scores'
|
||||
}
|
||||
"""
|
||||
|
||||
precisions, recalls, all_f_scores = [], [], []
|
||||
reference_recalls, table_recalls = [], []
|
||||
all_lambdas = []
|
||||
|
||||
for prediction, list_of_references, table in zip(predictions, references, tables):
|
||||
c_prec, c_rec, c_f = [], [], []
|
||||
ref_rec, table_rec = [], []
|
||||
|
||||
for reference in list_of_references:
|
||||
# Weighted ngram precisions and recalls for each order.
|
||||
ngram_prec, ngram_rec = [], []
|
||||
for order in range(1, max_order + 1):
|
||||
# Collect n-grams and their entailment probabilities.
|
||||
pred_ngram_counts = _ngram_counts(prediction, order)
|
||||
pred_ngram_weights = {ngram: entailment_fn(ngram, table)
|
||||
for ngram in pred_ngram_counts}
|
||||
ref_ngram_counts = _ngram_counts(reference, order)
|
||||
ref_ngram_weights = {ngram: entailment_fn(ngram, table)
|
||||
for ngram in ref_ngram_counts}
|
||||
|
||||
# Precision.
|
||||
numerator, denominator = 0., 0.
|
||||
for ngram, count in pred_ngram_counts.items():
|
||||
denominator += count
|
||||
prob_ngram_in_ref = min(
|
||||
1., float(ref_ngram_counts.get(ngram, 0) / count))
|
||||
numerator += count * (
|
||||
prob_ngram_in_ref +
|
||||
(1. - prob_ngram_in_ref) * pred_ngram_weights[ngram])
|
||||
if denominator == 0.:
|
||||
# Set precision to 0.
|
||||
ngram_prec.append(0.0)
|
||||
else:
|
||||
ngram_prec.append(numerator / denominator)
|
||||
|
||||
# Recall.
|
||||
numerator, denominator = 0., 0.
|
||||
for ngram, count in ref_ngram_counts.items():
|
||||
prob_ngram_in_pred = min(
|
||||
1., float(pred_ngram_counts.get(ngram, 0) / count))
|
||||
denominator += count * ref_ngram_weights[ngram]
|
||||
numerator += count * ref_ngram_weights[ngram] * prob_ngram_in_pred
|
||||
if denominator == 0.:
|
||||
# Set recall to 1.
|
||||
ngram_rec.append(1.0)
|
||||
else:
|
||||
ngram_rec.append(numerator / denominator)
|
||||
|
||||
# Compute recall against table fields.
|
||||
table_mention_probs = [mention_fn(entry, prediction) for entry in table]
|
||||
table_rec.append(sum(table_mention_probs) / len(table))
|
||||
|
||||
# Smoothing.
|
||||
for order in range(1, max_order):
|
||||
if ngram_prec[order] == 0.:
|
||||
ngram_prec[order] = smoothing
|
||||
if ngram_rec[order] == 0.:
|
||||
ngram_rec[order] = smoothing
|
||||
|
||||
# Compute geometric averages of precision and recall for all orders.
|
||||
w = 1. / max_order
|
||||
if any(prec == 0. for prec in ngram_prec):
|
||||
c_prec.append(0.)
|
||||
else:
|
||||
sp = (w * math.log(p_i) for p_i in ngram_prec)
|
||||
c_prec.append(math.exp(math.fsum(sp)))
|
||||
if any(rec == 0. for rec in ngram_rec):
|
||||
ref_rec.append(smoothing)
|
||||
else:
|
||||
sr = [w * math.log(r_i) for r_i in ngram_rec]
|
||||
ref_rec.append(math.exp(math.fsum(sr)))
|
||||
|
||||
# Combine reference and table recalls.
|
||||
if table_rec[-1] == 0.:
|
||||
table_rec[-1] = smoothing
|
||||
if ref_rec[-1] == 0. or table_rec[-1] == 0.:
|
||||
c_rec.append(0.)
|
||||
else:
|
||||
if lambda_weight is None:
|
||||
lw = sum([mention_fn(entry, reference) for entry in table]) / len(table)
|
||||
lw = 1. - lw
|
||||
else:
|
||||
lw = lambda_weight
|
||||
all_lambdas.append(lw)
|
||||
c_rec.append(
|
||||
math.exp((1. - lw) * math.log(ref_rec[-1]) +
|
||||
(lw) * math.log(table_rec[-1])))
|
||||
|
||||
# F-score.
|
||||
c_f.append((2. * c_prec[-1] * c_rec[-1]) / (c_prec[-1] + c_rec[-1] + 1e-8))
|
||||
|
||||
# Get index of best F-score.
|
||||
max_i = max(enumerate(c_f), key=lambda x: x[1])[0]
|
||||
precisions.append(c_prec[max_i])
|
||||
recalls.append(c_rec[max_i])
|
||||
all_f_scores.append(c_f[max_i])
|
||||
reference_recalls.append(ref_rec[max_i])
|
||||
table_recalls.append(table_rec[max_i])
|
||||
|
||||
avg_precision = sum(precisions) / len(precisions)
|
||||
avg_recall = sum(recalls) / len(recalls)
|
||||
avg_f_score = sum(all_f_scores) / len(all_f_scores)
|
||||
|
||||
if return_dict:
|
||||
result_dict = {
|
||||
'average_precision': avg_precision,
|
||||
'average_recall': avg_recall,
|
||||
'average_f1': avg_f_score,
|
||||
'all_f1_scores': all_f_scores
|
||||
}
|
||||
return result_dict
|
||||
else:
|
||||
return avg_precision, avg_recall, avg_f_score, all_f_scores
|
||||
|
||||
|
||||
def parent_metric_builder(tokenizer, parent_scorer=parent_scorer):
|
||||
"""A builder of the PARENT metrics given a compatible tokenizer. """
|
||||
|
||||
def compute_parent_metrics(pred, tables, verbose=False):
|
||||
labels_ids = pred.label_ids
|
||||
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
|
||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||
label_tokens = [[tokenizer.tokenize(str)] for str in label_str]
|
||||
|
||||
pred_ids = pred.predictions
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
pred_tokens = [tokenizer.tokenize(str) for str in pred_str]
|
||||
|
||||
parent_results = parent_scorer(
|
||||
predictions=pred_tokens,
|
||||
references=label_tokens,
|
||||
tables=tables,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if verbose == True:
|
||||
n = len(parent_results['all_f1_scores'])
|
||||
p = parent_results['average_precision']
|
||||
r = parent_results['average_recall']
|
||||
f = parent_results['average_f1']
|
||||
print(f'[{n} instances]: avg precision: {p:.3f}, recall: {r:.3f}, f1: {f:.3f}')
|
||||
return parent_results
|
||||
|
||||
return compute_parent_metrics
|
|
@ -0,0 +1,141 @@
|
|||
"""Build or Load, Pre-trained or Tuned Models."""
|
||||
|
||||
|
||||
import os
|
||||
import json
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM, BertGenerationEncoder,
|
||||
BertGenerationDecoder, EncoderDecoderModel,
|
||||
)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# %% train
|
||||
|
||||
def prepare_model_naive(name: str, path: str, device: str = 'cuda'):
|
||||
"""Load target model from the specified name or model-file path."""
|
||||
try:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(name)
|
||||
return model.to(device)
|
||||
except:
|
||||
try:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(path)
|
||||
return model.to(device)
|
||||
except:
|
||||
logger.error(f'[utils >> prep_model] fails with name [{name}] and path [{path}]')
|
||||
|
||||
|
||||
def prepare_b2b_model(name: str, path: str, device: str = 'cuda'):
|
||||
"""Prepare a EncoderDecoderModel class from BertGenerationEncoder + BertGenerationDecoder."""
|
||||
if path is not None:
|
||||
bert2bert = EncoderDecoderModel.from_pretrained(path)
|
||||
elif name is not None:
|
||||
encoder = BertGenerationEncoder.from_pretrained(
|
||||
name, bos_token_id=101, eos_token_id=102)
|
||||
decoder = BertGenerationDecoder.from_pretrained(
|
||||
name, bos_token_id=101, eos_token_id=102,
|
||||
add_cross_attention=True, is_decoder=True)
|
||||
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
# adjust default configs
|
||||
bert2bert.config.encoder.max_length = 512
|
||||
bert2bert.config.decoder.max_length = 60
|
||||
return bert2bert.to(device)
|
||||
|
||||
|
||||
|
||||
ModelPrepareDict = {
|
||||
't5': prepare_model_naive,
|
||||
'bart': prepare_model_naive,
|
||||
'b2b': prepare_b2b_model
|
||||
}
|
||||
|
||||
|
||||
# %% test
|
||||
|
||||
def find_best_model(run_dir, load_last: bool = True):
|
||||
if run_dir is None: return None
|
||||
model_ckpts = [rd for rd in os.listdir(run_dir) if rd.startswith('checkpoint')]
|
||||
if len(model_ckpts) == 0: return None
|
||||
|
||||
print(f"RUN-DIR: {run_dir}")
|
||||
print(f"MODEL-CKPTS: {model_ckpts}")
|
||||
|
||||
iters = [int(dirname.split('-')[-1]) for dirname in model_ckpts]
|
||||
index = iters.index( max(iters) )
|
||||
model_path = os.path.join(run_dir, model_ckpts[index])
|
||||
if load_last: return model_path
|
||||
|
||||
|
||||
trainer_state_file = os.path.join(model_path, 'trainer_state.json')
|
||||
with open(trainer_state_file, 'r') as fr:
|
||||
states = json.load(fr)
|
||||
best_model_path = states['best_model_checkpoint']
|
||||
return best_model_path
|
||||
|
||||
|
||||
def load_model_test_naive(run_dir, path, name, device):
|
||||
"""Load model from 1) the running directory, 2) specified path, 3) library model name."""
|
||||
|
||||
best_model_path = find_best_model(run_dir)
|
||||
if best_model_path is not None:
|
||||
logging.info(f'[utils >> load_model] from tuned checkpoint [{best_model_path}]')
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(best_model_path)
|
||||
return model.to(device)
|
||||
|
||||
logging.info(f'[utils >> load_model] fails import from run-dir [{run_dir}]')
|
||||
try:
|
||||
model_path = path
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
||||
logging.info(f'[utils >> load_model] from original path [{model_path}]')
|
||||
return model.to(device)
|
||||
except:
|
||||
logging.warning(f'[utils >> load_model] fails import from path [{path}]')
|
||||
|
||||
try:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(name)
|
||||
logging.info(f'[utils >> load_model] from name [{name}]')
|
||||
return model.to(device)
|
||||
except:
|
||||
logging.warning(f'[utils >> load_model] fails import from name [{name}]')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_model_test_b2b(run_dir, path, name, device):
|
||||
"""Load model from 1) the running directory, 2) specified path, 3) library model name."""
|
||||
|
||||
best_model_path = find_best_model(run_dir)
|
||||
if best_model_path is not None:
|
||||
logging.info(f'[utils >> load_model] from tuned checkpoint [{best_model_path}]')
|
||||
model = EncoderDecoderModel.from_pretrained(best_model_path)
|
||||
return model.to(device)
|
||||
|
||||
logging.info(f'[utils >> load_model] fails import from run-dir [{run_dir}]')
|
||||
try:
|
||||
model_path = path
|
||||
model = EncoderDecoderModel.from_pretrained(model_path)
|
||||
logging.info(f'[utils >> load_model] from original path [{model_path}]')
|
||||
return model.to(device)
|
||||
except:
|
||||
logging.warning(f'[utils >> load_model] fails import from path [{path}]')
|
||||
|
||||
try:
|
||||
encoder = BertGenerationEncoder.from_pretrained(name)
|
||||
decoder = BertGenerationDecoder.from_pretrained(name)
|
||||
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
logging.info(f'[utils >> load_model] from name [{name}]')
|
||||
return model.to(device)
|
||||
except:
|
||||
logging.warning(f'[utils >> load_model] fails import from name [{name}]')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
ModelTestDict = {
|
||||
't5': load_model_test_naive,
|
||||
'bart': load_model_test_naive,
|
||||
'b2b': load_model_test_b2b
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
"""Tokenizer utilities for experiments.
|
||||
|
||||
Possible tokenizer names are:
|
||||
- t5-base, t5-large
|
||||
- bart-base, bart-large
|
||||
- bert-base-uncased, bert-large-uncased
|
||||
"""
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
special_tokens_dict = {
|
||||
'sep_token': '<sep>',
|
||||
'cls_token': '<cls>',
|
||||
'mask_token': '<mask>'
|
||||
}
|
||||
|
||||
new_tokens = ['<title>', '<cell>', '<agg>', '<top>', '<left>', '<corner>', '<data>']
|
||||
|
||||
|
||||
def prepare_tokenizer(name: str, verbose: bool = False) -> AutoTokenizer:
|
||||
"""Prepare the loaded tokenizer class given the (model) name.
|
||||
args:
|
||||
name: <str>, key of the specified tokenizer
|
||||
choices = ['t5-base', 'bart-base', 'bert-base-uncased']
|
||||
verbose: <bool>, whether in a verbose mode
|
||||
rets:
|
||||
tokenzier: an automatically identified tokenizer class.
|
||||
"""
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
# tokenizer.add_special_tokens(special_tokens_dict)
|
||||
# tokenizer.add_tokens(new_tokens)
|
||||
|
||||
if verbose == True:
|
||||
logger.info(f'[utils >> prepare_tokenizer] gets tokenizer from name [{name}]')
|
||||
# logger.info(f'[utils >> prepare_tokenizer] adds special tokens {list(special_tokens_dict.keys())}')
|
||||
|
||||
return tokenizer
|
|
@ -0,0 +1,141 @@
|
|||
"""Train & Test Pipeline for HiTab Data-to-Text Generation.
|
||||
|
||||
Available Models:
|
||||
- (t5) T5: base size by default
|
||||
- (bart) BART: base size by default
|
||||
- (b2b) BERT-to-BERT: base size by default
|
||||
- (pgn) Pointer-Generator Network
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from experiment.train_d2t import TrainFunctionDict
|
||||
from experiment.eval_d2t import TestFunctionDict
|
||||
|
||||
|
||||
def main():
|
||||
if args.do_train or args.do_eval:
|
||||
TrainFunctionDict[args.expr_name](args)
|
||||
|
||||
if args.do_test or args.do_decode:
|
||||
TestFunctionDict[args.expr_name](args)
|
||||
|
||||
|
||||
|
||||
ExperimentSuite = {
|
||||
't5': {
|
||||
'model_name': 't5-large',
|
||||
'tokenizer_name': 't5-large',
|
||||
'per_device_train_batch_size': 2,
|
||||
'per_device_eval_batch_size': 2,
|
||||
'learning_rate': 1e-4,
|
||||
'num_train_epochs': 50,
|
||||
},
|
||||
'bart': {
|
||||
'model_name': 'facebook/bart-base',
|
||||
'tokenizer_name': 'facebook/bart-base',
|
||||
'per_device_train_batch_size': 8,
|
||||
'per_device_eval_batch_size': 8,
|
||||
'learning_rate': 1e-4,
|
||||
'num_train_epochs': 50,
|
||||
},
|
||||
'b2b': {
|
||||
'model_name': 'bert-large-uncased',
|
||||
'tokenizer_name': 'bert-large-uncased',
|
||||
'per_device_train_batch_size': 8,
|
||||
'per_device_eval_batch_size': 8,
|
||||
'learning_rate': 1e-4,
|
||||
'num_train_epochs': 50,
|
||||
},
|
||||
'pgn': {
|
||||
'model_name': None,
|
||||
'tokenizer_name': None,
|
||||
'per_device_train_batch_size': 2,
|
||||
'per_device_eval_batch_size': 2,
|
||||
'learning_rate': 1e-3,
|
||||
'num_train_epochs': 100,
|
||||
'train_sleep_time': 15,
|
||||
'vocab_path': os.path.join(os.getcwd(), 'experiment/pointer_generator/vocab'),
|
||||
'vocab_size': 30000,
|
||||
'test_decode_name': 'decoded_test.log',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
def update_arguments():
|
||||
# misc
|
||||
args.run_dir = os.path.join(os.getcwd(), args.run_subdir, args.expr_name)
|
||||
if not os.path.exists(args.run_dir): os.makedirs(args.run_dir)
|
||||
|
||||
# data
|
||||
args.train_outpath = args.train_data_path = os.path.join(args.data_dir, 'train_samples.jsonl')
|
||||
args.dev_outpath = args.eval_data_path = os.path.join(args.data_dir, 'dev_samples.jsonl')
|
||||
args.test_outpath = args.decode_data_path = os.path.join(args.data_dir, 'test_samples.jsonl')
|
||||
|
||||
# model
|
||||
for k, v in ExperimentSuite[args.expr_name].items():
|
||||
setattr(args, k, v)
|
||||
args.latest_model_path = args.model_path
|
||||
|
||||
return args
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# data
|
||||
parser.add_argument('--data_dir', type=str, default='data',
|
||||
help="Directory containing the processed train/dev/test_samples.jsonl files.")
|
||||
|
||||
# model
|
||||
parser.add_argument('--expr_name', type=str, default='t5',
|
||||
choices=['t5','bart','b2b','pgn'], help="Model name (abbr.).")
|
||||
parser.add_argument('--model_path', type=str, default=None,
|
||||
help="Path of model checkpoint if used for weight initialization.")
|
||||
|
||||
# training
|
||||
parser.add_argument('--logging_steps', type=int, default=100)
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
|
||||
parser.add_argument('--warmup_steps', type=int, default=100)
|
||||
parser.add_argument('--learning_rate', type=float, default=1e-3)
|
||||
|
||||
parser.add_argument('--start_iepoch', type=int, default=0,
|
||||
help='Index of the starting epoch.')
|
||||
parser.add_argument('--num_train_epochs', type=int, default=5,
|
||||
help='Number of epochs for continual tuning.')
|
||||
parser.add_argument('--num_eval_epochs', type=int, default=1,
|
||||
help='Number of epochs per validation.')
|
||||
parser.add_argument('--num_save_model_epochs', type=int, default=1,
|
||||
help='Number of epochs to save model ckpt.')
|
||||
|
||||
parser.add_argument('--input_maxlen', type=int, default=512,
|
||||
help='Max number of tokens of input sequences.')
|
||||
parser.add_argument('--decode_maxlen', type=int, default=100,
|
||||
help='Max number of tokens of generated sequnces.')
|
||||
parser.add_argument('--num_beams', type=int, default=5,
|
||||
help='Number of the searching beam size for sequence generation.')
|
||||
parser.add_argument('--num_return_sequences', type=int, default=3,
|
||||
help='Number of generated sentences for comparison.')
|
||||
|
||||
# evaluation
|
||||
parser.add_argument('--metrics', type=str, nargs='+', default=['bleu'])
|
||||
|
||||
# misc
|
||||
parser.add_argument('--run_subdir', type=str, default='runs')
|
||||
parser.add_argument('--log_subdir', type=str, default='logs')
|
||||
|
||||
parser.add_argument('--device', type=str, default='cuda')
|
||||
parser.add_argument('--seed', type=int, default=47)
|
||||
parser.add_argument('--do_train', action='store_true')
|
||||
parser.add_argument('--do_eval', action='store_true')
|
||||
parser.add_argument('--do_test', action='store_true')
|
||||
parser.add_argument('--do_decode', action='store_true')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args = update_arguments()
|
||||
main()
|
Загрузка…
Ссылка в новой задаче