import os import json from pprint import pprint import re # the folder containing annotated tables db_folder = "wikisql_data/annotated" # the table files used db_file = "test.tables.jsonl" # the files generated by nl2prog_lite program result_file = "test_top_1.log" def load_annotated_tables(file): tables = {} with open(file, "r") as f: for l in f.readlines(): raw = json.loads(l.replace("`","\'")) tables[raw["id"]] = raw return tables def load_queries(file): queries = [] with open(file, "r") as f: lines = f.readlines() for i in range(len(lines)): if i % 4 == 0: h = lines[i] if i % 4 == 1: y = lines[i] if i % 4 == 2: p = lines[i] queries.append((y, p, h)) return queries def simple_parser(q): q_list = q.strip().split() cond = [] for i in range(3, len(q_list) - 1, 3): l = q_list[i] op = q_list[i+1] r = q_list[i+2] cond.append((l, op, r)) if op not in ["=", "<", ">", "<>", ">=", "<="]: print(op) sys.exit(-1) q = { "table": q_list[0], "aggr": q_list[1], "selcol": q_list[2], "cond": cond } return q def cond_check(row, conds): check = True for cond in conds: lval = normalized(row[cond[0]]) rval = normalized(cond[2]) if cond[1] == "=": check = check and try_equal(lval, rval) elif cond[1] == "<": check = check and to_float(lval) <= to_float(rval) elif cond[1] == ">": check = check and to_float(lval) >= to_float(rval) elif cond[1] == "<=": check = check and to_float(lval) <= to_float(rval) elif cond[1] == ">=": check = check and to_float(lval) >= to_float(rval) else: print(cond) sys.exit(-1) return check def normalized(v): x = v.replace("(", "-lrb-").replace(")", "-rrb-").replace("`","\'") return re.sub(r'[^\x00-\x7F]+','',x).decode('utf-8','ignore').strip() def to_float(v): return float(v.replace(",","")) def try_equal(v1, v2): if v1 == v2: return True try: v1 = to_float(v1) v2 = to_float(v2) return v1 == v2 except: pass return False def execute(table, q): header = table["header"] header_to_index = {} for i in range(len(header)): header_to_index[normalized(header[i])] = i conds = [(header_to_index[normalized(c[0])], c[1], c[2]) for c in q["cond"]] filtered_row = [r for r in table["content"] if cond_check(r, conds)] if len(filtered_row) == 0: return None sel_index = header_to_index[normalized(q["selcol"])] selected_val = [r[sel_index] for r in filtered_row] if q["aggr"] == "select": return selected_val[0] elif q["aggr"] == "count": return len(selected_val) elif q["aggr"] == "avg": return sum([to_float(v) for v in selected_val]) / len(selected_val) elif q["aggr"] == "sum": return sum([to_float(v) for v in selected_val]) elif q["aggr"] == "max": return max([to_float(v) for v in selected_val]) elif q["aggr"] == "min": return min([to_float(v) for v in selected_val]) def main(): all_wrong_ones = [] # load tables and queries tables = load_annotated_tables(os.path.join(db_folder, db_file)) queries = load_queries(result_file) lf_equal = 0. ex_equal = 0. total_num = 402. wrong_a = 0 wrong_b = 0 wrong_c = 0 wrong_d = 0 wrong_e = 0 for raw_q1, raw_q2, h in queries: total_num += 1 if raw_q1.strip() == raw_q2.strip(): lf_equal += 1 # else: # print(int(total_num), raw_q1.strip(), raw_q2.strip()) # compare execution equivalence try: q1 = simple_parser(raw_q1) q2 = simple_parser(raw_q2) except: wrong_b += 1 continue if q1["table"] != q2["table"]: continue table = tables[q1["table"]] try: try: v1 = execute(table, q1) except: wrong_e += 1 v1 = execute(table, q1) if v1 == None: wrong_c += 1 v2 = execute(table, q2) if v1 != None and v1 == v2: ex_equal += 1 else: all_wrong_ones.append((raw_q1, raw_q2, h)) if v2 == None: wrong_d += 1 except: all_wrong_ones.append((raw_q1, raw_q2, h)) wrong_a += 1 print("#Q2 (predition) result is wrong: {}".format(wrong_a - wrong_e + wrong_d)) print("#Q1 or Q2 fail to parse: {}".format(wrong_b)) print("#Q1 (ground truth) exec to None: {}".format(wrong_c)) print("#Q1 (ground truth) failed to execute: {}".format(wrong_e)) print('Logical Form Accuracy: {}'.format(lf_equal / total_num)) print('Execute Accuracy: {}'.format(ex_equal / total_num)) with open("tmp_wrong_queries.txt", "w") as f: for p in all_wrong_ones: f.write(p[2]) f.write(p[0]) f.write(p[1]) f.write("\n") if __name__ == '__main__': main()