ContextualSP/unified_parser_text_to_sql/step3_evaluate.py

337 строки
12 KiB
Python
Исходник Постоянная ссылка Обычный вид История

2022-04-14 07:34:57 +03:00
import argparse
import json
import re
import subprocess
from collections import defaultdict
from re import RegexFlag
import networkx as nx
import torch
from genre.fairseq_model import GENRE, mGENRE
from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_fairseq as get_prefix_allowed_tokens_fn
from genre.trie import Trie
from semparse.sql.spider import load_original_schemas, load_tables
from semparse.worlds.evaluate_spider import evaluate as evaluate_sql
from step1_schema_linking import read_database_schema
database_dir='./data/spider/database'
database_schema_filename = './data/spider/tables.json'
schema_tokens, column_names, database_schemas = read_database_schema(database_schema_filename)
with open(f'./data/spider/dev.json', 'r', encoding='utf-8') as f:
item = json.load(f)
sql_to_db = []
for i in item:
sql_to_db.append(i['db_id'])
def post_processing_sql(p_sql, foreign_key_maps, schemas, o_schemas):
foreign_key = {}
for k, v in foreign_key_maps.items():
if k == v:
continue
key = ' '.join(sorted([k.split('.')[0].strip('_'), v.split('.')[0].strip('_')]))
foreign_key[key] = (k.strip('_').replace('.', '@'), v.strip('_').replace('.', '@'))
primary_key = {}
for t in o_schemas.tables:
table = t.orig_name.lower()
if len(t.primary_keys) == 0:
continue
column = t.primary_keys[0].orig_name.lower()
primary_key[table] = f'{table}@{column}'
p_sql = re.sub(r'(=)(\S+)', r'\1 \2', p_sql)
p_sql = p_sql.split()
columns = ['*']
tables = []
for table, column_list in schemas.schema.items():
for column in column_list:
columns.append(f"{table}@{column}")
tables.append(table)
# infer table from mentioned column
all_from_table_ids = set()
from_idx = where_idx = group_idx = order_idx = -1
for idx, token in enumerate(p_sql):
if '@' in token and token in columns:
all_from_table_ids.add(schemas.idMap[token.split('@')[0]])
if token == 'from' and from_idx == -1:
from_idx = idx
if token == 'where' and where_idx == -1:
where_idx = idx
if token == 'group' and group_idx == -1:
group_idx = idx
if token == 'order' and order_idx == -1:
order_idx = idx
#don't process nested SQL (more than one select)
if len(re.findall('select', ' '.join(p_sql))) > 1 or len(all_from_table_ids) == 0:
return ' '.join(p_sql)
covered_tables = set()
candidate_table_ids = sorted(all_from_table_ids)
start_table_id = candidate_table_ids[0]
conds = set()
all_conds = []
for table_id in candidate_table_ids[1:]:
if table_id in covered_tables:
continue
try:
path = nx.shortest_path(
o_schemas.foreign_key_graph,
source=start_table_id,
target=table_id,
)
except (nx.NetworkXNoPath, nx.NodeNotFound):
covered_tables.add(table_id)
continue
for source_table_id, target_table_id in zip(path, path[1:]):
if target_table_id in covered_tables:
continue
covered_tables.add(target_table_id)
all_from_table_ids.add(target_table_id)
col1, col2 = o_schemas.foreign_key_graph[source_table_id][target_table_id]["columns"]
all_conds.append((columns[col1], columns[col2]))
conds.add((tables[source_table_id],
tables[target_table_id],
columns[col1],
columns[col2]))
all_from_table_ids = list(all_from_table_ids)
try:
tokens = ["from", tables[all_from_table_ids[0]]]
for i, table_id in enumerate(all_from_table_ids[1:]):
tokens += ["join"]
tokens += [tables[table_id]]
tokens += ["on", all_conds[i][0], "=", all_conds[i][1]]
except:
return ' '.join(p_sql)
if where_idx != -1:
p_sql = p_sql[:from_idx] + tokens + p_sql[where_idx:]
elif group_idx != -1:
p_sql = p_sql[:from_idx] + tokens + p_sql[group_idx:]
elif order_idx != -1:
p_sql = p_sql[:from_idx] + tokens + p_sql[order_idx:]
elif len(p_sql[:from_idx] + p_sql[from_idx:]) == len(p_sql):
p_sql = p_sql[:from_idx] + tokens
return ' '.join(p_sql)
def extract_structure_data(plain_text_content: str):
def sort_by_id(data):
data.sort(key=lambda x: int(x.split('\t')[0][2:]))
return data
data = []
original_schemas = load_original_schemas(database_schema_filename)
schemas, eval_foreign_key_maps = load_tables(database_schema_filename)
predict_outputs = sort_by_id(re.findall("^D.+", plain_text_content, RegexFlag.MULTILINE))
ground_outputs = sort_by_id(re.findall("^T.+", plain_text_content, RegexFlag.MULTILINE))
source_inputs = sort_by_id(re.findall("^S.+", plain_text_content, RegexFlag.MULTILINE))
for idx, (predict, ground, source) in enumerate(zip(predict_outputs, ground_outputs, source_inputs)):
predict_id, predict_score, predict_clean = predict.split('\t')
ground_id, ground_clean = ground.split('\t')
source_id, source_clean = source.split('\t')
db_id = sql_to_db[idx]
#try to postprocess the incomplete sql from
# (1) correcting the COLUMN in ON_CLAUSE based on foreign key graph
# (2) adding the underlying TABLE via searching shortest path
predict_clean = post_processing_sql(predict_clean, eval_foreign_key_maps[db_id], original_schemas[db_id],
schemas[db_id])
data.append((predict_id[2:], source_clean.split('<Q>')[-1].strip(), ground_clean, predict_clean, db_id))
return data
def evaluate(data):
def evaluate_example(_predict_str: str, _ground_str: str):
return re.sub("\s+", "", _predict_str.lower()) == re.sub("\s+", "", _ground_str.lower())
correct_num = 0
correct_tag_list = []
total = 0
tmp = []
for example in data:
idx, source_str, ground_str, predict_str, db_id = example
total += 1
try:
sql_match = evaluate_sql(gold=ground_str.replace('@', '.'),
predict=predict_str.replace('@', '.'),
db_name=db_id,
db_dir=database_dir,
table=database_schema_filename)
except:
print(predict_str)
sql_match = False
if (sql_match or evaluate_example(predict_str, ground_str)):
is_correct = True
correct_num += 1
else:
is_correct = False
tmp.append(is_correct)
correct_tag_list.append(is_correct)
print("Correct/Total : {}/{}, {:.4f}".format(correct_num, total, correct_num / total))
return correct_tag_list, correct_num, total
def predict_and_evaluate(model_path, dataset_path, constrain):
if constrain:
data = predict_with_constrain(
model_path=model_path,
dataset_path=dataset_path
)
else:
decode_without_constrain(
model_path=model_path,
dataset_path=dataset_path
)
with open('./eval/generate-valid.txt', "r", encoding="utf8") as generate_f:
file_content = generate_f.read()
data = extract_structure_data(file_content)
correct_arr, correct_num, total = evaluate(data)
with open('./eval/spider_eval.txt', "w", encoding="utf8") as eval_file:
for example, correct in zip(data, correct_arr):
eval_file.write(str(correct) + "\n" + "\n".join(
[example[0], "db: " + example[-1], example[1], "gold: " + example[2], "pred: " + example[3]]) + "\n\n")
return correct_num, total
def get_alias_schema(schemas):
alias_schema = {}
for db in schemas:
schema = schemas[db].orig
collect = []
for i, (t, c) in enumerate(zip(schema['column_types'], schema['column_names_original'])):
if c[0] == -1:
collect.append('*')
else:
column_with_alias = "{0}@{1}".format(schema['table_names_original'][c[0]].lower(), c[1].lower())
collect.append(column_with_alias)
for t in schema['table_names_original']:
collect.append(t.lower())
collect.append("'value'")
alias_schema[db] = collect
return alias_schema
def predict_with_constrain(model_path, dataset_path):
schemas, eval_foreign_key_maps = load_tables(database_schema_filename)
original_schemas = load_original_schemas(database_schema_filename)
with open(f'{dataset_path}/dev.src', 'r', encoding='utf-8') as f:
item = [i.strip() for i in f.readlines()]
with open(f'{dataset_path}/dev.tgt', 'r', encoding='utf-8') as f:
ground = [i.strip() for i in f.readlines()]
alias_schema = get_alias_schema(schemas)
item_db_cluster = defaultdict(list)
ground_db_cluster = defaultdict(list)
source_db_cluster = defaultdict(list)
num_example = 1034
for db, sentence, g_sql in zip(sql_to_db[:num_example], item[:num_example], ground[:num_example]):
source = sentence.split('<Q>')[-1].strip()
item_db_cluster[db].append(sentence)
ground_db_cluster[db].append(g_sql)
source_db_cluster[db].append(source)
source = []
ground = []
for db, sentence in source_db_cluster.items():
source.extend(sentence)
for db, g_SQL in ground_db_cluster.items():
ground.extend(g_SQL)
model = GENRE.from_pretrained(model_path).eval()
if torch.cuda.is_available():
model.cuda()
result=[]
for db, sentence in item_db_cluster.items():
print(f'processing db: {db} with {len(sentence)} sentences')
rnt=decode_with_constrain(sentence, alias_schema[db], model)
result.extend([i[0]['text'] if isinstance(i[0]['text'], str) else i[0]['text'][0] for i in rnt])
eval_file_path= f'./eval/generate-valid-constrain.txt'
with open(eval_file_path, "w", encoding="utf8") as f:
f.write('\n'.join(result))
# result = []
# with open(f'./eval/generate-valid-constrain.txt', "r", encoding="utf8") as f:
# for idx, (sent, db_id) in enumerate(zip(f.readlines(), sql_to_db)):
# result.append(sent.strip())
data = []
for predict_id, (predict_clean, ground_clean, source_clean, db_id) in enumerate(
zip(result, ground, source, sql_to_db)):
predict_clean = post_processing_sql(predict_clean, eval_foreign_key_maps[db_id], original_schemas[db_id],
schemas[db_id])
data.append((str(predict_id), source_clean.split('<Q>')[-1].strip(), ground_clean, predict_clean, db_id))
return data
def decode_with_constrain(sentences, schema, model):
trie = Trie([
model.encode(" {}".format(e))[1:].tolist()
for e in schema
])
prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(
model,
sentences,
mention_trie=trie,
)
return model.sample(
sentences,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
def decode_without_constrain( model_path, dataset_path):
cmd = f'fairseq-generate \
--path {model_path}/model.pt {dataset_path}/bin \
--gen-subset valid \
--nbest 1 \
--max-tokens 4096 \
--source-lang src --target-lang tgt \
--results-path ./eval \
--beam 5 \
--bpe gpt2 \
--remove-bpe \
--skip-invalid-size-inputs-valid-test'
subprocess.Popen(
cmd, universal_newlines=True, shell=True,
stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default='./models/spider_sl')
parser.add_argument("--dataset_path", default='./dataset_post/spider_sl')
parser.add_argument("--constrain", action='store_true')
args = parser.parse_args()
predict_and_evaluate(model_path=args.model_path,
dataset_path=args.dataset_path,
constrain=args.constrain)