From 2b59163b3cca9922098c19895943b2c9e57c3447 Mon Sep 17 00:00:00 2001 From: SivilTaram Date: Tue, 13 Oct 2020 11:13:44 +0800 Subject: [PATCH] Fix the issue #1 (Traceback warning for semantic parsing in context) --- .../context/converter.py | 2 +- .../test_sql_to_semql.py | 65 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 semantic_parsing_in_context/test_sql_to_semql.py diff --git a/semantic_parsing_in_context/context/converter.py b/semantic_parsing_in_context/context/converter.py index 72850af..8e1b3ab 100644 --- a/semantic_parsing_in_context/context/converter.py +++ b/semantic_parsing_in_context/context/converter.py @@ -256,7 +256,7 @@ class SQLConverter(object): None ] """ - if isinstance(condition, list): + if isinstance(condition, list) and isinstance(condition[2], list) and isinstance(condition[3], list): from_cond_col_inds = [condition[2][1][1], condition[3][1]] for col_ind in from_cond_col_inds: if self.col_names[col_ind].refer_table.name == join_tab_name: diff --git a/semantic_parsing_in_context/test_sql_to_semql.py b/semantic_parsing_in_context/test_sql_to_semql.py new file mode 100644 index 0000000..0e97b5a --- /dev/null +++ b/semantic_parsing_in_context/test_sql_to_semql.py @@ -0,0 +1,65 @@ +import json +from context.converter import SQLConverter, SparcDBContext +import unittest +from allennlp.data.tokenizers import WordTokenizer + + +class TestSQLToSemQL(unittest.TestCase): + + @staticmethod + def template(sql_plain, sql_text, db_id, expected_str): + sql_clause = json.loads(sql_text) + db_context = SparcDBContext(db_id=db_id, + utterance=[], + tokenizer=WordTokenizer(), + # TODO: Please first config the dataset path you want to test + tables_file="dataset_sparc\\tables.json", + database_path="dataset_sparc\\database") + converter = SQLConverter(db_context=db_context) + inter_seq = converter.translate_to_intermediate(sql_clause=sql_clause) + assert str(inter_seq) == expected_str, \ + f'\nSQL:\t\t{sql_plain}\nExp:\t\t{expected_str}\nPred:\t\t{str(inter_seq)}\n' + + def test_example(self): + db_id = "flight_2" + sql_plain = "SELECT * FROM AIRLINES" + sql_clause = """ + { + "orderBy": [], + "from": { + "table_units": [ + [ + "table_unit", + 0 + ] + ], + "conds": [] + }, + "union": null, + "except": null, + "groupBy": [], + "limit": null, + "intersect": null, + "where": [], + "having": [], + "select": [ + false, + [ + [ + 0, + [ + 0, + [ + 0, + 0, + false + ], + null + ] + ] + ] + ] + } + """ + expected_action_str = "[Statement -> Root, Root -> Select, Select -> A, A -> none C T, C -> *, T -> airlines]" + self.template(sql_plain, sql_clause, db_id, expected_action_str)