ContextualSP/interactive_text_to_sql/parsers/parser.py

101 строка
3.8 KiB
Python

# coding: utf-8
import json
import dill
import hashlib
import os
import torch
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import Instance
from allennlp.data import Vocabulary
from allennlp.models.archival import load_archive
from parsers.irnet.context.converter import ActionConverter
from parsers.irnet.dataset_reader.spider_reader import SpiderDatasetReader
from parsers.irnet.models.sparc_parser import SparcParser
class Parser:
def __init__(self, model: torch.nn.Module):
assert model is not None
model.eval()
self.model = model
def parse(self, example):
# requirement: 'predict_sql' or 'predict_semql' must in returned dict
raise NotImplementedError()
class IRNetSpiderParser(Parser):
def __init__(self, model):
super().__init__(model)
self.spider_dataset_reader = SpiderDatasetReader()
self.sha1 = hashlib.sha1()
def parse(self, example):
hash_id = self.hash_dict(example)[:7]
if os.path.exists(f'cache/spider_instance/{hash_id}.bin'):
instance = dill.load(open(f'cache/spider_instance/{hash_id}.bin', 'rb'))
else:
db_id = example['db_id']
inter_utter_list = [example['question']]
sql_list = [example['sql']]
sql_query_list = [example['query']]
instance = self.spider_dataset_reader.text_to_instance(
utter_list=inter_utter_list,
db_id=db_id,
sql_list=sql_list,
sql_query_list=sql_query_list
)
dill.dump(instance, open(f'cache/spider_instance/{hash_id}.bin', 'wb'))
parsed_result = self.parse_instance(instance)
return parsed_result
def parse_instance(self, instance: Instance) -> JsonDict:
# convert predict result into production rule string
index_to_rule = [production_rule_field.rule
for production_rule_field in instance.fields['valid_actions_list'].field_list[0].field_list]
# Now get result
results = sanitize(self.model.forward_on_instance(instance))
rule_repr = [index_to_rule[ind] for ind in results['best_predict']]
ground_rule_repr = [index_to_rule[ind] for ind in results['ground_truth']]
db_context = instance.fields['worlds'].field_list[0].metadata.db_context
action_converter = ActionConverter(db_context)
predict_sql = action_converter.translate_to_sql(rule_repr)
ground_sql = action_converter.translate_to_sql(ground_rule_repr)
dis_results = {'predict': rule_repr,
'predict_sql': predict_sql,
'ground': ground_rule_repr,
'ground_sql': ground_sql,
'table_content': results['table_content']}
return dis_results
def hash_dict(self, d):
dict_str = json.dumps(d)
self.sha1.update(bytes(dict_str, encoding='utf-8'))
hex = self.sha1.hexdigest()
return hex
@staticmethod
def get_parser():
dataset_path = 'data/datasets/spider'
vocab = Vocabulary.from_files('parsers/irnet/checkpoints/v1.0_spider_baseline_model/vocabulary')
overrides = {
"dataset_path": dataset_path,
"train_data_path": "train.json",
"validation_data_path": "dev.json"
}
parser_model = load_archive('parsers/irnet/checkpoints/v1.0_spider_baseline_model/model.tar.gz',
cuda_device=0,
overrides=json.dumps(overrides)).model
parser_model.sql_metric_util._evaluator.update_dataset_path(dataset_path=dataset_path)
parser = IRNetSpiderParser(model=parser_model)
return parser
if __name__ == '__main__':
parser: Parser = IRNetSpiderParser.get_parser()