211 строки
4.5 KiB
Python
211 строки
4.5 KiB
Python
|
import os
|
||
|
import json
|
||
|
from pprint import pprint
|
||
|
import re
|
||
|
|
||
|
# the folder containing annotated tables
|
||
|
db_folder = "wikisql_data/annotated"
|
||
|
# the table files used
|
||
|
db_file = "test.tables.jsonl"
|
||
|
# the files generated by nl2prog_lite program
|
||
|
result_file = "test_top_1.log"
|
||
|
|
||
|
def load_annotated_tables(file):
|
||
|
tables = {}
|
||
|
with open(file, "r") as f:
|
||
|
for l in f.readlines():
|
||
|
raw = json.loads(l.replace("`","\'"))
|
||
|
tables[raw["id"]] = raw
|
||
|
return tables
|
||
|
|
||
|
def load_queries(file):
|
||
|
queries = []
|
||
|
with open(file, "r") as f:
|
||
|
lines = f.readlines()
|
||
|
for i in range(len(lines)):
|
||
|
if i % 4 == 0:
|
||
|
h = lines[i]
|
||
|
if i % 4 == 1:
|
||
|
y = lines[i]
|
||
|
if i % 4 == 2:
|
||
|
p = lines[i]
|
||
|
queries.append((y, p, h))
|
||
|
return queries
|
||
|
|
||
|
def simple_parser(q):
|
||
|
|
||
|
q_list = q.strip().split()
|
||
|
cond = []
|
||
|
|
||
|
for i in range(3, len(q_list) - 1, 3):
|
||
|
l = q_list[i]
|
||
|
op = q_list[i+1]
|
||
|
r = q_list[i+2]
|
||
|
cond.append((l, op, r))
|
||
|
|
||
|
if op not in ["=", "<", ">", "<>", ">=", "<="]:
|
||
|
print(op)
|
||
|
sys.exit(-1)
|
||
|
|
||
|
q = {
|
||
|
"table": q_list[0],
|
||
|
"aggr": q_list[1],
|
||
|
"selcol": q_list[2],
|
||
|
"cond": cond
|
||
|
}
|
||
|
|
||
|
return q
|
||
|
|
||
|
def cond_check(row, conds):
|
||
|
|
||
|
check = True
|
||
|
|
||
|
for cond in conds:
|
||
|
lval = normalized(row[cond[0]])
|
||
|
rval = normalized(cond[2])
|
||
|
|
||
|
if cond[1] == "=":
|
||
|
check = check and try_equal(lval, rval)
|
||
|
elif cond[1] == "<":
|
||
|
check = check and to_float(lval) <= to_float(rval)
|
||
|
elif cond[1] == ">":
|
||
|
check = check and to_float(lval) >= to_float(rval)
|
||
|
elif cond[1] == "<=":
|
||
|
check = check and to_float(lval) <= to_float(rval)
|
||
|
elif cond[1] == ">=":
|
||
|
check = check and to_float(lval) >= to_float(rval)
|
||
|
else:
|
||
|
print(cond)
|
||
|
sys.exit(-1)
|
||
|
|
||
|
return check
|
||
|
|
||
|
def normalized(v):
|
||
|
x = v.replace("(", "-lrb-").replace(")", "-rrb-").replace("`","\'")
|
||
|
return re.sub(r'[^\x00-\x7F]+','',x).decode('utf-8','ignore').strip()
|
||
|
|
||
|
def to_float(v):
|
||
|
return float(v.replace(",",""))
|
||
|
|
||
|
def try_equal(v1, v2):
|
||
|
if v1 == v2:
|
||
|
return True
|
||
|
try:
|
||
|
v1 = to_float(v1)
|
||
|
v2 = to_float(v2)
|
||
|
return v1 == v2
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
return False
|
||
|
|
||
|
def execute(table, q):
|
||
|
|
||
|
header = table["header"]
|
||
|
header_to_index = {}
|
||
|
for i in range(len(header)):
|
||
|
header_to_index[normalized(header[i])] = i
|
||
|
|
||
|
conds = [(header_to_index[normalized(c[0])], c[1], c[2]) for c in q["cond"]]
|
||
|
|
||
|
filtered_row = [r for r in table["content"] if cond_check(r, conds)]
|
||
|
|
||
|
if len(filtered_row) == 0:
|
||
|
return None
|
||
|
|
||
|
sel_index = header_to_index[normalized(q["selcol"])]
|
||
|
|
||
|
selected_val = [r[sel_index] for r in filtered_row]
|
||
|
|
||
|
if q["aggr"] == "select":
|
||
|
return selected_val[0]
|
||
|
elif q["aggr"] == "count":
|
||
|
return len(selected_val)
|
||
|
elif q["aggr"] == "avg":
|
||
|
return sum([to_float(v) for v in selected_val]) / len(selected_val)
|
||
|
elif q["aggr"] == "sum":
|
||
|
return sum([to_float(v) for v in selected_val])
|
||
|
elif q["aggr"] == "max":
|
||
|
return max([to_float(v) for v in selected_val])
|
||
|
elif q["aggr"] == "min":
|
||
|
return min([to_float(v) for v in selected_val])
|
||
|
|
||
|
|
||
|
def main():
|
||
|
|
||
|
all_wrong_ones = []
|
||
|
|
||
|
# load tables and queries
|
||
|
tables = load_annotated_tables(os.path.join(db_folder, db_file))
|
||
|
queries = load_queries(result_file)
|
||
|
|
||
|
lf_equal = 0.
|
||
|
ex_equal = 0.
|
||
|
total_num = 402.
|
||
|
|
||
|
wrong_a = 0
|
||
|
wrong_b = 0
|
||
|
wrong_c = 0
|
||
|
wrong_d = 0
|
||
|
wrong_e = 0
|
||
|
|
||
|
for raw_q1, raw_q2, h in queries:
|
||
|
total_num += 1
|
||
|
|
||
|
if raw_q1.strip() == raw_q2.strip():
|
||
|
lf_equal += 1
|
||
|
# else:
|
||
|
# print(int(total_num), raw_q1.strip(), raw_q2.strip())
|
||
|
|
||
|
# compare execution equivalence
|
||
|
try:
|
||
|
q1 = simple_parser(raw_q1)
|
||
|
q2 = simple_parser(raw_q2)
|
||
|
except:
|
||
|
wrong_b += 1
|
||
|
continue
|
||
|
|
||
|
if q1["table"] != q2["table"]:
|
||
|
continue
|
||
|
|
||
|
table = tables[q1["table"]]
|
||
|
|
||
|
try:
|
||
|
try:
|
||
|
v1 = execute(table, q1)
|
||
|
except:
|
||
|
wrong_e += 1
|
||
|
v1 = execute(table, q1)
|
||
|
if v1 == None:
|
||
|
wrong_c += 1
|
||
|
v2 = execute(table, q2)
|
||
|
if v1 != None and v1 == v2:
|
||
|
ex_equal += 1
|
||
|
else:
|
||
|
all_wrong_ones.append((raw_q1, raw_q2, h))
|
||
|
if v2 == None:
|
||
|
wrong_d += 1
|
||
|
except:
|
||
|
all_wrong_ones.append((raw_q1, raw_q2, h))
|
||
|
wrong_a += 1
|
||
|
|
||
|
print("#Q2 (predition) result is wrong: {}".format(wrong_a - wrong_e + wrong_d))
|
||
|
print("#Q1 or Q2 fail to parse: {}".format(wrong_b))
|
||
|
print("#Q1 (ground truth) exec to None: {}".format(wrong_c))
|
||
|
print("#Q1 (ground truth) failed to execute: {}".format(wrong_e))
|
||
|
|
||
|
print('Logical Form Accuracy: {}'.format(lf_equal / total_num))
|
||
|
print('Execute Accuracy: {}'.format(ex_equal / total_num))
|
||
|
|
||
|
with open("tmp_wrong_queries.txt", "w") as f:
|
||
|
for p in all_wrong_ones:
|
||
|
f.write(p[2])
|
||
|
f.write(p[0])
|
||
|
f.write(p[1])
|
||
|
f.write("\n")
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|