PointerSQL/execute.py

211 строки
4.5 KiB
Python
Исходник Постоянная ссылка Обычный вид История

2018-07-17 01:06:18 +03:00
import os
import json
from pprint import pprint
import re
# the folder containing annotated tables
db_folder = "wikisql_data/annotated"
# the table files used
db_file = "test.tables.jsonl"
# the files generated by nl2prog_lite program
result_file = "test_top_1.log"
def load_annotated_tables(file):
tables = {}
with open(file, "r") as f:
for l in f.readlines():
raw = json.loads(l.replace("`","\'"))
tables[raw["id"]] = raw
return tables
def load_queries(file):
queries = []
with open(file, "r") as f:
lines = f.readlines()
for i in range(len(lines)):
if i % 4 == 0:
h = lines[i]
if i % 4 == 1:
y = lines[i]
if i % 4 == 2:
p = lines[i]
queries.append((y, p, h))
return queries
def simple_parser(q):
q_list = q.strip().split()
cond = []
for i in range(3, len(q_list) - 1, 3):
l = q_list[i]
op = q_list[i+1]
r = q_list[i+2]
cond.append((l, op, r))
if op not in ["=", "<", ">", "<>", ">=", "<="]:
print(op)
sys.exit(-1)
q = {
"table": q_list[0],
"aggr": q_list[1],
"selcol": q_list[2],
"cond": cond
}
return q
def cond_check(row, conds):
check = True
for cond in conds:
lval = normalized(row[cond[0]])
rval = normalized(cond[2])
if cond[1] == "=":
check = check and try_equal(lval, rval)
elif cond[1] == "<":
check = check and to_float(lval) <= to_float(rval)
elif cond[1] == ">":
check = check and to_float(lval) >= to_float(rval)
elif cond[1] == "<=":
check = check and to_float(lval) <= to_float(rval)
elif cond[1] == ">=":
check = check and to_float(lval) >= to_float(rval)
else:
print(cond)
sys.exit(-1)
return check
def normalized(v):
x = v.replace("(", "-lrb-").replace(")", "-rrb-").replace("`","\'")
return re.sub(r'[^\x00-\x7F]+','',x).decode('utf-8','ignore').strip()
def to_float(v):
return float(v.replace(",",""))
def try_equal(v1, v2):
if v1 == v2:
return True
try:
v1 = to_float(v1)
v2 = to_float(v2)
return v1 == v2
except:
pass
return False
def execute(table, q):
header = table["header"]
header_to_index = {}
for i in range(len(header)):
header_to_index[normalized(header[i])] = i
conds = [(header_to_index[normalized(c[0])], c[1], c[2]) for c in q["cond"]]
filtered_row = [r for r in table["content"] if cond_check(r, conds)]
if len(filtered_row) == 0:
return None
sel_index = header_to_index[normalized(q["selcol"])]
selected_val = [r[sel_index] for r in filtered_row]
if q["aggr"] == "select":
return selected_val[0]
elif q["aggr"] == "count":
return len(selected_val)
elif q["aggr"] == "avg":
return sum([to_float(v) for v in selected_val]) / len(selected_val)
elif q["aggr"] == "sum":
return sum([to_float(v) for v in selected_val])
elif q["aggr"] == "max":
return max([to_float(v) for v in selected_val])
elif q["aggr"] == "min":
return min([to_float(v) for v in selected_val])
def main():
all_wrong_ones = []
# load tables and queries
tables = load_annotated_tables(os.path.join(db_folder, db_file))
queries = load_queries(result_file)
lf_equal = 0.
ex_equal = 0.
total_num = 402.
wrong_a = 0
wrong_b = 0
wrong_c = 0
wrong_d = 0
wrong_e = 0
for raw_q1, raw_q2, h in queries:
total_num += 1
if raw_q1.strip() == raw_q2.strip():
lf_equal += 1
# else:
# print(int(total_num), raw_q1.strip(), raw_q2.strip())
# compare execution equivalence
try:
q1 = simple_parser(raw_q1)
q2 = simple_parser(raw_q2)
except:
wrong_b += 1
continue
if q1["table"] != q2["table"]:
continue
table = tables[q1["table"]]
try:
try:
v1 = execute(table, q1)
except:
wrong_e += 1
v1 = execute(table, q1)
if v1 == None:
wrong_c += 1
v2 = execute(table, q2)
if v1 != None and v1 == v2:
ex_equal += 1
else:
all_wrong_ones.append((raw_q1, raw_q2, h))
if v2 == None:
wrong_d += 1
except:
all_wrong_ones.append((raw_q1, raw_q2, h))
wrong_a += 1
print("#Q2 (predition) result is wrong: {}".format(wrong_a - wrong_e + wrong_d))
print("#Q1 or Q2 fail to parse: {}".format(wrong_b))
print("#Q1 (ground truth) exec to None: {}".format(wrong_c))
print("#Q1 (ground truth) failed to execute: {}".format(wrong_e))
print('Logical Form Accuracy: {}'.format(lf_equal / total_num))
print('Execute Accuracy: {}'.format(ex_equal / total_num))
with open("tmp_wrong_queries.txt", "w") as f:
for p in all_wrong_ones:
f.write(p[2])
f.write(p[0])
f.write(p[1])
f.write("\n")
if __name__ == '__main__':
main()