initialize scripts for building several datasets including WikiSQL, WTQ, SQA and etc.
This commit is contained in:
Родитель
ef1c078a43
Коммит
fe690930c3
|
@ -36,6 +36,8 @@ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additio
|
|||
|
||||
# 📝 License
|
||||
|
||||
The code and pre-trained models are open-sourced under [MIT License](LICENSE-Code), while the pre-training corpus is released under [CC BY-SA 4.0](LICENSE-Data).
|
||||
|
||||
# ™️ Trademarks
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Utils for linearizing the table content into a flatten sequence
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def concat_input_with_table(_table_content: Dict, input_query: str, start_row_idx: int = 0):
|
||||
compact_str = input_query.lower().strip() + " " + linearize_schema(_table_content, start_row_idx)
|
||||
return compact_str.strip()
|
||||
|
||||
|
||||
def linearize_schema(_table_content: Dict, start_row_idx: int):
|
||||
"""
|
||||
Data format: col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...
|
||||
"""
|
||||
_table_str = "col : " + " | ".join(_table_content["header"]) + " "
|
||||
_table_str = _table_str.lower()
|
||||
for i, row_example in enumerate(_table_content["rows"]):
|
||||
_table_str += "row " + str(start_row_idx + i + 1) + " : "
|
||||
row_cell_values = [str(cell_value) if isinstance(cell_value, int) else cell_value.lower()
|
||||
for cell_value in row_example]
|
||||
_table_str += " | ".join(row_cell_values) + " "
|
||||
|
||||
return _table_str
|
|
@ -0,0 +1,257 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import List, Dict
|
||||
import logging
|
||||
from transformers import AutoTokenizer
|
||||
from table_linearize import concat_input_with_table
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="facebook/bart-large")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TGT_DEL = ", "
|
||||
CHUNK_TOKEN = " <chunk> "
|
||||
|
||||
|
||||
def build_example(_args, _question: str, _answer: List, _table: Dict,
|
||||
_table_name: str, is_train: bool, max_sen_length):
|
||||
input_sources = []
|
||||
output_targets = []
|
||||
|
||||
truncate_table_content, table_mapping = truncate_database_values(_table,
|
||||
_args.max_cell_length,
|
||||
_args.max_cell_truncate)
|
||||
# the chunk size of validation can be scale to a larger upperbound
|
||||
max_chunk_size = _args.max_chunk_size if is_train else _args.max_chunk_size_valid
|
||||
if len(_answer) > 0:
|
||||
for i, case in enumerate(_answer):
|
||||
if case in table_mapping.keys():
|
||||
_answer[i] = table_mapping[case]
|
||||
flatten_output = TGT_DEL.join([str(case).lower() for case in _answer])
|
||||
# we should split into multiple chunks to save memory
|
||||
chunk_table_contents, cover_ratio = split_long_table(truncate_table_content,
|
||||
_question,
|
||||
max_sen_length=max_sen_length)
|
||||
# in training, we should recall these training examples
|
||||
if len(chunk_table_contents) > max_chunk_size:
|
||||
maximum_val = 1.1 if max_chunk_size != 1 else 1.05
|
||||
drop_ratio = maximum_val - cover_ratio
|
||||
# truncate the table
|
||||
small_table_content = truncate_training_database(_table_name,
|
||||
truncate_table_content,
|
||||
_question,
|
||||
drop_ratio,
|
||||
_answers=_answer if is_train else None)
|
||||
if max_chunk_size == 1:
|
||||
split_mode = "greedy"
|
||||
else:
|
||||
split_mode = "average"
|
||||
chunk_table_contents, _ = split_long_table(small_table_content,
|
||||
_question,
|
||||
max_sen_length=max_sen_length,
|
||||
split_mode=split_mode)
|
||||
|
||||
flatten_inputs = []
|
||||
# the initial value of row
|
||||
row_idx = 0
|
||||
early_stop = False
|
||||
for _chunk_id, _table_content in enumerate(chunk_table_contents):
|
||||
if _chunk_id >= max_chunk_size:
|
||||
logger.warning("Table: {} has {} chunks which is too large to handle, truncate it.".format(
|
||||
_table_name, len(chunk_table_contents)))
|
||||
early_stop = True
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
flatten_inputs.append(concat_input_with_table(_table_content, _question, start_row_idx=row_idx))
|
||||
row_idx = len(_table_content["rows"]) + 1
|
||||
|
||||
flatten_input = CHUNK_TOKEN.join(flatten_inputs).strip()
|
||||
input_sources.append(flatten_input)
|
||||
output_targets.append(flatten_output)
|
||||
|
||||
else:
|
||||
logger.warning("Empty answer in case: {}".format(_question))
|
||||
|
||||
return input_sources, output_targets
|
||||
|
||||
|
||||
def setup_parser(_parser):
|
||||
_parser.add_argument(
|
||||
"--max-cell-length",
|
||||
type=int,
|
||||
default=15,
|
||||
help="if the cell's length is larger that this, it should be processed",
|
||||
)
|
||||
|
||||
_parser.add_argument(
|
||||
"--max-cell-truncate",
|
||||
type=int,
|
||||
default=15,
|
||||
help="truncate cell's length into a value less than this one",
|
||||
)
|
||||
|
||||
return _parser
|
||||
|
||||
|
||||
def permute_table(_wtq_table_content: Dict):
|
||||
# shuffle header orders
|
||||
header = _wtq_table_content["header"]
|
||||
header_content = list(map(list, zip(*_wtq_table_content["rows"])))
|
||||
|
||||
header_num = len(_wtq_table_content["header"])
|
||||
header_range = list(range(header_num))
|
||||
random.shuffle(header_range)
|
||||
|
||||
# map from the original to the new
|
||||
origin_to_shuffle = {i: header_range[i] for i in range(header_num)}
|
||||
shuffle_header = [header[origin_to_shuffle[i]] for i in range(header_num)]
|
||||
shuffle_content = [header_content[origin_to_shuffle[i]] for i in range(header_num)]
|
||||
shuffle_rows = list(map(list, zip(*shuffle_content)))
|
||||
|
||||
# random.shuffle(shuffle_rows)
|
||||
|
||||
return {
|
||||
"header": shuffle_header,
|
||||
"rows": shuffle_rows
|
||||
}
|
||||
|
||||
|
||||
def split_long_table(_normalized_table: Dict, input_query: str, max_sen_length, split_mode="average"):
|
||||
assert "header" in _normalized_table
|
||||
assert "rows" in _normalized_table
|
||||
number_of_rows = len(_normalized_table["rows"])
|
||||
# TODO: avg split table into relatively average chunks
|
||||
query_tokens = tokenizer.tokenize(input_query, add_special_tokens=True)
|
||||
header_string = "col : " + " | ".join(_normalized_table["header"]) + " "
|
||||
header_tokens = tokenizer.tokenize(header_string, add_special_tokens=False)
|
||||
# split all cell values into tokens and see how many can be adapt
|
||||
used_token_len = len(query_tokens) + len(header_tokens)
|
||||
# remaining length
|
||||
remain_token_len = max_sen_length - 2 - used_token_len
|
||||
|
||||
value_string = ""
|
||||
for _, row_example in enumerate(_normalized_table["rows"]):
|
||||
# generally we do not want to make
|
||||
value_string += "row " + str(100) + " : "
|
||||
row_cell_values = [str(cell_value) if isinstance(cell_value, int) else cell_value.lower()
|
||||
for cell_value in row_example]
|
||||
value_string += " | ".join(row_cell_values) + " "
|
||||
value_token_len = len(tokenizer.tokenize(value_string))
|
||||
# used to estimate the busy ratio
|
||||
whole_token_len = used_token_len + value_token_len
|
||||
|
||||
# maximum chunk size
|
||||
chunk_size = math.ceil(value_token_len / remain_token_len)
|
||||
if chunk_size == 1:
|
||||
return [_normalized_table], 1.0
|
||||
|
||||
if split_mode == "greedy":
|
||||
remain_token_len = remain_token_len
|
||||
elif split_mode == "average":
|
||||
remain_token_len = min(remain_token_len, 100 + math.ceil(value_token_len / chunk_size))
|
||||
else:
|
||||
raise Exception("Do not support split_mode {}".format(split_mode))
|
||||
|
||||
current_chunk_remain_size = remain_token_len
|
||||
current_chunk_row = 0
|
||||
split_table_contents = []
|
||||
for ind, row_example in enumerate(_normalized_table["rows"]):
|
||||
value_string = "row " + str(ind) + " : "
|
||||
row_cell_values = [str(cell_value) if isinstance(cell_value, int) else cell_value.lower()
|
||||
for cell_value in row_example]
|
||||
value_string += " | ".join(row_cell_values)
|
||||
value_token_len = len(tokenizer.tokenize(value_string))
|
||||
# over the size limit, and take action
|
||||
if value_token_len > current_chunk_remain_size:
|
||||
split_table_contents.append({
|
||||
"header": _normalized_table["header"],
|
||||
"rows": _normalized_table["rows"][current_chunk_row: ind]
|
||||
})
|
||||
# reset every thing
|
||||
current_chunk_row = ind
|
||||
current_chunk_remain_size = remain_token_len
|
||||
current_chunk_remain_size -= value_token_len
|
||||
|
||||
if current_chunk_row != (number_of_rows - 1):
|
||||
split_table_contents.append({
|
||||
"header": _normalized_table["header"],
|
||||
"rows": _normalized_table["rows"][current_chunk_row:]
|
||||
})
|
||||
|
||||
return split_table_contents, float(max_sen_length / whole_token_len)
|
||||
|
||||
|
||||
def truncate_training_database(_table_id, _table_content, _question, _drop_ratio: float, _answers=None, _sql=None):
|
||||
truncated_unrelated_indices = []
|
||||
related_indices = []
|
||||
if _answers is None:
|
||||
answer_set = set([])
|
||||
else:
|
||||
answer_set = set([ans_ex.lower() for ans_ex in _answers])
|
||||
# add _sql into answer_set
|
||||
if _sql is not None:
|
||||
answer_set.update(_sql.split())
|
||||
question_set = set(_question.strip("?!.,").split(" "))
|
||||
row_max_len = len(_table_content["rows"])
|
||||
for _row_idx, row in enumerate(_table_content["rows"]):
|
||||
lower_row = set([str(cell).lower() for cell in row])
|
||||
if len(lower_row & answer_set) == 0 and len(lower_row & question_set) == 0:
|
||||
truncated_unrelated_indices.append(_row_idx)
|
||||
else:
|
||||
# add neighbours to preserve information aggressively
|
||||
related_indices.append([_row_idx - 2, _row_idx - 1, _row_idx, _row_idx + 1, _row_idx + 2])
|
||||
|
||||
# remove the neighbours
|
||||
truncated_unrelated_indices = [_row_idx for _row_idx in truncated_unrelated_indices
|
||||
if _row_idx not in related_indices]
|
||||
# select some cases to drop
|
||||
drop_items = min(len(truncated_unrelated_indices), int(len(_table_content["rows"]) * _drop_ratio))
|
||||
drop_row_indices = random.choices(truncated_unrelated_indices, k=drop_items)
|
||||
|
||||
for _row_idx in reversed(range(row_max_len)):
|
||||
if _row_idx in drop_row_indices:
|
||||
del _table_content["rows"][_row_idx]
|
||||
|
||||
# only when the drop ratio is too large, logging for warning.
|
||||
if _drop_ratio >= 0.1:
|
||||
logger.warning("Drop {:.2f} rows in table {}".format(len(drop_row_indices), _table_id))
|
||||
|
||||
return _table_content
|
||||
|
||||
|
||||
def truncate_database_values(_table_content, max_cell_length, max_cell_truncate):
|
||||
"""
|
||||
This function is to process the wikitablequestion answer content to avoid too long sequence.
|
||||
:param _table_content: `Dict` contains keys as `header` and `rows`
|
||||
:param max_cell_length: `int` which indicates the maximum cell value length
|
||||
:param max_cell_truncate: `int` which indicates the maximum cell value truncated length
|
||||
:return: a new table_content and a mapping from answer to a value
|
||||
"""
|
||||
|
||||
def _truncate_cell(cell_value):
|
||||
# do not process on these cases
|
||||
if isinstance(cell_value, int) or isinstance(cell_value, float):
|
||||
return cell_value
|
||||
if cell_value.strip() != "":
|
||||
try_tokens = tokenizer.tokenize(cell_value)
|
||||
if len(try_tokens) >= max_cell_length:
|
||||
retain_tokens = try_tokens[:max_cell_truncate]
|
||||
retain_cell_value = tokenizer.convert_tokens_to_string(retain_tokens)
|
||||
return retain_cell_value
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return cell_value
|
||||
|
||||
cell_mapping = {}
|
||||
for row in _table_content["rows"]:
|
||||
for i, cell in enumerate(row):
|
||||
truncate_cell = _truncate_cell(cell)
|
||||
if truncate_cell is not None:
|
||||
cell_mapping[cell] = truncate_cell
|
||||
row[i] = truncate_cell
|
||||
return _table_content, cell_mapping
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
import requests
|
||||
import tarfile
|
||||
import zipfile
|
||||
import shutil
|
||||
"""
|
||||
Download datasets used in our paper and move them into `raw_dataset` folder.
|
||||
Now the script supports downloading the following datasets for different usage:
|
||||
|
||||
=== Fine-tuning ===
|
||||
1. WikiSQL
|
||||
2. WikiTableQuestions
|
||||
3. SQA
|
||||
4. TabFact
|
||||
|
||||
=== Pre-training ===
|
||||
1. Squall
|
||||
"""
|
||||
RAW_DATASET_FOLDER = "raw_dataset"
|
||||
|
||||
|
||||
def download_file(url):
|
||||
"""
|
||||
Download file into local file system from url
|
||||
"""
|
||||
local_filename = url.split('/')[-1]
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(local_filename, 'wb') as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
return local_filename
|
||||
|
||||
|
||||
def download_wikisql():
|
||||
"""
|
||||
Download WikiSQL dataset and unzip the files
|
||||
"""
|
||||
WIKISQL_URL = "https://raw.github.com/salesforce/WikiSQL/master/data.tar.bz2"
|
||||
wikisql_raw_path = os.path.join(RAW_DATASET_FOLDER, "wikisql")
|
||||
wikisql_tar_file = download_file(WIKISQL_URL)
|
||||
# unzip and move it into raw_dataset folder
|
||||
tar = tarfile.open(wikisql_tar_file, "r:bz2")
|
||||
tar.extractall(wikisql_raw_path)
|
||||
tar.close()
|
||||
# remove the original file
|
||||
os.remove(wikisql_tar_file)
|
||||
|
||||
|
||||
def download_wikitablequestion():
|
||||
"""
|
||||
Download WikiSQL dataset and unzip the files
|
||||
"""
|
||||
WTQ_URL = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip"
|
||||
wtq_raw_path = os.path.join(RAW_DATASET_FOLDER, "wtq")
|
||||
wtq_zip_file = download_file(WTQ_URL)
|
||||
# unzip and move it into raw_dataset folder
|
||||
with zipfile.ZipFile(wtq_zip_file) as zf:
|
||||
zf.extractall(RAW_DATASET_FOLDER)
|
||||
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "WikiTableQuestions")
|
||||
shutil.move(unzip_wtq_path, wtq_raw_path)
|
||||
# remove the original file
|
||||
os.remove(wtq_zip_file)
|
||||
|
||||
|
||||
def download_sqa():
|
||||
"""
|
||||
Download WikiSQL dataset and unzip the files
|
||||
"""
|
||||
SQA_URL = "https://download.microsoft.com/download/1/D/C/1DC270D2-1B53-4A61-A2E3-88AB3E4E6E1F/SQA%20Release%201.0.zip"
|
||||
sqa_raw_path = os.path.join(RAW_DATASET_FOLDER, "sqa")
|
||||
sqa_zip_file = download_file(SQA_URL)
|
||||
# unzip and move it into raw_dataset folder
|
||||
with zipfile.ZipFile(sqa_zip_file) as zf:
|
||||
zf.extractall(RAW_DATASET_FOLDER)
|
||||
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "SQA Release 1.0")
|
||||
shutil.move(unzip_wtq_path, sqa_raw_path)
|
||||
# remove the original file
|
||||
os.remove(sqa_zip_file)
|
||||
|
||||
|
||||
def download_tabfact():
|
||||
"""
|
||||
Download WikiSQL dataset and unzip the files
|
||||
"""
|
||||
SQA_URL = "https://download.microsoft.com/download/1/D/C/1DC270D2-1B53-4A61-A2E3-88AB3E4E6E1F/SQA%20Release%201.0.zip"
|
||||
sqa_raw_path = os.path.join(RAW_DATASET_FOLDER, "sqa")
|
||||
sqa_zip_file = download_file(SQA_URL)
|
||||
# unzip and move it into raw_dataset folder
|
||||
with zipfile.ZipFile(sqa_zip_file) as zf:
|
||||
zf.extractall(RAW_DATASET_FOLDER)
|
||||
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "SQA Release 1.0")
|
||||
shutil.move(unzip_wtq_path, sqa_raw_path)
|
||||
# remove the original file
|
||||
os.remove(sqa_zip_file)
|
||||
|
||||
|
||||
def download_squall():
|
||||
"""
|
||||
Download WikiSQL dataset and unzip the files
|
||||
"""
|
||||
SQAULL_URL = "https://github.com/tzshi/squall/archive/refs/heads/main.zip"
|
||||
squall_raw_path = os.path.join(RAW_DATASET_FOLDER, "squall")
|
||||
squall_zip_file = download_file(SQAULL_URL)
|
||||
# unzip and move it into raw_dataset folder
|
||||
with zipfile.ZipFile(squall_zip_file) as zf:
|
||||
zf.extractall(RAW_DATASET_FOLDER)
|
||||
unzip_wtq_path = os.path.join(RAW_DATASET_FOLDER, "squall-main")
|
||||
shutil.move(unzip_wtq_path, squall_raw_path)
|
||||
# remove the original file
|
||||
os.remove(squall_zip_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# download_wikisql()
|
||||
# download_wikitablequestion()
|
||||
download_sqa()
|
||||
# download_squall()
|
|
@ -0,0 +1,85 @@
|
|||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import numpy
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from common.table_transform import *
|
||||
|
||||
random.seed(42)
|
||||
numpy.random.seed(42)
|
||||
|
||||
TABLE_PATH = os.path.join("raw_dataset", "sqa")
|
||||
|
||||
|
||||
def read_table_from_file(_wtq_table_name: str):
|
||||
rows = []
|
||||
assert ".csv" in _wtq_table_name
|
||||
table_data = pd.read_csv(os.path.join(TABLE_PATH, _wtq_table_name))
|
||||
# the first line is header
|
||||
header = list(table_data.columns)
|
||||
for row_data in table_data.values:
|
||||
rows.append([str(_) for _ in list(row_data)])
|
||||
|
||||
return {
|
||||
"header": header,
|
||||
"rows": rows
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--mode', help='source file for the prediction', type=str,
|
||||
default='test')
|
||||
parser.add_argument('--source-file', help='source file for the prediction', type=str,
|
||||
default='sqa/test.tsv')
|
||||
parser.add_argument('--data-dir', help='data directory to store the dataset', type=str,
|
||||
default='dataset/sqa_chunk')
|
||||
|
||||
parser = setup_parser(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
folder = args.data_dir
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
|
||||
mode = args.mode
|
||||
|
||||
input_f = open("{}/{}.src".format(folder, mode), "w", encoding="utf8")
|
||||
output_f = open("{}/{}.tgt".format(folder, mode), "w", encoding="utf8")
|
||||
|
||||
table_content_map = {}
|
||||
db_engine_map = {}
|
||||
history = ""
|
||||
|
||||
with open(args.source_file) as fs:
|
||||
examples = open(args.source_file, "r", encoding="utf8").readlines()
|
||||
idx = 0
|
||||
for example in tqdm(examples[1:]):
|
||||
try:
|
||||
anno_id, _, position, question, table_file, _, answer_text = example.strip("\n").split("\t")
|
||||
answer_text = answer_text.replace("\"\"", "\"").strip("\"'")
|
||||
if position == "0":
|
||||
# reset history
|
||||
history = ""
|
||||
question = question.lower()
|
||||
if history:
|
||||
question = history + " " + question
|
||||
answer = eval(answer_text)
|
||||
table_content = read_table_from_file(table_file)
|
||||
input_sources, output_targets = build_example(args, question, answer,
|
||||
table_content, table_file,
|
||||
is_train=mode == "train",
|
||||
max_sen_length=1024)
|
||||
for input_s, output_t in zip(input_sources, output_targets):
|
||||
input_f.write(input_s + "\n")
|
||||
output_f.write(output_t + "\n")
|
||||
# reset the history
|
||||
history = question
|
||||
idx += 1
|
||||
except Exception as e:
|
||||
print("Error case on Line: {}, {}".format(idx, question))
|
||||
print(e)
|
||||
input_f.close()
|
||||
output_f.close()
|
|
@ -0,0 +1,98 @@
|
|||
# coding=utf8
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="facebook/bart-large")
|
||||
|
||||
|
||||
def read_jsonl(path: str) -> List[Dict]:
|
||||
data = list()
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
data.append(json.loads(line))
|
||||
return data
|
||||
|
||||
|
||||
def flatten_table(table: List[List[str]]) -> str:
|
||||
str_rows = list()
|
||||
for rid, row in enumerate(table):
|
||||
values = " | ".join([value.lower() for value in row]) + " "
|
||||
if rid == 0:
|
||||
str_rows.append("col : %s" % values)
|
||||
else:
|
||||
str_rows.append("row %d : %s" % (rid, values))
|
||||
return "".join(str_rows)
|
||||
|
||||
|
||||
def main(data_path: str, splits: List[str], task: str = "classification") -> None:
|
||||
for split in splits:
|
||||
path = os.path.join(data_path, "%s.jsonl" % split)
|
||||
|
||||
examples = read_jsonl(path)
|
||||
|
||||
if task == 'classification':
|
||||
save_input0_path = os.path.join(data_path, "%s.raw.input0" % split)
|
||||
save_label_path = os.path.join(data_path, "%s.label" % split)
|
||||
else:
|
||||
save_input0_path = os.path.join(data_path, "%s.src" % split)
|
||||
save_label_path = os.path.join(data_path, "%s.tgt" % split)
|
||||
|
||||
with open(save_input0_path, "w", encoding="utf8") as input0_f, open(save_label_path, "w", encoding="utf8") as label_f:
|
||||
for ex in tqdm(examples):
|
||||
sentence = ex['statement']
|
||||
label = ex['label']
|
||||
table_text = flatten_table(ex['table_text'])
|
||||
input_line = "%s %s" % (sentence.lower(), table_text)
|
||||
try_input_tokens = tokenizer.tokenize(input_line)
|
||||
if len(try_input_tokens) > 1020:
|
||||
try_input_tokens = try_input_tokens[: 1020]
|
||||
input_line = tokenizer.convert_tokens_to_string(try_input_tokens)
|
||||
print("Warning: truncate the source from {} token to 1020 tokens.".format(len(try_input_tokens)))
|
||||
input0_f.write(input_line + "\n")
|
||||
if task == 'classification':
|
||||
label_f.write("%d\n" % label)
|
||||
else:
|
||||
answer = "yes" if label == 1 else "no"
|
||||
label_f.write("%s\n" % answer)
|
||||
|
||||
|
||||
def split_test_data(test_data_path, split_json_files, out_dir):
|
||||
examples = read_jsonl(test_data_path)
|
||||
for split_file in split_json_files:
|
||||
split_mode = os.path.split(split_file)[-1][:-5]
|
||||
valid_id_list = json.load(open(split_file, "r", encoding="utf8"))
|
||||
valid_examples = [example for example in examples if example["table_id"] in valid_id_list]
|
||||
save_input0_path = os.path.join(out_dir, "%s.raw.input0" % ("test_" + split_mode))
|
||||
save_label_path = os.path.join(out_dir, "%s.label" % ("test_" + split_mode))
|
||||
|
||||
with open(save_input0_path, "w", encoding="utf8") as input0_f, open(save_label_path, "w", encoding="utf8") as label_f:
|
||||
for ex in tqdm(valid_examples):
|
||||
sentence = ex['statement']
|
||||
label = ex['label']
|
||||
table_text = flatten_table(ex['table_text'])
|
||||
input_line = "%s %s" % (sentence.lower(), table_text)
|
||||
try_input_tokens = tokenizer.tokenize(input_line)
|
||||
if len(try_input_tokens) > 1020:
|
||||
try_input_tokens = try_input_tokens[: 1020]
|
||||
input_line = tokenizer.convert_tokens_to_string(try_input_tokens)
|
||||
print("Warning: truncate the source")
|
||||
input0_f.write(input_line + "\n")
|
||||
label_f.write("%d\n" % label)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_path = "dataset/tabfact_classification"
|
||||
if not os.path.exists(data_path):
|
||||
os.makedirs(data_path)
|
||||
splits = ["train", "valid", "test"]
|
||||
main(data_path, splits, task="classification")
|
||||
split_test_data("tabfact/test.jsonl",
|
||||
["tabfact/complex.json",
|
||||
"tabfact/small.json",
|
||||
"tabfact/simple.json"],
|
||||
"dataset/tabfact_classification")
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python
|
||||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from copy import deepcopy
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from common.table_transform import *
|
||||
from wikisql_utils.executor import retrieve_wikisql_query_answer_tapas, _TYPE_CONVERTER
|
||||
from wikisql_utils.wikisql_common import count_lines
|
||||
|
||||
TGT_DEL = ", "
|
||||
|
||||
|
||||
def _parse_table(table):
|
||||
"""Runs the type converter over the table cells."""
|
||||
ret_table = deepcopy(table)
|
||||
types = ret_table['types']
|
||||
ret_table['real_rows'] = ret_table['rows']
|
||||
typed_rows = []
|
||||
for row in ret_table['rows']:
|
||||
typed_row = []
|
||||
for column, cell_value in enumerate(row):
|
||||
typed_row.append(_TYPE_CONVERTER[types[column]](cell_value))
|
||||
typed_rows.append(typed_row)
|
||||
ret_table['rows'] = typed_rows
|
||||
return ret_table
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--mode', help='source file for the prediction', type=str,
|
||||
default='train')
|
||||
parser.add_argument('--source-file', help='source file for the prediction', type=str, default='wikisql/data/train.jsonl')
|
||||
parser.add_argument('--db-file', help='source database for the prediction', type=str, default='wikisql/data/train.db')
|
||||
parser.add_argument('--table-file', help='source table content for the prediction', type=str,
|
||||
default='wikisql/data/train.tables.jsonl')
|
||||
parser.add_argument('--data-dir', help='data directory to store the dataset', type=str,
|
||||
default='dataset/wikisql')
|
||||
parser = setup_parser(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
mode = args.mode
|
||||
# record table id to table content
|
||||
table_content_map = {}
|
||||
|
||||
for json_line in open(args.table_file, "r", encoding="utf8"):
|
||||
content = json.loads(json_line)
|
||||
table_content_map[content["id"]] = content
|
||||
|
||||
folder = args.data_dir
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
|
||||
input_f = open("{}/{}.src".format(folder, mode), "w", encoding="utf8")
|
||||
output_f = open("{}/{}.tgt".format(folder, mode), "w", encoding="utf8")
|
||||
|
||||
with open(args.source_file) as fs:
|
||||
for ls in tqdm(fs, total=count_lines(args.source_file)):
|
||||
example = json.loads(ls)
|
||||
table_id = example["table_id"]
|
||||
table_content = table_content_map[table_id]
|
||||
question = example["question"].lower()
|
||||
tapas_table = _parse_table(table_content)
|
||||
answer = retrieve_wikisql_query_answer_tapas(tapas_table, example)
|
||||
input_sources, output_targets = build_example(args, question, answer,
|
||||
table_content, table_id,
|
||||
is_train=mode == "train",
|
||||
max_sen_length=1024)
|
||||
for input_s, output_t in zip(input_sources, output_targets):
|
||||
input_f.write(input_s + "\n")
|
||||
output_f.write(output_t + "\n")
|
||||
|
||||
input_f.close()
|
||||
output_f.close()
|
|
@ -0,0 +1,82 @@
|
|||
import logging
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import numpy
|
||||
from tqdm import tqdm
|
||||
from typing import Tuple
|
||||
from table_transform import *
|
||||
|
||||
TABLE_PATH = "wtq_origin"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_table_from_file(_wtq_table_name: str):
|
||||
def _extract_content(_line: str):
|
||||
_vals = [_.replace("\n", " ").strip() for _ in _line.strip("\n").split("\t")]
|
||||
# _vals = ["empty" if _ == "" else _ for _ in _vals]
|
||||
return _vals
|
||||
|
||||
rows = []
|
||||
assert ".csv" in _wtq_table_name
|
||||
# use the normalized table file
|
||||
_wtq_table_name = _wtq_table_name.replace(".csv", ".tsv")
|
||||
with open(os.path.join(TABLE_PATH, _wtq_table_name), "r", encoding="utf8") as table_f:
|
||||
table_lines = table_f.readlines()
|
||||
# the first line is header
|
||||
header = _extract_content(table_lines[0])
|
||||
for line in table_lines[1:]:
|
||||
rows.append(_extract_content(line))
|
||||
|
||||
return {
|
||||
"header": header,
|
||||
"rows": rows
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
random.seed(42)
|
||||
numpy.random.seed(42)
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--mode', help='source file for the prediction', type=str,
|
||||
default='train')
|
||||
parser.add_argument('--source-file', help='source file for the prediction', type=str,
|
||||
default='wtq_origin/train.tsv')
|
||||
parser.add_argument('--data-dir', help='data directory to store the dataset', type=str,
|
||||
default='dataset/wtq_chunk')
|
||||
parser.add_argument('--max-sen-len', help='data directory to store the dataset', type=int)
|
||||
|
||||
# setup up table transformation operations
|
||||
parser = setup_parser(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
mode = args.mode
|
||||
folder = args.data_dir
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
|
||||
input_f = open("{}/{}.src".format(folder, mode), "w", encoding="utf8")
|
||||
output_f = open("{}/{}.tgt".format(folder, mode), "w", encoding="utf8")
|
||||
|
||||
table_content_map = {}
|
||||
db_engine_map = {}
|
||||
|
||||
with open(args.source_file) as fs:
|
||||
examples = open(args.source_file, "r", encoding="utf8").readlines()
|
||||
for example in tqdm(examples[1:]):
|
||||
_, question, table_name, answer = example.strip("\n").split("\t")
|
||||
question = question.lower()
|
||||
answer = answer.split("|")
|
||||
# must contain rows and header keys
|
||||
table_content = read_table_from_file(table_name)
|
||||
input_sources, output_targets = build_fairseq_example(args, question, answer, table_content,
|
||||
table_name, mode == "train", args.max_sen_len)
|
||||
for input_s, output_t in zip(input_sources, output_targets):
|
||||
input_f.write(input_s + "\n")
|
||||
output_f.write(output_t + "\n")
|
||||
|
||||
input_f.close()
|
||||
output_f.close()
|
|
@ -0,0 +1,25 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
|
||||
def load_file_and_convert(fariseq_data_folder, huggingface_data_folder):
|
||||
train_src_file = os.path.join(fariseq_data_folder, "train.src")
|
||||
train_tgt_file = os.path.join(fariseq_data_folder, "train.tgt")
|
||||
dev_src_file = os.path.join(fariseq_data_folder, "valid.src")
|
||||
dev_tgt_file = os.path.join(fariseq_data_folder, "valid.tgt")
|
||||
|
||||
out_train = open(os.path.join(huggingface_data_folder, "train.json"), "w", encoding="utf8")
|
||||
out_dev = open(os.path.join(huggingface_data_folder, "valid.json"), "w", encoding="utf8")
|
||||
with open(train_src_file, "r", encoding="utf8") as train_src_f, \
|
||||
open(train_tgt_file, "r", encoding="utf8") as train_tgt_f:
|
||||
src_lines = train_src_f.readlines()
|
||||
tgt_lines = train_tgt_f.readlines()
|
||||
for src_line, tgt_line in zip(src_lines, tgt_lines):
|
||||
out_train.write(json.dumps({"input": src_line.strip(), "output": tgt_line.strip()}) + "\n")
|
||||
|
||||
with open(dev_src_file, "r", encoding="utf8") as dev_src_f, \
|
||||
open(dev_tgt_file, "r", encoding="utf8") as dev_tgt_f:
|
||||
src_lines = dev_src_f.readlines()
|
||||
tgt_lines = dev_tgt_f.readlines()
|
||||
for src_line, tgt_line in zip(src_lines, tgt_lines):
|
||||
out_dev.write(json.dumps({"input": src_line.strip(), "output": tgt_line.strip()}) + "\n")
|
|
@ -0,0 +1,250 @@
|
|||
# The following script is adapted from the script of TaPas.
|
||||
# Original: https://github.com/google-research/tapas/master/wikisql_utils.py
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import re
|
||||
from typing import Text, Any
|
||||
|
||||
import six
|
||||
|
||||
EMPTY_ANSWER = "none"
|
||||
EMPTY_ANSWER_AGG = "none"
|
||||
|
||||
|
||||
def _split_thousands(delimiter, value):
|
||||
split = value.split(delimiter)
|
||||
return len(split) > 1 and any(map(lambda x: len(x) == 3, split))
|
||||
|
||||
|
||||
def convert_to_float(value):
|
||||
"""Converts value to a float using a series of increasingly complex heuristics.
|
||||
Args:
|
||||
value: object that needs to be converted. Allowed types include
|
||||
float/int/strings.
|
||||
Returns:
|
||||
A float interpretation of value.
|
||||
Raises:
|
||||
ValueError if the float conversion of value fails.
|
||||
"""
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return float(value)
|
||||
if not isinstance(value, six.string_types):
|
||||
raise ValueError("Argument value is not a string. Can't parse it as float")
|
||||
sanitized = value
|
||||
|
||||
try:
|
||||
# Example: 1,000.7
|
||||
if "." in sanitized and "," in sanitized:
|
||||
return float(sanitized.replace(",", ""))
|
||||
# 1,000
|
||||
if "," in sanitized and _split_thousands(",", sanitized):
|
||||
return float(sanitized.replace(",", ""))
|
||||
# 5,5556
|
||||
if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(
|
||||
",", sanitized):
|
||||
return float(sanitized.replace(",", "."))
|
||||
# 0.0.0.1
|
||||
if sanitized.count(".") > 1:
|
||||
return float(sanitized.replace(".", ""))
|
||||
# 0,0,0,1
|
||||
if sanitized.count(",") > 1:
|
||||
return float(sanitized.replace(",", ""))
|
||||
return float(sanitized)
|
||||
except ValueError:
|
||||
# Avoid adding the sanitized value in the error message.
|
||||
raise ValueError("Unable to convert value to float")
|
||||
|
||||
|
||||
def _normalize_float(answer):
|
||||
if answer is None:
|
||||
return None
|
||||
try:
|
||||
value = convert_to_float(answer)
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
return None
|
||||
return value
|
||||
except ValueError:
|
||||
return answer.lower()
|
||||
|
||||
|
||||
_TYPE_CONVERTER = {
|
||||
'text': lambda x: x,
|
||||
'real': convert_to_float,
|
||||
}
|
||||
|
||||
|
||||
class _Aggregation(enum.Enum):
|
||||
"""Aggregations as defined by WikiSQL. Indexes match the data."""
|
||||
NONE = 0
|
||||
MAX = 1
|
||||
MIN = 2
|
||||
COUNT = 3
|
||||
SUM = 4
|
||||
AVERAGE = 5
|
||||
|
||||
|
||||
class _Operator(enum.Enum):
|
||||
"""The boolean operators used by WikiSQL. Indexes match the data."""
|
||||
EQUALS = 0
|
||||
GREATER = 1
|
||||
LESSER = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Condition:
|
||||
"""Represents an SQL where clauses (e.g A = "a" or B > 5)."""
|
||||
column: Text
|
||||
operator: _Operator
|
||||
cmp_value: Any
|
||||
|
||||
|
||||
_TOKENIZER = re.compile(r'\w+|[^\w\s]+', re.UNICODE | re.MULTILINE | re.DOTALL)
|
||||
|
||||
|
||||
def _normalize_for_match(x):
|
||||
return [t for t in _TOKENIZER.findall(x.lower())]
|
||||
|
||||
|
||||
def _compare(operator, src, tgt):
|
||||
if operator == _Operator.EQUALS:
|
||||
return src == tgt
|
||||
elif operator == _Operator.GREATER:
|
||||
return src > tgt
|
||||
elif operator == _Operator.LESSER:
|
||||
return src < tgt
|
||||
raise ValueError(f'Unknown operator: {operator}')
|
||||
|
||||
|
||||
def _parse_value(table, column,
|
||||
cell_value):
|
||||
"""Convert numeric values to floats and keeps everything else as string."""
|
||||
types = table['types']
|
||||
return _TYPE_CONVERTER[types[column]](cell_value)
|
||||
|
||||
|
||||
def _is_string(x):
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
def _respect_conditions(table, row,
|
||||
conditions):
|
||||
"""True if 'row' satisfies all 'conditions'."""
|
||||
for cond in conditions:
|
||||
table_value = row[cond.column]
|
||||
|
||||
cmp_value = _parse_value(table, cond.column, cond.cmp_value)
|
||||
|
||||
if _is_string(table_value) and _is_string(cmp_value):
|
||||
table_value = _normalize_for_match(table_value)
|
||||
cmp_value = _normalize_for_match(cmp_value)
|
||||
|
||||
if not isinstance(table_value, type(cmp_value)):
|
||||
raise ValueError('Type difference {} != {}'.format(
|
||||
type(table_value), type(cmp_value)))
|
||||
|
||||
if not _compare(cond.operator, table_value, cmp_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_float_answer(table,
|
||||
answer_coordinates,
|
||||
aggregation_op):
|
||||
"""Applies operation to produce reference float answer."""
|
||||
if not answer_coordinates:
|
||||
if aggregation_op == _Aggregation.COUNT:
|
||||
return 0.0
|
||||
else:
|
||||
return EMPTY_ANSWER_AGG
|
||||
|
||||
# Count can support non numeric answers.
|
||||
if aggregation_op == _Aggregation.COUNT:
|
||||
return float(len(answer_coordinates))
|
||||
|
||||
# If we have just one answer, if float returns it or try a conversion.
|
||||
values = [table['rows'][i][j] for (i, j) in answer_coordinates]
|
||||
if len(answer_coordinates) == 1:
|
||||
try:
|
||||
return convert_to_float(values[0])
|
||||
except ValueError as e:
|
||||
if aggregation_op != _Aggregation.NONE:
|
||||
raise e
|
||||
|
||||
if aggregation_op == _Aggregation.NONE:
|
||||
return None
|
||||
|
||||
# Other aggregation only support numeric values. Bail out if we have strings.
|
||||
if not all((isinstance(v, (int, float)) for v in values)):
|
||||
return None
|
||||
|
||||
if aggregation_op == _Aggregation.SUM:
|
||||
return float(sum(values))
|
||||
elif aggregation_op == _Aggregation.AVERAGE:
|
||||
return sum(values) / len(answer_coordinates)
|
||||
else:
|
||||
raise ValueError(f'Unknown aggregation: {aggregation_op}')
|
||||
|
||||
|
||||
def _get_answer_coordinates(table, example):
|
||||
"""Retrieves references coordinates by executing SQL."""
|
||||
# MAX and MIN are automatically supported by the model.
|
||||
aggregation_op_index = example['sql']['agg']
|
||||
if aggregation_op_index >= 3:
|
||||
aggregation_op = _Aggregation(aggregation_op_index)
|
||||
else:
|
||||
aggregation_op = _Aggregation.NONE
|
||||
|
||||
target_column = example['sql']['sel']
|
||||
conditions = [
|
||||
_Condition(column, _Operator(operator), cmp_value)
|
||||
for column, operator, cmp_value in example['sql']['conds']
|
||||
]
|
||||
|
||||
indices = []
|
||||
for row in range(len(table['rows'])):
|
||||
if _respect_conditions(table, table['rows'][row], conditions):
|
||||
indices.append((row, target_column))
|
||||
|
||||
if not indices:
|
||||
return [], aggregation_op
|
||||
|
||||
if len(indices) == 1:
|
||||
return indices, aggregation_op
|
||||
|
||||
# Parsing of MIN/MAX.
|
||||
if aggregation_op_index in (1, 2):
|
||||
operators = {2: min, 1: max}
|
||||
values = [
|
||||
(table['rows'][i][j], index) for index, (i, j) in enumerate(indices)
|
||||
]
|
||||
reduced = functools.reduce(operators[example['sql']['agg']], values)
|
||||
|
||||
ret = [indices[reduced[1]]]
|
||||
return ret, _Aggregation.NONE
|
||||
|
||||
return indices, aggregation_op
|
||||
|
||||
|
||||
def _get_answer_text(table,
|
||||
answer_coordinates,
|
||||
float_answer):
|
||||
if float_answer is not None:
|
||||
return [str(float_answer)]
|
||||
return [str(table['real_rows'][r][c]) for r, c in answer_coordinates]
|
||||
|
||||
|
||||
def retrieve_wikisql_query_answer_tapas(table, example):
|
||||
answer_coordinates, aggregation_op = \
|
||||
_get_answer_coordinates(table, example)
|
||||
float_answer = _get_float_answer(table, answer_coordinates,
|
||||
aggregation_op)
|
||||
answer_text = _get_answer_text(table, answer_coordinates, float_answer)
|
||||
# keep the original data the same with TaPas
|
||||
if len(answer_text) == 0:
|
||||
answer_text = [EMPTY_ANSWER]
|
||||
return answer_text
|
|
@ -0,0 +1,240 @@
|
|||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from wikisql_common import detokenize
|
||||
|
||||
re_whitespace = re.compile(r'\s+', flags=re.UNICODE)
|
||||
|
||||
|
||||
class Query:
|
||||
agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
|
||||
cond_ops = ['=', '>', '<', 'OP']
|
||||
syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG',
|
||||
'AGGOPS', 'CONDOPS']
|
||||
|
||||
def __init__(self, sel_index, agg_index, conditions=tuple(), ordered=False):
|
||||
self.sel_index = sel_index
|
||||
self.agg_index = agg_index
|
||||
self.conditions = list(conditions)
|
||||
self.ordered = ordered
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
indices = self.sel_index == other.sel_index and self.agg_index == other.agg_index
|
||||
if other.ordered:
|
||||
conds = [(col, op, str(cond).lower()) for col, op, cond in self.conditions] == [
|
||||
(col, op, str(cond).lower()) for col, op, cond in other.conditions]
|
||||
else:
|
||||
conds = set([(col, op, str(cond).lower()) for col, op, cond in self.conditions]) == set(
|
||||
[(col, op, str(cond).lower()) for col, op, cond in other.conditions])
|
||||
|
||||
return indices and conds
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return not self.__eq__(other)
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(sorted(self.__dict__.items())))
|
||||
|
||||
def __repr__(self):
|
||||
rep = 'SELECT {agg} {sel} FROM table'.format(
|
||||
agg=self.agg_ops[self.agg_index],
|
||||
sel='col{}'.format(self.sel_index),
|
||||
)
|
||||
if self.conditions:
|
||||
rep += ' WHERE ' + ' AND '.join(
|
||||
['{} {} {}'.format('col{}'.format(i), self.cond_ops[o], v) for i, o, v in self.conditions])
|
||||
return rep
|
||||
|
||||
def to_dict(self):
|
||||
return {'sel': self.sel_index, 'agg': self.agg_index, 'conds': self.conditions}
|
||||
|
||||
def lower(self):
|
||||
conds = []
|
||||
for col, op, cond in self.conditions:
|
||||
conds.append([col, op, cond.lower()])
|
||||
return self.__class__(self.sel_index, self.agg_index, conds)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d, ordered=False):
|
||||
return cls(sel_index=d['sel'], agg_index=d['agg'], conditions=d['conds'], ordered=ordered)
|
||||
|
||||
@classmethod
|
||||
def from_tokenized_dict(cls, d):
|
||||
conds = []
|
||||
for col, op, val in d['conds']:
|
||||
conds.append([col, op, detokenize(val)])
|
||||
return cls(d['sel'], d['agg'], conds)
|
||||
|
||||
@classmethod
|
||||
def from_generated_dict(cls, d):
|
||||
conds = []
|
||||
for col, op, val in d['conds']:
|
||||
end = len(val['words'])
|
||||
conds.append([col, op, detokenize(val)])
|
||||
return cls(d['sel'], d['agg'], conds)
|
||||
|
||||
@classmethod
|
||||
def from_sequence(cls, sequence, table, lowercase=True):
|
||||
sequence = deepcopy(sequence)
|
||||
if 'symend' in sequence['words']:
|
||||
end = sequence['words'].index('symend')
|
||||
for k, v in sequence.items():
|
||||
sequence[k] = v[:end]
|
||||
terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in
|
||||
zip(sequence['gloss'], sequence['words'], sequence['after'])]
|
||||
headers = [detokenize(h) for h in table['header']]
|
||||
|
||||
# lowercase everything and truncate sequence
|
||||
if lowercase:
|
||||
headers = [h.lower() for h in headers]
|
||||
for i, t in enumerate(terms):
|
||||
for k, v in t.items():
|
||||
t[k] = v.lower()
|
||||
headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers]
|
||||
|
||||
# get select
|
||||
if 'symselect' != terms.pop(0)['word']:
|
||||
raise Exception('Missing symselect operator')
|
||||
|
||||
# get aggregation
|
||||
if 'symagg' != terms.pop(0)['word']:
|
||||
raise Exception('Missing symagg operator')
|
||||
agg_op = terms.pop(0)['word']
|
||||
|
||||
if agg_op == 'symcol':
|
||||
agg_op = ''
|
||||
else:
|
||||
if 'symcol' != terms.pop(0)['word']:
|
||||
raise Exception('Missing aggregation column')
|
||||
try:
|
||||
agg_op = cls.agg_ops.index(agg_op.upper())
|
||||
except Exception as e:
|
||||
raise Exception('Invalid agg op {}'.format(agg_op))
|
||||
|
||||
def find_column(name):
|
||||
return headers_no_whitespcae.index(re.sub(re_whitespace, '', name))
|
||||
|
||||
def flatten(tokens):
|
||||
ret = {'words': [], 'after': [], 'gloss': []}
|
||||
for t in tokens:
|
||||
ret['words'].append(t['word'])
|
||||
ret['after'].append(t['after'])
|
||||
ret['gloss'].append(t['gloss'])
|
||||
return ret
|
||||
|
||||
where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere']
|
||||
where_index = where_index[0] if where_index else len(terms)
|
||||
flat = flatten(terms[:where_index])
|
||||
try:
|
||||
agg_col = find_column(detokenize(flat))
|
||||
except Exception as e:
|
||||
raise Exception('Cannot find aggregation column {}'.format(flat['words']))
|
||||
where_terms = terms[where_index + 1:]
|
||||
|
||||
# get conditions
|
||||
conditions = []
|
||||
while where_terms:
|
||||
t = where_terms.pop(0)
|
||||
flat = flatten(where_terms)
|
||||
if t['word'] != 'symcol':
|
||||
raise Exception('Missing conditional column {}'.format(flat['words']))
|
||||
try:
|
||||
op_index = flat['words'].index('symop')
|
||||
col_tokens = flatten(where_terms[:op_index])
|
||||
except Exception as e:
|
||||
raise Exception('Missing conditional operator {}'.format(flat['words']))
|
||||
cond_op = where_terms[op_index + 1]['word']
|
||||
try:
|
||||
cond_op = cls.cond_ops.index(cond_op.upper())
|
||||
except Exception as e:
|
||||
raise Exception('Invalid cond op {}'.format(cond_op))
|
||||
try:
|
||||
cond_col = find_column(detokenize(col_tokens))
|
||||
except Exception as e:
|
||||
raise Exception('Cannot find conditional column {}'.format(col_tokens['words']))
|
||||
try:
|
||||
val_index = flat['words'].index('symcond')
|
||||
except Exception as e:
|
||||
raise Exception('Cannot find conditional value {}'.format(flat['words']))
|
||||
|
||||
where_terms = where_terms[val_index + 1:]
|
||||
flat = flatten(where_terms)
|
||||
val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms)
|
||||
cond_val = detokenize(flatten(where_terms[:val_end_index]))
|
||||
conditions.append([cond_col, cond_op, cond_val])
|
||||
where_terms = where_terms[val_end_index + 1:]
|
||||
q = cls(agg_col, agg_op, conditions)
|
||||
return q
|
||||
|
||||
@classmethod
|
||||
def from_partial_sequence(cls, agg_col, agg_op, sequence, table, lowercase=True):
|
||||
sequence = deepcopy(sequence)
|
||||
if 'symend' in sequence['words']:
|
||||
end = sequence['words'].index('symend')
|
||||
for k, v in sequence.items():
|
||||
sequence[k] = v[:end]
|
||||
terms = [{'gloss': g, 'word': w, 'after': a} for g, w, a in
|
||||
zip(sequence['gloss'], sequence['words'], sequence['after'])]
|
||||
headers = [detokenize(h) for h in table['header']]
|
||||
|
||||
# lowercase everything and truncate sequence
|
||||
if lowercase:
|
||||
headers = [h.lower() for h in headers]
|
||||
for i, t in enumerate(terms):
|
||||
for k, v in t.items():
|
||||
t[k] = v.lower()
|
||||
headers_no_whitespcae = [re.sub(re_whitespace, '', h) for h in headers]
|
||||
|
||||
def find_column(name):
|
||||
return headers_no_whitespcae.index(re.sub(re_whitespace, '', name))
|
||||
|
||||
def flatten(tokens):
|
||||
ret = {'words': [], 'after': [], 'gloss': []}
|
||||
for t in tokens:
|
||||
ret['words'].append(t['word'])
|
||||
ret['after'].append(t['after'])
|
||||
ret['gloss'].append(t['gloss'])
|
||||
return ret
|
||||
|
||||
where_index = [i for i, t in enumerate(terms) if t['word'] == 'symwhere']
|
||||
where_index = where_index[0] if where_index else len(terms)
|
||||
where_terms = terms[where_index + 1:]
|
||||
|
||||
# get conditions
|
||||
conditions = []
|
||||
while where_terms:
|
||||
t = where_terms.pop(0)
|
||||
flat = flatten(where_terms)
|
||||
if t['word'] != 'symcol':
|
||||
raise Exception('Missing conditional column {}'.format(flat['words']))
|
||||
try:
|
||||
op_index = flat['words'].index('symop')
|
||||
col_tokens = flatten(where_terms[:op_index])
|
||||
except Exception as e:
|
||||
raise Exception('Missing conditional operator {}'.format(flat['words']))
|
||||
cond_op = where_terms[op_index + 1]['word']
|
||||
try:
|
||||
cond_op = cls.cond_ops.index(cond_op.upper())
|
||||
except Exception as e:
|
||||
raise Exception('Invalid cond op {}'.format(cond_op))
|
||||
try:
|
||||
cond_col = find_column(detokenize(col_tokens))
|
||||
except Exception as e:
|
||||
raise Exception('Cannot find conditional column {}'.format(col_tokens['words']))
|
||||
try:
|
||||
val_index = flat['words'].index('symcond')
|
||||
except Exception as e:
|
||||
raise Exception('Cannot find conditional value {}'.format(flat['words']))
|
||||
|
||||
where_terms = where_terms[val_index + 1:]
|
||||
flat = flatten(where_terms)
|
||||
val_end_index = flat['words'].index('symand') if 'symand' in flat['words'] else len(where_terms)
|
||||
cond_val = detokenize(flatten(where_terms[:val_end_index]))
|
||||
conditions.append([cond_col, cond_op, cond_val])
|
||||
where_terms = where_terms[val_end_index + 1:]
|
||||
q = cls(agg_col, agg_op, conditions)
|
||||
return q
|
|
@ -0,0 +1,10 @@
|
|||
def count_lines(fname):
|
||||
with open(fname, "r", encoding="utf8") as f:
|
||||
return sum(1 for line in f)
|
||||
|
||||
|
||||
def detokenize(tokens):
|
||||
ret = ''
|
||||
for g, a in zip(tokens['gloss'], tokens['after']):
|
||||
ret += g + a
|
||||
return ret.strip()
|
Загрузка…
Ссылка в новой задаче