initialize scripts for building several datasets including WikiSQL, WTQ, SQA and etc.

This commit is contained in:
SivilTaram 2021-07-11 14:23:39 +08:00
Родитель ef1c078a43
Коммит fe690930c3
14 изменённых файлов: 1271 добавлений и 0 удалений

Просмотреть файл

@ -36,6 +36,8 @@ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additio
# 📝 License
The code and pre-trained models are open-sourced under [MIT License](LICENSE-Code), while the pre-training corpus is released under [CC BY-SA 4.0](LICENSE-Data).
# ™️ Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft

28
common/table_linearize.py Normal file
Просмотреть файл

@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Utils for linearizing the table content into a flatten sequence
"""
from typing import Dict
def concat_input_with_table(_table_content: Dict, input_query: str, start_row_idx: int = 0):
compact_str = input_query.lower().strip() + " " + linearize_schema(_table_content, start_row_idx)
return compact_str.strip()
def linearize_schema(_table_content: Dict, start_row_idx: int):
"""
Data format: col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...
"""
_table_str = "col : " + " | ".join(_table_content["header"]) + " "
_table_str = _table_str.lower()
for i, row_example in enumerate(_table_content["rows"]):
_table_str += "row " + str(start_row_idx + i + 1) + " : "
row_cell_values = [str(cell_value) if isinstance(cell_value, int) else cell_value.lower()
for cell_value in row_example]
_table_str += " | ".join(row_cell_values) + " "
return _table_str

257
common/table_transform.py Normal file
Просмотреть файл

@ -0,0 +1,257 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import random
from typing import List, Dict
import logging
from transformers import AutoTokenizer
from table_linearize import concat_input_with_table
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="facebook/bart-large")
logger = logging.getLogger(__name__)
TGT_DEL = ", "
CHUNK_TOKEN = " <chunk> "
def build_example(_args, _question: str, _answer: List, _table: Dict,
_table_name: str, is_train: bool, max_sen_length):
input_sources = []
output_targets = []
truncate_table_content, table_mapping = truncate_database_values(_table,
_args.max_cell_length,
_args.max_cell_truncate)
# the chunk size of validation can be scale to a larger upperbound
max_chunk_size = _args.max_chunk_size if is_train else _args.max_chunk_size_valid
if len(_answer) > 0:
for i, case in enumerate(_answer):
if case in table_mapping.keys():
_answer[i] = table_mapping[case]
flatten_output = TGT_DEL.join([str(case).lower() for case in _answer])
# we should split into multiple chunks to save memory
chunk_table_contents, cover_ratio = split_long_table(truncate_table_content,
_question,
max_sen_length=max_sen_length)
# in training, we should recall these training examples
if len(chunk_table_contents) > max_chunk_size:
maximum_val = 1.1 if max_chunk_size != 1 else 1.05
drop_ratio = maximum_val - cover_ratio
# truncate the table
small_table_content = truncate_training_database(_table_name,
truncate_table_content,
_question,
drop_ratio,
_answers=_answer if is_train else None)
if max_chunk_size == 1:
split_mode = "greedy"
else:
split_mode = "average"
chunk_table_contents, _ = split_long_table(small_table_content,
_question,
max_sen_length=max_sen_length,
split_mode=split_mode)
flatten_inputs = []
# the initial value of row
row_idx = 0
early_stop = False
for _chunk_id, _table_content in enumerate(chunk_table_contents):
if _chunk_id >= max_chunk_size:
logger.warning("Table: {} has {} chunks which is too large to handle, truncate it.".format(
_table_name, len(chunk_table_contents)))
early_stop = True
if early_stop:
break
flatten_inputs.append(concat_input_with_table(_table_content, _question, start_row_idx=row_idx))
row_idx = len(_table_content["rows"]) + 1
flatten_input = CHUNK_TOKEN.join(flatten_inputs).strip()
input_sources.append(flatten_input)
output_targets.append(flatten_output)
else:
logger.warning("Empty answer in case: {}".format(_question))
return input_sources, output_targets
def setup_parser(_parser):
_parser.add_argument(
"--max-cell-length",
type=int,
default=15,
help="if the cell's length is larger that this, it should be processed",
)
_parser.add_argument(
"--max-cell-truncate",
type=int,
default=15,
help="truncate cell's length into a value less than this one",
)
return _parser
def permute_table(_wtq_table_content: Dict):
# shuffle header orders
header = _wtq_table_content["header"]
header_content = list(map(list, zip(*_wtq_table_content["rows"])))
header_num = len(_wtq_table_content["header"])
header_range = list(range(header_num))
random.shuffle(header_range)
# map from the original to the new
origin_to_shuffle = {i: header_range[i] for i in range(header_num)}
shuffle_header = [header[origin_to_shuffle[i]] for i in range(header_num)]
shuffle_content = [header_content[origin_to_shuffle[i]] for i in range(header_num)]
shuffle_rows = list(map(list, zip(*shuffle_content)))
# random.shuffle(shuffle_rows)
return {
"header": shuffle_header,
"rows": shuffle_rows
}
def split_long_table(_normalized_table: Dict, input_query: str, max_sen_length, split_mode="average"):
assert "header" in _normalized_table
assert "rows" in _normalized_table
number_of_rows = len(_normalized_table["rows"])
# TODO: avg split table into relatively average chunks
query_tokens = tokenizer.tokenize(input_query, add_special_tokens=True)
header_string = "col : " + " | ".join(_normalized_table["header"]) + " "
header_tokens = tokenizer.tokenize(header_string, add_special_tokens=False)
# split all cell values into tokens and see how many can be adapt
used_token_len = len(query_tokens) + len(header_tokens)
# remaining length
remain_token_len = max_sen_length - 2 - used_token_len
value_string = ""
for _, row_example in enumerate(_normalized_table["rows"]):
# generally we do not want to make
value_string += "row " + str(100) + " : "
row_cell_values = [str(cell_value) if isinstance(cell_value, int) else cell_value.lower()
for cell_value in row_example]
value_string += " | ".join(row_cell_values) + " "
value_token_len = len(tokenizer.tokenize(value_string))
# used to estimate the busy ratio
whole_token_len = used_token_len + value_token_len
# maximum chunk size
chunk_size = math.ceil(value_token_len / remain_token_len)
if chunk_size == 1:
return [_normalized_table], 1.0
if split_mode == "greedy":
remain_token_len = remain_token_len
elif split_mode == "average":
remain_token_len = min(remain_token_len, 100 + math.ceil(value_token_len / chunk_size))
else:
raise Exception("Do not support split_mode {}".format(split_mode))
current_chunk_remain_size = remain_token_len
current_chunk_row = 0
split_table_contents = []
for ind, row_example in enumerate(_normalized_table["rows"]):
value_string = "row " + str(ind) + " : "
row_cell_values = [str(cell_value) if isinstance(cell_value, int) else cell_value.lower()
for cell_value in row_example]
value_string += " | ".join(row_cell_values)
value_token_len = len(tokenizer.tokenize(value_string))
# over the size limit, and take action
if value_token_len > current_chunk_remain_size:
split_table_contents.append({
"header": _normalized_table["header"],
"rows": _normalized_table["rows"][current_chunk_row: ind]
})
# reset every thing
current_chunk_row = ind
current_chunk_remain_size = remain_token_len
current_chunk_remain_size -= value_token_len
if current_chunk_row != (number_of_rows - 1):
split_table_contents.append({
"header": _normalized_table["header"],
"rows": _normalized_table["rows"][current_chunk_row:]
})
return split_table_contents, float(max_sen_length / whole_token_len)
def truncate_training_database(_table_id, _table_content, _question, _drop_ratio: float, _answers=None, _sql=None):
truncated_unrelated_indices = []
related_indices = []
if _answers is None:
answer_set = set([])
else:
answer_set = set([ans_ex.lower() for ans_ex in _answers])
# add _sql into answer_set
if _sql is not None:
answer_set.update(_sql.split())
question_set = set(_question.strip("?!.,").split(" "))
row_max_len = len(_table_content["rows"])
for _row_idx, row in enumerate(_table_content["rows"]):
lower_row = set([str(cell).lower() for cell in row])
if len(lower_row & answer_set) == 0 and len(lower_row & question_set) == 0:
truncated_unrelated_indices.append(_row_idx)
else:
# add neighbours to preserve information aggressively
related_indices.append([_row_idx - 2, _row_idx - 1, _row_idx, _row_idx + 1, _row_idx + 2])
# remove the neighbours
truncated_unrelated_indices = [_row_idx for _row_idx in truncated_unrelated_indices
if _row_idx not in related_indices]
# select some cases to drop
drop_items = min(len(truncated_unrelated_indices), int(len(_table_content["rows"]) * _drop_ratio))
drop_row_indices = random.choices(truncated_unrelated_indices, k=drop_items)
for _row_idx in reversed(range(row_max_len)):
if _row_idx in drop_row_indices:
del _table_content["rows"][_row_idx]
# only when the drop ratio is too large, logging for warning.
if _drop_ratio >= 0.1:
logger.warning("Drop {:.2f} rows in table {}".format(len(drop_row_indices), _table_id))
return _table_content
def truncate_database_values(_table_content, max_cell_length, max_cell_truncate):
"""
This function is to process the wikitablequestion answer content to avoid too long sequence.
:param _table_content: `Dict` contains keys as `header` and `rows`
:param max_cell_length: `int` which indicates the maximum cell value length
:param max_cell_truncate: `int` which indicates the maximum cell value truncated length
:return: a new table_content and a mapping from answer to a value
"""
def _truncate_cell(cell_value):
# do not process on these cases
if isinstance(cell_value, int) or isinstance(cell_value, float):
return cell_value
if cell_value.strip() != "":
try_tokens = tokenizer.tokenize(cell_value)
if len(try_tokens) >= max_cell_length:
retain_tokens = try_tokens[:max_cell_truncate]
retain_cell_value = tokenizer.convert_tokens_to_string(retain_tokens)
return retain_cell_value
else:
return None
else:
return cell_value
cell_mapping = {}
for row in _table_content["rows"]:
for i, cell in enumerate(row):
truncate_cell = _truncate_cell(cell)
if truncate_cell is not None:
cell_mapping[cell] = truncate_cell
row[i] = truncate_cell
return _table_content, cell_mapping

0
dataset/.gitkeep Normal file
Просмотреть файл

117
download_dataset.py Normal file
Просмотреть файл

@ -0,0 +1,117 @@
import os
import requests
import tarfile
import zipfile
import shutil
"""
Download datasets used in our paper and move them into `raw_dataset` folder.
Now the script supports downloading the following datasets for different usage:
=== Fine-tuning ===
1. WikiSQL
2. WikiTableQuestions
3. SQA
4. TabFact
=== Pre-training ===
1. Squall
"""
RAW_DATASET_FOLDER = "raw_dataset"
def download_file(url):
"""
Download file into local file system from url
"""
local_filename = url.split('/')[-1]
with requests.get(url, stream=True) as r:
with open(local_filename, 'wb') as f:
shutil.copyfileobj(r.raw, f)
return local_filename
def download_wikisql():
"""
Download WikiSQL dataset and unzip the files
"""
WIKISQL_URL = "https://raw.github.com/salesforce/WikiSQL/master/data.tar.bz2"
wikisql_raw_path = os.path.join(RAW_DATASET_FOLDER, "wikisql")
wikisql_tar_file = download_file(WIKISQL_URL)
# unzip and move it into raw_dataset folder
tar = tarfile.open(wikisql_tar_file, "r:bz2")
tar.extractall(wikisql_raw_path)
tar.close()
# remove the original file
os.remove(wikisql_tar_file)
def download_wikitablequestion():
"""
Download WikiSQL dataset and unzip the files
"""
WTQ_URL = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip"
wtq_raw_path = os.path.join(RAW_DATASET_FOLDER, "wtq")
wtq_zip_file = download_file(WTQ_URL)
# unzip and move it into raw_dataset folder
with zipfile.ZipFile(wtq_zip_file) as zf:
zf.extractall(RAW_DATASET_FOLDER)
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "WikiTableQuestions")
shutil.move(unzip_wtq_path, wtq_raw_path)
# remove the original file
os.remove(wtq_zip_file)
def download_sqa():
"""
Download WikiSQL dataset and unzip the files
"""
SQA_URL = "https://download.microsoft.com/download/1/D/C/1DC270D2-1B53-4A61-A2E3-88AB3E4E6E1F/SQA%20Release%201.0.zip"
sqa_raw_path = os.path.join(RAW_DATASET_FOLDER, "sqa")
sqa_zip_file = download_file(SQA_URL)
# unzip and move it into raw_dataset folder
with zipfile.ZipFile(sqa_zip_file) as zf:
zf.extractall(RAW_DATASET_FOLDER)
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "SQA Release 1.0")
shutil.move(unzip_wtq_path, sqa_raw_path)
# remove the original file
os.remove(sqa_zip_file)
def download_tabfact():
"""
Download WikiSQL dataset and unzip the files
"""
SQA_URL = "https://download.microsoft.com/download/1/D/C/1DC270D2-1B53-4A61-A2E3-88AB3E4E6E1F/SQA%20Release%201.0.zip"
sqa_raw_path = os.path.join(RAW_DATASET_FOLDER, "sqa")
sqa_zip_file = download_file(SQA_URL)
# unzip and move it into raw_dataset folder
with zipfile.ZipFile(sqa_zip_file) as zf:
zf.extractall(RAW_DATASET_FOLDER)
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "SQA Release 1.0")
shutil.move(unzip_wtq_path, sqa_raw_path)
# remove the original file
os.remove(sqa_zip_file)
def download_squall():
"""
Download WikiSQL dataset and unzip the files
"""
SQAULL_URL = "https://github.com/tzshi/squall/archive/refs/heads/main.zip"
squall_raw_path = os.path.join(RAW_DATASET_FOLDER, "squall")
squall_zip_file = download_file(SQAULL_URL)
# unzip and move it into raw_dataset folder
with zipfile.ZipFile(squall_zip_file) as zf:
zf.extractall(RAW_DATASET_FOLDER)
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "squall-main")
shutil.move(unzip_wtq_path, squall_raw_path)
# remove the original file
os.remove(squall_zip_file)
if __name__ == '__main__':
# download_wikisql()
# download_wikitablequestion()
download_sqa()
# download_squall()

Просмотреть файл

@ -0,0 +1,85 @@
import os
from argparse import ArgumentParser
import numpy
import pandas as pd
from tqdm import tqdm
from common.table_transform import *
random.seed(42)
numpy.random.seed(42)
TABLE_PATH = os.path.join("raw_dataset", "sqa")
def read_table_from_file(_wtq_table_name: str):
rows = []
assert ".csv" in _wtq_table_name
table_data = pd.read_csv(os.path.join(TABLE_PATH, _wtq_table_name))
# the first line is header
header = list(table_data.columns)
for row_data in table_data.values:
rows.append([str(_) for _ in list(row_data)])
return {
"header": header,
"rows": rows
}
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--mode', help='source file for the prediction', type=str,
default='test')
parser.add_argument('--source-file', help='source file for the prediction', type=str,
default='sqa/test.tsv')
parser.add_argument('--data-dir', help='data directory to store the dataset', type=str,
default='dataset/sqa_chunk')
parser = setup_parser(parser)
args = parser.parse_args()
folder = args.data_dir
if not os.path.exists(folder):
os.makedirs(folder)
mode = args.mode
input_f = open("{}/{}.src".format(folder, mode), "w", encoding="utf8")
output_f = open("{}/{}.tgt".format(folder, mode), "w", encoding="utf8")
table_content_map = {}
db_engine_map = {}
history = ""
with open(args.source_file) as fs:
examples = open(args.source_file, "r", encoding="utf8").readlines()
idx = 0
for example in tqdm(examples[1:]):
try:
anno_id, _, position, question, table_file, _, answer_text = example.strip("\n").split("\t")
answer_text = answer_text.replace("\"\"", "\"").strip("\"'")
if position == "0":
# reset history
history = ""
question = question.lower()
if history:
question = history + " " + question
answer = eval(answer_text)
table_content = read_table_from_file(table_file)
input_sources, output_targets = build_example(args, question, answer,
table_content, table_file,
is_train=mode == "train",
max_sen_length=1024)
for input_s, output_t in zip(input_sources, output_targets):
input_f.write(input_s + "\n")
output_f.write(output_t + "\n")
# reset the history
history = question
idx += 1
except Exception as e:
print("Error case on Line: {}, {}".format(idx, question))
print(e)
input_f.close()
output_f.close()

Просмотреть файл

@ -0,0 +1,98 @@
# coding=utf8
import os
import json
from typing import List, Dict
from tqdm import tqdm
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="facebook/bart-large")
def read_jsonl(path: str) -> List[Dict]:
data = list()
with open(path, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
data.append(json.loads(line))
return data
def flatten_table(table: List[List[str]]) -> str:
str_rows = list()
for rid, row in enumerate(table):
values = " | ".join([value.lower() for value in row]) + " "
if rid == 0:
str_rows.append("col : %s" % values)
else:
str_rows.append("row %d : %s" % (rid, values))
return "".join(str_rows)
def main(data_path: str, splits: List[str], task: str = "classification") -> None:
for split in splits:
path = os.path.join(data_path, "%s.jsonl" % split)
examples = read_jsonl(path)
if task == 'classification':
save_input0_path = os.path.join(data_path, "%s.raw.input0" % split)
save_label_path = os.path.join(data_path, "%s.label" % split)
else:
save_input0_path = os.path.join(data_path, "%s.src" % split)
save_label_path = os.path.join(data_path, "%s.tgt" % split)
with open(save_input0_path, "w", encoding="utf8") as input0_f, open(save_label_path, "w", encoding="utf8") as label_f:
for ex in tqdm(examples):
sentence = ex['statement']
label = ex['label']
table_text = flatten_table(ex['table_text'])
input_line = "%s %s" % (sentence.lower(), table_text)
try_input_tokens = tokenizer.tokenize(input_line)
if len(try_input_tokens) > 1020:
try_input_tokens = try_input_tokens[: 1020]
input_line = tokenizer.convert_tokens_to_string(try_input_tokens)
print("Warning: truncate the source from {} token to 1020 tokens.".format(len(try_input_tokens)))
input0_f.write(input_line + "\n")
if task == 'classification':
label_f.write("%d\n" % label)
else:
answer = "yes" if label == 1 else "no"
label_f.write("%s\n" % answer)
def split_test_data(test_data_path, split_json_files, out_dir):
examples = read_jsonl(test_data_path)
for split_file in split_json_files:
split_mode = os.path.split(split_file)[-1][:-5]
valid_id_list = json.load(open(split_file, "r", encoding="utf8"))
valid_examples = [example for example in examples if example["table_id"] in valid_id_list]
save_input0_path = os.path.join(out_dir, "%s.raw.input0" % ("test_" + split_mode))
save_label_path = os.path.join(out_dir, "%s.label" % ("test_" + split_mode))
with open(save_input0_path, "w", encoding="utf8") as input0_f, open(save_label_path, "w", encoding="utf8") as label_f:
for ex in tqdm(valid_examples):
sentence = ex['statement']
label = ex['label']
table_text = flatten_table(ex['table_text'])
input_line = "%s %s" % (sentence.lower(), table_text)
try_input_tokens = tokenizer.tokenize(input_line)
if len(try_input_tokens) > 1020:
try_input_tokens = try_input_tokens[: 1020]
input_line = tokenizer.convert_tokens_to_string(try_input_tokens)
print("Warning: truncate the source")
input0_f.write(input_line + "\n")
label_f.write("%d\n" % label)
if __name__ == '__main__':
data_path = "dataset/tabfact_classification"
if not os.path.exists(data_path):
os.makedirs(data_path)
splits = ["train", "valid", "test"]
main(data_path, splits, task="classification")
split_test_data("tabfact/test.jsonl",
["tabfact/complex.json",
"tabfact/small.json",
"tabfact/simple.json"],
"dataset/tabfact_classification")

Просмотреть файл

@ -0,0 +1,77 @@
#!/usr/bin/env python
import json
import os
from argparse import ArgumentParser
from copy import deepcopy
from tqdm import tqdm
from common.table_transform import *
from wikisql_utils.executor import retrieve_wikisql_query_answer_tapas, _TYPE_CONVERTER
from wikisql_utils.wikisql_common import count_lines
TGT_DEL = ", "
def _parse_table(table):
"""Runs the type converter over the table cells."""
ret_table = deepcopy(table)
types = ret_table['types']
ret_table['real_rows'] = ret_table['rows']
typed_rows = []
for row in ret_table['rows']:
typed_row = []
for column, cell_value in enumerate(row):
typed_row.append(_TYPE_CONVERTER[types[column]](cell_value))
typed_rows.append(typed_row)
ret_table['rows'] = typed_rows
return ret_table
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--mode', help='source file for the prediction', type=str,
default='train')
parser.add_argument('--source-file', help='source file for the prediction', type=str, default='wikisql/data/train.jsonl')
parser.add_argument('--db-file', help='source database for the prediction', type=str, default='wikisql/data/train.db')
parser.add_argument('--table-file', help='source table content for the prediction', type=str,
default='wikisql/data/train.tables.jsonl')
parser.add_argument('--data-dir', help='data directory to store the dataset', type=str,
default='dataset/wikisql')
parser = setup_parser(parser)
args = parser.parse_args()
mode = args.mode
# record table id to table content
table_content_map = {}
for json_line in open(args.table_file, "r", encoding="utf8"):
content = json.loads(json_line)
table_content_map[content["id"]] = content
folder = args.data_dir
if not os.path.exists(folder):
os.makedirs(folder)
input_f = open("{}/{}.src".format(folder, mode), "w", encoding="utf8")
output_f = open("{}/{}.tgt".format(folder, mode), "w", encoding="utf8")
with open(args.source_file) as fs:
for ls in tqdm(fs, total=count_lines(args.source_file)):
example = json.loads(ls)
table_id = example["table_id"]
table_content = table_content_map[table_id]
question = example["question"].lower()
tapas_table = _parse_table(table_content)
answer = retrieve_wikisql_query_answer_tapas(tapas_table, example)
input_sources, output_targets = build_example(args, question, answer,
table_content, table_id,
is_train=mode == "train",
max_sen_length=1024)
for input_s, output_t in zip(input_sources, output_targets):
input_f.write(input_s + "\n")
output_f.write(output_t + "\n")
input_f.close()
output_f.close()

Просмотреть файл

@ -0,0 +1,82 @@
import logging
import os
from argparse import ArgumentParser
import numpy
from tqdm import tqdm
from typing import Tuple
from table_transform import *
TABLE_PATH = "wtq_origin"
logger = logging.getLogger(__name__)
def read_table_from_file(_wtq_table_name: str):
def _extract_content(_line: str):
_vals = [_.replace("\n", " ").strip() for _ in _line.strip("\n").split("\t")]
# _vals = ["empty" if _ == "" else _ for _ in _vals]
return _vals
rows = []
assert ".csv" in _wtq_table_name
# use the normalized table file
_wtq_table_name = _wtq_table_name.replace(".csv", ".tsv")
with open(os.path.join(TABLE_PATH, _wtq_table_name), "r", encoding="utf8") as table_f:
table_lines = table_f.readlines()
# the first line is header
header = _extract_content(table_lines[0])
for line in table_lines[1:]:
rows.append(_extract_content(line))
return {
"header": header,
"rows": rows
}
if __name__ == '__main__':
random.seed(42)
numpy.random.seed(42)
parser = ArgumentParser()
parser.add_argument('--mode', help='source file for the prediction', type=str,
default='train')
parser.add_argument('--source-file', help='source file for the prediction', type=str,
default='wtq_origin/train.tsv')
parser.add_argument('--data-dir', help='data directory to store the dataset', type=str,
default='dataset/wtq_chunk')
parser.add_argument('--max-sen-len', help='data directory to store the dataset', type=int)
# setup up table transformation operations
parser = setup_parser(parser)
args = parser.parse_args()
mode = args.mode
folder = args.data_dir
if not os.path.exists(folder):
os.makedirs(folder)
input_f = open("{}/{}.src".format(folder, mode), "w", encoding="utf8")
output_f = open("{}/{}.tgt".format(folder, mode), "w", encoding="utf8")
table_content_map = {}
db_engine_map = {}
with open(args.source_file) as fs:
examples = open(args.source_file, "r", encoding="utf8").readlines()
for example in tqdm(examples[1:]):
_, question, table_name, answer = example.strip("\n").split("\t")
question = question.lower()
answer = answer.split("|")
# must contain rows and header keys
table_content = read_table_from_file(table_name)
input_sources, output_targets = build_fairseq_example(args, question, answer, table_content,
table_name, mode == "train", args.max_sen_len)
for input_s, output_t in zip(input_sources, output_targets):
input_f.write(input_s + "\n")
output_f.write(output_t + "\n")
input_f.close()
output_f.close()

Просмотреть файл

@ -0,0 +1,25 @@
import os
import json
def load_file_and_convert(fariseq_data_folder, huggingface_data_folder):
train_src_file = os.path.join(fariseq_data_folder, "train.src")
train_tgt_file = os.path.join(fariseq_data_folder, "train.tgt")
dev_src_file = os.path.join(fariseq_data_folder, "valid.src")
dev_tgt_file = os.path.join(fariseq_data_folder, "valid.tgt")
out_train = open(os.path.join(huggingface_data_folder, "train.json"), "w", encoding="utf8")
out_dev = open(os.path.join(huggingface_data_folder, "valid.json"), "w", encoding="utf8")
with open(train_src_file, "r", encoding="utf8") as train_src_f, \
open(train_tgt_file, "r", encoding="utf8") as train_tgt_f:
src_lines = train_src_f.readlines()
tgt_lines = train_tgt_f.readlines()
for src_line, tgt_line in zip(src_lines, tgt_lines):
out_train.write(json.dumps({"input": src_line.strip(), "output": tgt_line.strip()}) + "\n")
with open(dev_src_file, "r", encoding="utf8") as dev_src_f, \
open(dev_tgt_file, "r", encoding="utf8") as dev_tgt_f:
src_lines = dev_src_f.readlines()
tgt_lines = dev_tgt_f.readlines()
for src_line, tgt_line in zip(src_lines, tgt_lines):
out_dev.write(json.dumps({"input": src_line.strip(), "output": tgt_line.strip()}) + "\n")

Просмотреть файл

@ -0,0 +1,250 @@
# The following script is adapted from the script of TaPas.
# Original: https://github.com/google-research/tapas/master/wikisql_utils.py
import dataclasses
import enum
import functools
import math
import re
from typing import Text, Any
import six
EMPTY_ANSWER = "none"
EMPTY_ANSWER_AGG = "none"
def _split_thousands(delimiter, value):
split = value.split(delimiter)
return len(split) > 1 and any(map(lambda x: len(x) == 3, split))
def convert_to_float(value):
"""Converts value to a float using a series of increasingly complex heuristics.
Args:
value: object that needs to be converted. Allowed types include
float/int/strings.
Returns:
A float interpretation of value.
Raises:
ValueError if the float conversion of value fails.
"""
if isinstance(value, float):
return value
if isinstance(value, int):
return float(value)
if not isinstance(value, six.string_types):
raise ValueError("Argument value is not a string. Can't parse it as float")
sanitized = value
try:
# Example: 1,000.7
if "." in sanitized and "," in sanitized:
return float(sanitized.replace(",", ""))
# 1,000
if "," in sanitized and _split_thousands(",", sanitized):
return float(sanitized.replace(",", ""))
# 5,5556
if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(
",", sanitized):
return float(sanitized.replace(",", "."))
# 0.0.0.1
if sanitized.count(".") > 1:
return float(sanitized.replace(".", ""))
# 0,0,0,1
if sanitized.count(",") > 1:
return float(sanitized.replace(",", ""))
return float(sanitized)
except ValueError:
# Avoid adding the sanitized value in the error message.
raise ValueError("Unable to convert value to float")
def _normalize_float(answer):
if answer is None:
return None
try:
value = convert_to_float(answer)
if isinstance(value, float) and math.isnan(value):
return None
return value
except ValueError:
return answer.lower()
_TYPE_CONVERTER = {
'text': lambda x: x,
'real': convert_to_float,
}
class _Aggregation(enum.Enum):
"""Aggregations as defined by WikiSQL. Indexes match the data."""
NONE = 0
MAX = 1
MIN = 2
COUNT = 3
SUM = 4
AVERAGE = 5
class _Operator(enum.Enum):
"""The boolean operators used by WikiSQL. Indexes match the data."""
EQUALS = 0
GREATER = 1
LESSER = 2
@dataclasses.dataclass
class _Condition:
"""Represents an SQL where clauses (e.g A = "a" or B > 5)."""
column: Text
operator: _Operator
cmp_value: Any
_TOKENIZER = re.compile(r'\w+|[^\w\s]+', re.UNICODE | re.MULTILINE | re.DOTALL)
def _normalize_for_match(x):
return [t for t in _TOKENIZER.findall(x.lower())]
def _compare(operator, src, tgt):
if operator == _Operator.EQUALS:
return src == tgt
elif operator == _Operator.GREATER:
return src > tgt
elif operator == _Operator.LESSER:
return src < tgt
raise ValueError(f'Unknown operator: {operator}')
def _parse_value(table, column,
cell_value):
"""Convert numeric values to floats and keeps everything else as string."""
types = table['types']
return _TYPE_CONVERTER[types[column]](cell_value)
def _is_string(x):
return isinstance(x, str)
def _respect_conditions(table, row,
conditions):
"""True if 'row' satisfies all 'conditions'."""
for cond in conditions:
table_value = row[cond.column]
cmp_value = _parse_value(table, cond.column, cond.cmp_value)
if _is_string(table_value) and _is_string(cmp_value):
table_value = _normalize_for_match(table_value)
cmp_value = _normalize_for_match(cmp_value)
if not isinstance(table_value, type(cmp_value)):
raise ValueError('Type difference {} != {}'.format(
type(table_value), type(cmp_value)))
if not _compare(cond.operator, table_value, cmp_value):
return False
return True
def _get_float_answer(table,
answer_coordinates,
aggregation_op):
"""Applies operation to produce reference float answer."""
if not answer_coordinates:
if aggregation_op == _Aggregation.COUNT:
return 0.0
else:
return EMPTY_ANSWER_AGG
# Count can support non numeric answers.
if aggregation_op == _Aggregation.COUNT:
return float(len(answer_coordinates))
# If we have just one answer, if float returns it or try a conversion.
values = [table['rows'][i][j] for (i, j) in answer_coordinates]
if len(answer_coordinates) == 1:
try:
return convert_to_float(values[0])
except ValueError as e:
if aggregation_op != _Aggregation.NONE:
raise e
if aggregation_op == _Aggregation.NONE:
return None
# Other aggregation only support numeric values. Bail out if we have strings.
if not all((isinstance(v, (int, float)) for v in values)):
return None
if aggregation_op == _Aggregation.SUM:
return float(sum(values))
elif aggregation_op == _Aggregation.AVERAGE:
return sum(values) / len(answer_coordinates)
else:
raise ValueError(f'Unknown aggregation: {aggregation_op}')
def _get_answer_coordinates(table, example):
"""Retrieves references coordinates by executing SQL."""
# MAX and MIN are automatically supported by the model.
aggregation_op_index = example['sql']['agg']
if aggregation_op_index >= 3:
aggregation_op = _Aggregation(aggregation_op_index)
else:
aggregation_op = _Aggregation.NONE
target_column = example['sql']['sel']
conditions = [
_Condition(column, _Operator(operator), cmp_value)
for column, operator, cmp_value in example['sql']['conds']
]
indices = []
for row in range(len(table['rows'])):
if _respect_conditions(table, table['rows'][row], conditions):
indices.append((row, target_column))
if not indices:
return [], aggregation_op
if len(indices) == 1:
return indices, aggregation_op
# Parsing of MIN/MAX.
if aggregation_op_index in (1, 2):
operators = {2: min, 1: max}
values = [
(table['rows'][i][j], index) for index, (i, j) in enumerate(indices)
]
reduced = functools.reduce(operators[example['sql']['agg']], values)
ret = [indices[reduced[1]]]
return ret, _Aggregation.NONE
return indices, aggregation_op
def _get_answer_text(table,
answer_coordinates,
float_answer):
if float_answer is not None:
return [str(float_answer)]
return [str(table['real_rows'][r][c]) for r, c in answer_coordinates]
def retrieve_wikisql_query_answer_tapas(table, example):
answer_coordinates, aggregation_op = \
_get_answer_coordinates(table, example)
float_answer = _get_float_answer(table, answer_coordinates,
aggregation_op)
answer_text = _get_answer_text(table, answer_coordinates, float_answer)
# keep the original data the same with TaPas
if len(answer_text) == 0:
answer_text = [EMPTY_ANSWER]
return answer_text

Просмотреть файл

@ -0,0 +1,240 @@
import re
from copy import deepcopy
from wikisql_common import detokenize
re_whitespace = re.compile(r'\s+', flags=re.UNICODE)
class Query:
agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']
syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG',
'AGGOPS', 'CONDOPS']
def __init__(self, sel_index, agg_index, conditions=tuple(), ordered=False):
self.sel_index = sel_index
self.agg_index = agg_index
self.conditions = list(conditions)
self.ordered = ordered
def __eq__(self, other):
if isinstance(other, self.__class__):
indices = self.sel_index == other.sel_index and self.agg_index == other.agg_index
if other.ordered:
conds = [(col, op, str(cond).lower()) for col, op, cond in self.conditions] == [
(col, op, str(cond).lower()) for col, op, cond in other.conditions]
else:
conds = set([(col, op, str(cond).lower()) for col, op, cond in self.conditions]) == set(
[(col, op, str(cond).lower()) for col, op, cond in other.conditions])
return indices and conds
return NotImplemented
def __ne__(self, other):
if isinstance(other, self.__class__):
return not self.__eq__(other)
return NotImplemented
def __hash__(self):
return hash(tuple(sorted(self.__dict__.items())))
def __repr__(self):
rep = 'SELECT {agg} {sel} FROM table'.format(
agg=self.agg_ops[self.agg_index],
sel='col{}'.format(self.sel_index),
)
if self.conditions:
rep += ' WHERE ' + ' AND '.join(
['{} {} {}'.format('col{}'.format(i), self.cond_ops[o], v) for i, o, v in self.conditions])
return rep
def to_dict(self):
return {'sel': self.sel_index, 'agg': self.agg_index, 'conds': self.conditions}
def lower(self):
conds = []
for col, op, cond in self.conditions:
conds.append([col, op, cond.lower()])
return self.__class__(self.sel_index, self.agg_index, conds)
@classmethod
def from_dict(cls, d, ordered=False):
return cls(sel_index=d['sel'], agg_index=d['agg'], conditions=d['conds'], ordered=ordered)
@classmethod
def from_tokenized_dict(cls, d):
conds = []
for col, op, val in d['conds']:
conds.append([col, op, detokenize(val)])
return cls(d['sel'], d['agg'], conds)
@classmethod
def from_generated_dict(cls, d):
conds = []
for col, op, val in d['conds']:
end = len(val['words'])
conds.append([col, op, detokenize(val)])
return cls(d['sel'], d['agg'], conds)
@classmethod
def from_sequence(cls, sequence, table, lowercase=True):
sequence = deepcopy(sequence)
if 'symend' in sequence['words']:
end = sequence['words'].index('symend')
for k, v in sequence.items():
sequence[k] = v[:end]
terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in
zip(sequence['gloss'], sequence['words'], sequence['after'])]
headers = [detokenize(h) for h in table['header']]
# lowercase everything and truncate sequence
if lowercase:
headers = [h.lower() for h in headers]
for i, t in enumerate(terms):
for k, v in t.items():
t[k] = v.lower()
headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers]
# get select
if 'symselect' != terms.pop(0)['word']:
raise Exception('Missing symselect operator')
# get aggregation
if 'symagg' != terms.pop(0)['word']:
raise Exception('Missing symagg operator')
agg_op = terms.pop(0)['word']
if agg_op == 'symcol':
agg_op = ''
else:
if 'symcol' != terms.pop(0)['word']:
raise Exception('Missing aggregation column')
try:
agg_op = cls.agg_ops.index(agg_op.upper())
except Exception as e:
raise Exception('Invalid agg op {}'.format(agg_op))
def find_column(name):
return headers_no_whitespcae.index(re.sub(re_whitespace, '', name))
def flatten(tokens):
ret = {'words': [], 'after': [], 'gloss': []}
for t in tokens:
ret['words'].append(t['word'])
ret['after'].append(t['after'])
ret['gloss'].append(t['gloss'])
return ret
where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere']
where_index = where_index[0] if where_index else len(terms)
flat = flatten(terms[:where_index])
try:
agg_col = find_column(detokenize(flat))
except Exception as e:
raise Exception('Cannot find aggregation column {}'.format(flat['words']))
where_terms = terms[where_index + 1:]
# get conditions
conditions = []
while where_terms:
t = where_terms.pop(0)
flat = flatten(where_terms)
if t['word'] != 'symcol':
raise Exception('Missing conditional column {}'.format(flat['words']))
try:
op_index = flat['words'].index('symop')
col_tokens = flatten(where_terms[:op_index])
except Exception as e:
raise Exception('Missing conditional operator {}'.format(flat['words']))
cond_op = where_terms[op_index + 1]['word']
try:
cond_op = cls.cond_ops.index(cond_op.upper())
except Exception as e:
raise Exception('Invalid cond op {}'.format(cond_op))
try:
cond_col = find_column(detokenize(col_tokens))
except Exception as e:
raise Exception('Cannot find conditional column {}'.format(col_tokens['words']))
try:
val_index = flat['words'].index('symcond')
except Exception as e:
raise Exception('Cannot find conditional value {}'.format(flat['words']))
where_terms = where_terms[val_index + 1:]
flat = flatten(where_terms)
val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms)
cond_val = detokenize(flatten(where_terms[:val_end_index]))
conditions.append([cond_col, cond_op, cond_val])
where_terms = where_terms[val_end_index + 1:]
q = cls(agg_col, agg_op, conditions)
return q
@classmethod
def from_partial_sequence(cls, agg_col, agg_op, sequence, table, lowercase=True):
sequence = deepcopy(sequence)
if 'symend' in sequence['words']:
end = sequence['words'].index('symend')
for k, v in sequence.items():
sequence[k] = v[:end]
terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in
zip(sequence['gloss'], sequence['words'], sequence['after'])]
headers = [detokenize(h) for h in table['header']]
# lowercase everything and truncate sequence
if lowercase:
headers = [h.lower() for h in headers]
for i, t in enumerate(terms):
for k, v in t.items():
t[k] = v.lower()
headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers]
def find_column(name):
return headers_no_whitespcae.index(re.sub(re_whitespace, '', name))
def flatten(tokens):
ret = {'words': [], 'after': [], 'gloss': []}
for t in tokens:
ret['words'].append(t['word'])
ret['after'].append(t['after'])
ret['gloss'].append(t['gloss'])
return ret
where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere']
where_index = where_index[0] if where_index else len(terms)
where_terms = terms[where_index + 1:]
# get conditions
conditions = []
while where_terms:
t = where_terms.pop(0)
flat = flatten(where_terms)
if t['word'] != 'symcol':
raise Exception('Missing conditional column {}'.format(flat['words']))
try:
op_index = flat['words'].index('symop')
col_tokens = flatten(where_terms[:op_index])
except Exception as e:
raise Exception('Missing conditional operator {}'.format(flat['words']))
cond_op = where_terms[op_index + 1]['word']
try:
cond_op = cls.cond_ops.index(cond_op.upper())
except Exception as e:
raise Exception('Invalid cond op {}'.format(cond_op))
try:
cond_col = find_column(detokenize(col_tokens))
except Exception as e:
raise Exception('Cannot find conditional column {}'.format(col_tokens['words']))
try:
val_index = flat['words'].index('symcond')
except Exception as e:
raise Exception('Cannot find conditional value {}'.format(flat['words']))
where_terms = where_terms[val_index + 1:]
flat = flatten(where_terms)
val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms)
cond_val = detokenize(flatten(where_terms[:val_end_index]))
conditions.append([cond_col, cond_op, cond_val])
where_terms = where_terms[val_end_index + 1:]
q = cls(agg_col, agg_op, conditions)
return q

Просмотреть файл

@ -0,0 +1,10 @@
def count_lines(fname):
with open(fname, "r", encoding="utf8") as f:
return sum(1 for line in f)
def detokenize(tokens):
ret = ''
for g, a in zip(tokens['gloss'], tokens['after']):
ret += g + a
return ret.strip()

0
raw_dataset/.gitkeep Normal file
Просмотреть файл