Fix the issue #1 (Traceback warning for semantic parsing in context)

This commit is contained in:
SivilTaram 2020-10-13 11:13:44 +08:00
Родитель 707b25588e
Коммит 2b59163b3c
2 изменённых файлов: 66 добавлений и 1 удалений

Просмотреть файл

@ -256,7 +256,7 @@ class SQLConverter(object):
None
]
"""
if isinstance(condition, list):
if isinstance(condition, list) and isinstance(condition[2], list) and isinstance(condition[3], list):
from_cond_col_inds = [condition[2][1][1], condition[3][1]]
for col_ind in from_cond_col_inds:
if self.col_names[col_ind].refer_table.name == join_tab_name:

Просмотреть файл

@ -0,0 +1,65 @@
import json
from context.converter import SQLConverter, SparcDBContext
import unittest
from allennlp.data.tokenizers import WordTokenizer
class TestSQLToSemQL(unittest.TestCase):
@staticmethod
def template(sql_plain, sql_text, db_id, expected_str):
sql_clause = json.loads(sql_text)
db_context = SparcDBContext(db_id=db_id,
utterance=[],
tokenizer=WordTokenizer(),
# TODO: Please first config the dataset path you want to test
tables_file="dataset_sparc\\tables.json",
database_path="dataset_sparc\\database")
converter = SQLConverter(db_context=db_context)
inter_seq = converter.translate_to_intermediate(sql_clause=sql_clause)
assert str(inter_seq) == expected_str, \
f'\nSQL:\t\t{sql_plain}\nExp:\t\t{expected_str}\nPred:\t\t{str(inter_seq)}\n'
def test_example(self):
db_id = "flight_2"
sql_plain = "SELECT * FROM AIRLINES"
sql_clause = """
{
"orderBy": [],
"from": {
"table_units": [
[
"table_unit",
0
]
],
"conds": []
},
"union": null,
"except": null,
"groupBy": [],
"limit": null,
"intersect": null,
"where": [],
"having": [],
"select": [
false,
[
[
0,
[
0,
[
0,
0,
false
],
null
]
]
]
]
}
"""
expected_action_str = "[Statement -> Root, Root -> Select, Select -> A, A -> none C T, C -> *, T -> airlines]"
self.template(sql_plain, sql_clause, db_id, expected_action_str)