init unisar
This commit is contained in:
Родитель
eb2335c51e
Коммит
5ae4e368e2
|
@ -0,0 +1,57 @@
|
|||
## Introduction
|
||||
|
||||
This paper introduces [UniSAr](https://arxiv.org/pdf/2203.07781.pdf), which extends existing autoregressive language models to incorporate three non-invasive extensions to make them structure-aware:
|
||||
(1) adding structure mark to encode database schema, conversation context, and their relationships;
|
||||
(2) constrained decoding to decode well structured SQL for a given database schema; and
|
||||
(3) SQL completion to complete potential missing JOIN relationships in SQL based on database schema.
|
||||
|
||||
[//]: # (On seven well-known text-to-SQL datasets covering multi-domain, multi-table and multi-turn, UniSAr demonstrates highly comparable or better performance to the most advanced specifically-designed text-to-SQL models.)
|
||||
|
||||
## Dataset and Model
|
||||
[Spider](https://github.com/taoyds/spider) -> `./data/spider`
|
||||
|
||||
[Fine-tuned BART model](https://huggingface.co/dreamerdeo/mark-bart/tree/main) -> `./models/spider_sl`
|
||||
(Please download this model by `git-lfs` to avoid the [issue](https://github.com/DreamerDeo/UniSAr_text2sql/issues/1).)
|
||||
```angular2html
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/dreamerdeo/mark-bart
|
||||
```
|
||||
|
||||
|
||||
## Main dependencies
|
||||
* Python version >= 3.6
|
||||
* PyTorch version >= 1.5.0
|
||||
* `pip install -r requirements.txt`
|
||||
* fairseq is going though changing without backward compatibility. Install `fairseq` from source and use [this](https://github.com/nicola-decao/fairseq/tree/fixing_prefix_allowed_tokens_fn) commit for reproducibilty. See [here](https://github.com/pytorch/fairseq/pull/3276) for the current PR that should fix `fairseq/master`.
|
||||
|
||||
|
||||
## Evaluation Pipeline
|
||||
Step 1: Preprocess via adding schema-linking and value-linking tag.
|
||||
|
||||
`python step1_schema_linking.py`
|
||||
|
||||
Step 2: Building the input and output for BART.
|
||||
|
||||
`python step2_serialization.py`
|
||||
|
||||
Step 3: Evaluation Script with/without constrained decoding.
|
||||
|
||||
`python step3_evaluate.py --constrain`
|
||||
|
||||
## Results
|
||||
Prediction: `69.34`
|
||||
|
||||
Prediction with Constrain Decoding: `70.02`
|
||||
|
||||
## Interactive
|
||||
`python interactive.py --logdir ./models/spider-sl --db_id student_1 --db-path ./data/spider/database --schema-path ./data/spider/tables.json`
|
||||
|
||||
## Reference Code
|
||||
`https://github.com/ryanzhumich/editsql`
|
||||
|
||||
`https://github.com/benbogin/spider-schema-gnn-global`
|
||||
|
||||
`https://github.com/ElementAI/duorat`
|
||||
|
||||
`https://github.com/facebookresearch/GENRE`
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,20 @@
|
|||
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.source.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.target.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
|
||||
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.source.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.target.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
|
||||
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.src.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.tgt.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
|
||||
[src] Dictionary: 50264 types
|
||||
[src] ./dataset_post/spider_sl/dev.bpe.src: 1034 sents, 270269 tokens, 0.0% replaced by <unk>
|
||||
[tgt] Dictionary: 50264 types
|
||||
[tgt] ./dataset_post/spider_sl/dev.bpe.tgt: 1034 sents, 33481 tokens, 0.0% replaced by <unk>
|
||||
Wrote preprocessed data to ./dataset_post/spider_sl/bin
|
||||
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.src.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.tgt.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
|
||||
[src] Dictionary: 50264 types
|
||||
[src] ./dataset_post/spider_sl/dev.bpe.src: 1034 sents, 270269 tokens, 0.0% replaced by <unk>
|
||||
[tgt] Dictionary: 50264 types
|
||||
[tgt] ./dataset_post/spider_sl/dev.bpe.tgt: 1034 sents, 33481 tokens, 0.0% replaced by <unk>
|
||||
Wrote preprocessed data to ./dataset_post/spider_sl/bin
|
||||
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/mark-bart/dict.src.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/mark-bart/dict.tgt.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
|
||||
[src] Dictionary: 50264 types
|
||||
[src] ./dataset_post/spider_sl/dev.bpe.src: 1034 sents, 281522 tokens, 0.0% replaced by <unk>
|
||||
[tgt] Dictionary: 50264 types
|
||||
[tgt] ./dataset_post/spider_sl/dev.bpe.tgt: 1034 sents, 33481 tokens, 0.0% replaced by <unk>
|
||||
Wrote preprocessed data to ./dataset_post/spider_sl/bin
|
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.src.bin
Normal file
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.src.bin
Normal file
Двоичный файл не отображается.
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.src.idx
Normal file
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.src.idx
Normal file
Двоичный файл не отображается.
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.tgt.bin
Normal file
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.tgt.bin
Normal file
Двоичный файл не отображается.
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.tgt.idx
Normal file
Двоичные данные
unified_parser_text_to_sql/dataset_post/spider_sl/bin/valid.src-tgt.tgt.idx
Normal file
Двоичный файл не отображается.
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,133 @@
|
|||
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from genre.trie import Trie
|
||||
|
||||
keyword = ['select', 'distinct', 'from', 'join', 'on', 'where', 'group', 'by', 'order', 'asc', 'desc', 'limit',
|
||||
'having',
|
||||
'and', 'not', 'or', 'like', 'between', 'in',
|
||||
'sum', 'count', 'max', 'min', 'avg',
|
||||
'(', ')', ',', '>', '<', '=', '>=', '!=', '<=',
|
||||
'union', 'except', 'intersect',
|
||||
'1', '2', '3', '4', '5']
|
||||
|
||||
|
||||
def get_end_to_end_prefix_allowed_tokens_fn_hf(
|
||||
model,
|
||||
sentences: List[str],
|
||||
start_mention_token="{",
|
||||
end_mention_token="}",
|
||||
start_entity_token="[",
|
||||
end_entity_token="]",
|
||||
mention_trie: Trie = None,
|
||||
candidates_trie: Trie = None,
|
||||
mention_to_candidates_dict: Dict[str, List[str]] = None,
|
||||
):
|
||||
return _get_end_to_end_prefix_allowed_tokens_fn(
|
||||
lambda x: model.tokenizer.encode(x),
|
||||
lambda x: model.tokenizer.decode(torch.tensor(x)),
|
||||
model.tokenizer.bos_token_id,
|
||||
model.tokenizer.pad_token_id,
|
||||
model.tokenizer.eos_token_id,
|
||||
len(model.tokenizer) - 1,
|
||||
sentences,
|
||||
start_mention_token,
|
||||
end_mention_token,
|
||||
start_entity_token,
|
||||
end_entity_token,
|
||||
mention_trie,
|
||||
candidates_trie,
|
||||
mention_to_candidates_dict,
|
||||
)
|
||||
|
||||
|
||||
def get_end_to_end_prefix_allowed_tokens_fn_fairseq(
|
||||
model,
|
||||
sentences: List[str],
|
||||
start_mention_token="{",
|
||||
end_mention_token="}",
|
||||
start_entity_token="[",
|
||||
end_entity_token="]",
|
||||
mention_trie: Trie = None,
|
||||
candidates_trie: Trie = None,
|
||||
mention_to_candidates_dict: Dict[str, List[str]] = None,
|
||||
):
|
||||
return _get_end_to_end_prefix_allowed_tokens_fn(
|
||||
lambda x: model.encode(x).tolist(),
|
||||
lambda x: model.decode(torch.tensor(x)),
|
||||
model.model.decoder.dictionary.bos(),
|
||||
model.model.decoder.dictionary.pad(),
|
||||
model.model.decoder.dictionary.eos(),
|
||||
len(model.model.decoder.dictionary),
|
||||
sentences,
|
||||
start_mention_token,
|
||||
end_mention_token,
|
||||
start_entity_token,
|
||||
end_entity_token,
|
||||
mention_trie,
|
||||
candidates_trie,
|
||||
mention_to_candidates_dict,
|
||||
)
|
||||
|
||||
|
||||
def _get_end_to_end_prefix_allowed_tokens_fn(
|
||||
encode_fn,
|
||||
decode_fn,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
vocabulary_length,
|
||||
sentences: List[str],
|
||||
start_mention_token="{",
|
||||
end_mention_token="}",
|
||||
start_entity_token="[",
|
||||
end_entity_token="]",
|
||||
mention_trie: Trie = None,
|
||||
candidates_trie: Trie = None,
|
||||
mention_to_candidates_dict: Dict[str, List[str]] = None,
|
||||
):
|
||||
assert not (
|
||||
candidates_trie is not None and mention_to_candidates_dict is not None
|
||||
), "`candidates_trie` and `mention_to_candidates_dict` cannot be both != `None`"
|
||||
|
||||
codes = {}
|
||||
codes["EOS"] = eos_token_id
|
||||
codes["BOS"] = bos_token_id
|
||||
|
||||
keyword_codes = {k: encode_fn(" {}".format(k))[1] for k in keyword}
|
||||
keyword_codes['wselect'] = encode_fn("{}".format('select'))[1]
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, sent):
|
||||
sent = sent.tolist()
|
||||
trie_out = get_trie_schema(sent)
|
||||
return trie_out
|
||||
|
||||
def get_trie_schema(sent):
|
||||
pointer_start = get_keyword_mention(sent)
|
||||
keyword_rnt = list(keyword_codes.values())
|
||||
|
||||
if pointer_start + 1 < len(sent) and pointer_start != -1:
|
||||
ment_next = mention_trie.get(sent[pointer_start + 1:])
|
||||
if codes["EOS"] in ment_next:
|
||||
return ment_next + keyword_rnt
|
||||
else:
|
||||
return ment_next
|
||||
else:
|
||||
ment_next = mention_trie.get([])
|
||||
return ment_next + keyword_rnt + [codes["EOS"]]
|
||||
|
||||
def get_keyword_mention(sent):
|
||||
pointer_start = -1
|
||||
for i, e in enumerate(sent):
|
||||
if e in keyword_codes.values():
|
||||
pointer_start = i
|
||||
return pointer_start
|
||||
|
||||
return prefix_allowed_tokens_fn
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from fairseq import search, utils
|
||||
from fairseq.models.bart import BARTHubInterface, BARTModel
|
||||
from omegaconf import open_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GENREHubInterface(BARTHubInterface):
|
||||
def sample(
|
||||
self,
|
||||
sentences: List[str],
|
||||
beam: int = 5,
|
||||
verbose: bool = False,
|
||||
text_to_id=None,
|
||||
marginalize=False,
|
||||
marginalize_lenpen=0.5,
|
||||
max_len_a=1024,
|
||||
max_len_b=1024,
|
||||
**kwargs,
|
||||
) -> List[str]:
|
||||
|
||||
if isinstance(sentences, str):
|
||||
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
|
||||
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
|
||||
|
||||
batched_hypos = self.generate(
|
||||
tokenized_sentences,
|
||||
beam,
|
||||
verbose,
|
||||
max_len_a=max_len_a,
|
||||
max_len_b=max_len_b,
|
||||
**kwargs,
|
||||
)
|
||||
outputs = [
|
||||
[
|
||||
{"text": self.decode(hypo["tokens"]), "score": hypo["score"]}
|
||||
for hypo in hypos
|
||||
]
|
||||
for hypos in batched_hypos
|
||||
]
|
||||
if text_to_id:
|
||||
outputs = [
|
||||
[{**hypo, "id": text_to_id(hypo["text"])} for hypo in hypos]
|
||||
for hypos in outputs
|
||||
]
|
||||
|
||||
if marginalize:
|
||||
for (i, hypos), hypos_tok in zip(enumerate(outputs), batched_hypos):
|
||||
outputs_dict = defaultdict(list)
|
||||
for hypo, hypo_tok in zip(hypos, hypos_tok):
|
||||
outputs_dict[hypo["id"]].append(
|
||||
{**hypo, "len": len(hypo_tok["tokens"])}
|
||||
)
|
||||
|
||||
outputs[i] = sorted(
|
||||
[
|
||||
{
|
||||
"id": _id,
|
||||
"texts": [hypo["text"] for hypo in hypos],
|
||||
"scores": torch.stack(
|
||||
[hypo["score"] for hypo in hypos]
|
||||
),
|
||||
"score": torch.stack(
|
||||
[
|
||||
hypo["score"]
|
||||
* hypo["len"]
|
||||
/ (hypo["len"] ** marginalize_lenpen)
|
||||
for hypo in hypos
|
||||
]
|
||||
).logsumexp(-1),
|
||||
}
|
||||
for _id, hypos in outputs_dict.items()
|
||||
],
|
||||
key=lambda x: x["score"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def generate(self, *args, **kwargs) -> List[List[Dict[str, torch.Tensor]]]:
|
||||
return super(BARTHubInterface, self).generate(*args, **kwargs)
|
||||
|
||||
def encode(self, sentence) -> torch.LongTensor:
|
||||
tokens = super(BARTHubInterface, self).encode(sentence)
|
||||
tokens[
|
||||
tokens >= len(self.task.target_dictionary)
|
||||
] = self.task.target_dictionary.unk_index
|
||||
if tokens[0] != self.task.target_dictionary.bos_index:
|
||||
return torch.cat(
|
||||
(torch.tensor([self.task.target_dictionary.bos_index]), tokens)
|
||||
)
|
||||
else:
|
||||
return tokens
|
||||
|
||||
|
||||
class GENRE(BARTModel):
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name_or_path,
|
||||
checkpoint_file="model.pt",
|
||||
data_name_or_path=".",
|
||||
bpe="gpt2",
|
||||
**kwargs,
|
||||
):
|
||||
from fairseq import hub_utils
|
||||
|
||||
x = hub_utils.from_pretrained(
|
||||
model_name_or_path,
|
||||
checkpoint_file,
|
||||
data_name_or_path,
|
||||
archive_map=cls.hub_models(),
|
||||
bpe=bpe,
|
||||
load_checkpoint_heads=True,
|
||||
**kwargs,
|
||||
)
|
||||
return GENREHubInterface(x["args"], x["task"], x["models"][0])
|
||||
|
||||
|
||||
class mGENRE(BARTModel):
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name_or_path,
|
||||
sentencepiece_model="sentence.bpe.model",
|
||||
checkpoint_file="model.pt",
|
||||
data_name_or_path=".",
|
||||
bpe="sentencepiece",
|
||||
layernorm_embedding=True,
|
||||
**kwargs,
|
||||
):
|
||||
from fairseq import hub_utils
|
||||
|
||||
x = hub_utils.from_pretrained(
|
||||
model_name_or_path,
|
||||
checkpoint_file,
|
||||
data_name_or_path,
|
||||
archive_map=cls.hub_models(),
|
||||
bpe=bpe,
|
||||
load_checkpoint_heads=True,
|
||||
sentencepiece_model=os.path.join(model_name_or_path, sentencepiece_model),
|
||||
**kwargs,
|
||||
)
|
||||
return GENREHubInterface(x["args"], x["task"], x["models"][0])
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from genre.utils import chunk_it
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GENREHubInterface(BartForConditionalGeneration):
|
||||
def sample(
|
||||
self, sentences: List[str], num_beams: int = 5, num_return_sequences=5, **kwargs
|
||||
) -> List[str]:
|
||||
|
||||
input_args = {
|
||||
k: v.to(self.device)
|
||||
for k, v in self.tokenizer.batch_encode_plus(
|
||||
sentences, padding=True, return_tensors="pt"
|
||||
).items()
|
||||
}
|
||||
|
||||
outputs = self.generate(
|
||||
**input_args,
|
||||
min_length=0,
|
||||
max_length=1024,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_return_sequences,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return chunk_it(
|
||||
[
|
||||
{
|
||||
"text": text,
|
||||
"logprob": score,
|
||||
}
|
||||
for text, score in zip(
|
||||
self.tokenizer.batch_decode(
|
||||
outputs.sequences, skip_special_tokens=True
|
||||
),
|
||||
outputs.sequences_scores,
|
||||
)
|
||||
],
|
||||
num_return_sequences,
|
||||
)
|
||||
|
||||
def encode(self, sentence):
|
||||
return self.tokenizer.encode(sentence, return_tensors="pt")[0]
|
||||
|
||||
|
||||
class GENRE(BartForConditionalGeneration):
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path):
|
||||
model = GENREHubInterface.from_pretrained(model_name_or_path)
|
||||
model.tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
|
||||
return model
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
try:
|
||||
import marisa_trie
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class Trie(object):
|
||||
def __init__(self, sequences: List[List[int]] = []):
|
||||
self.trie_dict = {}
|
||||
self.len = 0
|
||||
if sequences:
|
||||
for sequence in sequences:
|
||||
Trie._add_to_trie(sequence, self.trie_dict)
|
||||
self.len += 1
|
||||
|
||||
self.append_trie = None
|
||||
self.bos_token_id = None
|
||||
|
||||
def append(self, trie, bos_token_id):
|
||||
self.append_trie = trie
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def add(self, sequence: List[int]):
|
||||
Trie._add_to_trie(sequence, self.trie_dict)
|
||||
self.len += 1
|
||||
|
||||
def get(self, prefix_sequence: List[int]):
|
||||
return Trie._get_from_trie(
|
||||
prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_dict(trie_dict):
|
||||
trie = Trie()
|
||||
trie.trie_dict = trie_dict
|
||||
trie.len = sum(1 for _ in trie)
|
||||
return trie
|
||||
|
||||
@staticmethod
|
||||
def _add_to_trie(sequence: List[int], trie_dict: Dict):
|
||||
if sequence:
|
||||
if sequence[0] not in trie_dict:
|
||||
trie_dict[sequence[0]] = {}
|
||||
Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
|
||||
|
||||
@staticmethod
|
||||
def _get_from_trie(
|
||||
prefix_sequence: List[int],
|
||||
trie_dict: Dict,
|
||||
append_trie=None,
|
||||
bos_token_id: int = None,
|
||||
):
|
||||
if len(prefix_sequence) == 0:
|
||||
output = list(trie_dict.keys())
|
||||
if append_trie and bos_token_id in output:
|
||||
output.remove(bos_token_id)
|
||||
output += list(append_trie.trie_dict.keys())
|
||||
return output
|
||||
elif prefix_sequence[0] in trie_dict:
|
||||
return Trie._get_from_trie(
|
||||
prefix_sequence[1:],
|
||||
trie_dict[prefix_sequence[0]],
|
||||
append_trie,
|
||||
bos_token_id,
|
||||
)
|
||||
else:
|
||||
if append_trie:
|
||||
return append_trie.get(prefix_sequence)
|
||||
else:
|
||||
return []
|
||||
|
||||
def __iter__(self):
|
||||
def _traverse(prefix_sequence, trie_dict):
|
||||
if trie_dict:
|
||||
for next_token in trie_dict:
|
||||
yield from _traverse(
|
||||
prefix_sequence + [next_token], trie_dict[next_token]
|
||||
)
|
||||
else:
|
||||
yield prefix_sequence
|
||||
|
||||
return _traverse([], self.trie_dict)
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
def __getitem__(self, value):
|
||||
return self.get(value)
|
||||
|
||||
|
||||
class MarisaTrie(object):
|
||||
def __init__(
|
||||
self,
|
||||
sequences: List[List[int]] = [],
|
||||
cache_fist_branch=True,
|
||||
max_token_id=256001,
|
||||
):
|
||||
|
||||
self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + (
|
||||
[chr(i) for i in range(65000, max_token_id + 10000)]
|
||||
if max_token_id >= 55000
|
||||
else []
|
||||
)
|
||||
self.char2int = {self.int2char[i]: i for i in range(max_token_id)}
|
||||
|
||||
self.cache_fist_branch = cache_fist_branch
|
||||
if self.cache_fist_branch:
|
||||
self.zero_iter = list({sequence[0] for sequence in sequences})
|
||||
assert len(self.zero_iter) == 1
|
||||
self.first_iter = list({sequence[1] for sequence in sequences})
|
||||
|
||||
self.trie = marisa_trie.Trie(
|
||||
"".join([self.int2char[i] for i in sequence]) for sequence in sequences
|
||||
)
|
||||
|
||||
def get(self, prefix_sequence: List[int]):
|
||||
if self.cache_fist_branch and len(prefix_sequence) == 0:
|
||||
return self.zero_iter
|
||||
elif (
|
||||
self.cache_fist_branch
|
||||
and len(prefix_sequence) == 1
|
||||
and self.zero_iter == prefix_sequence
|
||||
):
|
||||
return self.first_iter
|
||||
else:
|
||||
key = "".join([self.int2char[i] for i in prefix_sequence])
|
||||
return list(
|
||||
{
|
||||
self.char2int[e[len(key)]]
|
||||
for e in self.trie.keys(key)
|
||||
if len(e) > len(key)
|
||||
}
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
for sequence in self.trie.iterkeys():
|
||||
yield [self.char2int[e] for e in sequence]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.trie)
|
||||
|
||||
def __getitem__(self, value):
|
||||
return self.get(value)
|
||||
|
||||
|
||||
class DummyTrieMention(object):
|
||||
def __init__(self, return_values):
|
||||
self._return_values = return_values
|
||||
|
||||
def get(self, indices=None):
|
||||
return self._return_values
|
||||
|
||||
|
||||
class DummyTrieEntity(object):
|
||||
def __init__(self, return_values, codes):
|
||||
self._return_values = list(
|
||||
set(return_values).difference(
|
||||
set(
|
||||
codes[e]
|
||||
for e in (
|
||||
"start_mention_token",
|
||||
"end_mention_token",
|
||||
"start_entity_token",
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
self._codes = codes
|
||||
|
||||
def get(self, indices, depth=0):
|
||||
if len(indices) == 0 and depth == 0:
|
||||
return self._codes["end_mention_token"]
|
||||
elif len(indices) == 0 and depth == 1:
|
||||
return self._codes["start_entity_token"]
|
||||
elif len(indices) == 0:
|
||||
return self._return_values
|
||||
elif len(indices) == 1 and indices[0] == self._codes["end_entity_token"]:
|
||||
return self._codes["EOS"]
|
||||
else:
|
||||
return self.get(indices[1:], depth=depth + 1)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,52 @@
|
|||
import argparse
|
||||
import stanza
|
||||
from unisar.api import UnisarAPI
|
||||
|
||||
|
||||
class Interactive(object):
|
||||
|
||||
def __init__(self, Unisar: UnisarAPI):
|
||||
self.unisar = Unisar
|
||||
|
||||
def ask_any_question(self, question, db_id):
|
||||
results = self.unisar.infer_query(question, db_id)
|
||||
|
||||
print('input:', results['slml_question'])
|
||||
print(f'"pred:" {results["predict_sql"]} ({results["score"]})')
|
||||
# try:
|
||||
# results = self.unisar.execute(results['query'])
|
||||
# print(results)
|
||||
# except Exception as e:
|
||||
# print(str(e))
|
||||
|
||||
def show_schema(self, db_id):
|
||||
for table in self.unisar.schema[db_id].values():
|
||||
print("Table", f"{table.name}")
|
||||
for column in table.columns:
|
||||
print(" Column", f"{column.name}")
|
||||
|
||||
def run(self, db_id):
|
||||
self.show_schema(db_id)
|
||||
# self.ask_any_question('Tell me the name about organization', db_id)
|
||||
while True:
|
||||
question = input("Ask a question: ")
|
||||
self.ask_any_question(question, db_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--logdir", default='./models/spider_sl')
|
||||
parser.add_argument("--db_id", default='student_1')
|
||||
parser.add_argument(
|
||||
"--db-path", default='./data/spider/database',
|
||||
help="The path to the sqlite database or csv file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schema-path", default='./data/spider/tables.json',
|
||||
help="The path to the tables.json file with human-readable database schema."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
stanza_model = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma')
|
||||
interactive = Interactive(UnisarAPI(args.logdir, args.db_path, args.schema_path, stanza_model))
|
||||
interactive.run(args.db_id)
|
|
@ -0,0 +1,136 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import sys
|
||||
from collections import Counter
|
||||
from multiprocessing import Pool
|
||||
# from .bpe_utils import get_encoder
|
||||
from fairseq.data.encoders.gpt2_bpe import get_encoder
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Helper script to encode raw text with the GPT-2 BPE using multiple processes.
|
||||
|
||||
The encoder.json and vocab.bpe files can be obtained here:
|
||||
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
|
||||
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--encoder-json",
|
||||
help="path to encoder.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-bpe",
|
||||
type=str,
|
||||
help="path to vocab.bpe",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--special-token",
|
||||
# type=str,
|
||||
# help="path to special tokens split by \n"
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--inputs",
|
||||
nargs="+",
|
||||
default=["-"],
|
||||
help="input files to filter/encode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outputs",
|
||||
nargs="+",
|
||||
default=["-"],
|
||||
help="path to save encoded outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-empty",
|
||||
action="store_true",
|
||||
help="keep empty lines",
|
||||
)
|
||||
parser.add_argument("--workers", type=int, default=20)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert len(args.inputs) == len(
|
||||
args.outputs
|
||||
), "number of input and output paths should match"
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
inputs = [
|
||||
stack.enter_context(open(input, "r", encoding="utf-8"))
|
||||
if input != "-"
|
||||
else sys.stdin
|
||||
for input in args.inputs
|
||||
]
|
||||
outputs = [
|
||||
stack.enter_context(open(output, "w", encoding="utf-8"))
|
||||
if output != "-"
|
||||
else sys.stdout
|
||||
for output in args.outputs
|
||||
]
|
||||
|
||||
encoder = MultiprocessingEncoder(args)
|
||||
pool = Pool(args.workers, initializer=encoder.initializer)
|
||||
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
|
||||
|
||||
stats = Counter()
|
||||
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
|
||||
if filt == "PASS":
|
||||
for enc_line, output_h in zip(enc_lines, outputs):
|
||||
print(enc_line, file=output_h)
|
||||
else:
|
||||
stats["num_filtered_" + filt] += 1
|
||||
if i % 10000 == 0:
|
||||
print("processed {} lines".format(i), file=sys.stderr)
|
||||
|
||||
for k, v in stats.most_common():
|
||||
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
|
||||
|
||||
|
||||
class MultiprocessingEncoder(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def initializer(self):
|
||||
global bpe
|
||||
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
|
||||
# bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe, self.args.special_token)
|
||||
|
||||
def encode(self, line):
|
||||
global bpe
|
||||
ids = bpe.encode(line)
|
||||
return list(map(str, ids))
|
||||
|
||||
def decode(self, tokens):
|
||||
global bpe
|
||||
return bpe.decode(tokens)
|
||||
|
||||
def encode_lines(self, lines):
|
||||
"""
|
||||
Encode a set of lines. All lines will be encoded together.
|
||||
"""
|
||||
enc_lines = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if len(line) == 0 and not self.args.keep_empty:
|
||||
return ["EMPTY", None]
|
||||
tokens = self.encode(line)
|
||||
enc_lines.append(" ".join(tokens))
|
||||
return ["PASS", enc_lines]
|
||||
|
||||
def decode_lines(self, lines):
|
||||
dec_lines = []
|
||||
for line in lines:
|
||||
tokens = map(int, line.strip().split())
|
||||
dec_lines.append(self.decode(tokens))
|
||||
return ["PASS", dec_lines]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,11 @@
|
|||
dataclasses
|
||||
tqdm
|
||||
pydantic
|
||||
sqlparse
|
||||
unidecode
|
||||
networkx
|
||||
stanza
|
||||
nltk
|
||||
vocab
|
||||
sentencepiece
|
||||
torch
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
|
||||
#requirement:
|
||||
#./data/spider
|
||||
#./BART-large
|
||||
|
||||
# data/spider -> data/spider_schema_linking_tag
|
||||
python step1_schema_linking.py --dataset=spider
|
||||
|
||||
# data/spider_schema_linking_tag -> dataset_post/spider_sl
|
||||
python step2_serialization.py
|
||||
|
||||
###training
|
||||
python train.py \
|
||||
--dataset_path ./dataset_post/spider_sl/bin/ \
|
||||
--exp_name spider_sl_v1 \
|
||||
--models_path ./models \
|
||||
--total_num_update 10000 \
|
||||
--max_tokens 1024 \
|
||||
--bart_model_path ./data/BART-large \
|
||||
|
||||
###evaluate
|
||||
python step3_evaluate --constrain
|
|
@ -0,0 +1,257 @@
|
|||
import re
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
from unidecode import unidecode
|
||||
|
||||
from semparse.sql.spider_utils import TableColumn, read_dataset_schema, read_dataset_values
|
||||
|
||||
# == stop words that will be omitted by ContextGenerator
|
||||
STOP_WORDS = {"", "", "all", "being", "-", "over", "through", "yourselves", "its", "before",
|
||||
"hadn", "with", "had", ",", "should", "to", "only", "under", "ours", "has", "ought", "do",
|
||||
"them", "his", "than", "very", "cannot", "they", "not", "during", "yourself", "him",
|
||||
"nor", "did", "didn", "'ve", "this", "she", "each", "where", "because", "doing", "some", "we", "are",
|
||||
"further", "ourselves", "out", "what", "for", "weren", "does", "above", "between", "mustn", "?",
|
||||
"be", "hasn", "who", "were", "here", "shouldn", "let", "hers", "by", "both", "about", "couldn",
|
||||
"of", "could", "against", "isn", "or", "own", "into", "while", "whom", "down", "wasn", "your",
|
||||
"from", "her", "their", "aren", "there", "been", ".", "few", "too", "wouldn", "themselves",
|
||||
":", "was", "until", "more", "himself", "on", "but", "don", "herself", "haven", "those", "he",
|
||||
"me", "myself", "these", "up", ";", "below", "'re", "can", "theirs", "my", "and", "would", "then",
|
||||
"is", "am", "it", "doesn", "an", "as", "itself", "at", "have", "in", "any", "if", "!",
|
||||
"again", "'ll", "no", "that", "when", "same", "how", "other", "which", "you", "many", "shan",
|
||||
"'t", "'s", "our", "after", "most", "'d", "such", "'m", "why", "a", "off", "i", "yours", "so",
|
||||
"the", "having", "once"}
|
||||
|
||||
digits_list = [str(i) for i in range(10)]
|
||||
|
||||
class SpiderDBContext:
|
||||
schemas = {}
|
||||
db_knowledge_graph = {}
|
||||
db_tables_data = {}
|
||||
|
||||
def __init__(self, db_id: str, utterance: str, tables_file: str, dataset_path: str, stanza_model=None, schemas=None,
|
||||
original_utterance=None):
|
||||
self.dataset_path = dataset_path
|
||||
self.tables_file = tables_file
|
||||
self.db_id = db_id
|
||||
self.utterance = utterance
|
||||
self.tokenized_utterance = utterance
|
||||
self.stanza_model = stanza_model
|
||||
self.original_utterance = original_utterance if original_utterance is not None else utterance
|
||||
|
||||
if schemas is not None:
|
||||
SpiderDBContext.schemas = schemas
|
||||
elif db_id not in SpiderDBContext.schemas:
|
||||
SpiderDBContext.schemas = read_dataset_schema(self.tables_file, self.stanza_model)
|
||||
|
||||
self.schema = SpiderDBContext.schemas[db_id]
|
||||
|
||||
@staticmethod
|
||||
def entity_key_for_column(table_name: str, column: TableColumn) -> str:
|
||||
return f"{table_name.lower()}@{column.name.lower()}"
|
||||
if column.foreign_key is not None:
|
||||
column_type = "foreign"
|
||||
elif column.is_primary_key:
|
||||
column_type = "primary"
|
||||
else:
|
||||
column_type = column.column_type
|
||||
return f"column:{column_type.lower()}:{table_name.lower()}:{column.name.lower()}"
|
||||
|
||||
def get_db_knowledge_graph(self, db_id: str):
|
||||
db_schema = self.schema
|
||||
tables = db_schema.values()
|
||||
|
||||
if db_id not in self.db_tables_data:
|
||||
self.db_tables_data[db_id] = read_dataset_values(db_id, self.dataset_path, tables)
|
||||
|
||||
tables_data = self.db_tables_data[db_id]
|
||||
string_column_mapping: Dict[str, set] = defaultdict(set)
|
||||
|
||||
for table, table_data in tables_data.items():
|
||||
for table_row in table_data:
|
||||
for column, cell_value in zip(db_schema[table.name].columns, table_row):
|
||||
if column.column_type == 'text' and type(cell_value) is str:
|
||||
cell_value_normalized = self.normalize_string(cell_value)
|
||||
column_key = self.entity_key_for_column(table.name, column)
|
||||
string_column_mapping[cell_value_normalized].add(column_key)
|
||||
# for key in string_column_mapping:
|
||||
# string_column_mapping[key]=list(string_column_mapping[key])
|
||||
|
||||
string_entities = self.get_entities_from_question(string_column_mapping)
|
||||
|
||||
value_match=[]
|
||||
value_alignment = []
|
||||
for item in string_entities:
|
||||
value_match+=item['token_in_columns']
|
||||
value_alignment.append(item['alignment'])
|
||||
value_match = list(set(value_match))
|
||||
|
||||
r_schemas={}
|
||||
for table in db_schema:
|
||||
r_schemas["{0}".format(db_schema[table].name).lower()] = db_schema[table].lemma.lower()
|
||||
for column in db_schema[table].columns:
|
||||
r_schemas[f"{db_schema[table].name}@{column.name}".lower()] = column.lemma.strip('is ').lower()
|
||||
|
||||
question_tokens = [t for t in self.tokenized_utterance]
|
||||
schema_counter = Counter()
|
||||
|
||||
partial_match = []
|
||||
exact_match = []
|
||||
|
||||
for r_k, r_s in r_schemas.items():
|
||||
schema_counter[r_s] = 0
|
||||
#exact match
|
||||
if r_s in self.tokenized_utterance and r_s not in STOP_WORDS:
|
||||
schema_counter[r_s] += 2
|
||||
#partial_match
|
||||
else:
|
||||
for tok in r_s.split(' '):
|
||||
if tok in question_tokens and tok not in STOP_WORDS:
|
||||
schema_counter[r_s]+=1
|
||||
continue
|
||||
|
||||
for ques_tok in question_tokens:
|
||||
if tok in STOP_WORDS or ques_tok in STOP_WORDS or \
|
||||
len(tok)<=3 or len(ques_tok)<=3:
|
||||
continue
|
||||
if ques_tok in tok or tok in ques_tok:
|
||||
schema_counter[r_s] += 1
|
||||
|
||||
if schema_counter[r_s]>=2:
|
||||
exact_match.append(r_k)
|
||||
elif schema_counter[r_s]==1:
|
||||
partial_match.append(r_k)
|
||||
|
||||
return value_match, value_alignment, exact_match, partial_match
|
||||
|
||||
|
||||
def _string_in_table(self, candidate: str,
|
||||
string_column_mapping: Dict[str, set]) -> List[str]:
|
||||
"""
|
||||
Checks if the string occurs in the table, and if it does, returns the names of the columns
|
||||
under which it occurs. If it does not, returns an empty list.
|
||||
"""
|
||||
candidate_column_names: List[str] = []
|
||||
alignment = []
|
||||
# First check if the entire candidate occurs as a cell.
|
||||
candidate= candidate.strip('-_"\'')
|
||||
if candidate in string_column_mapping and candidate not in digits_list:
|
||||
candidate_column_names = string_column_mapping[candidate]
|
||||
alignment.append((candidate,candidate))
|
||||
# # If not, check if it is a substring pf any cell value.
|
||||
# if not candidate_column_names:
|
||||
# for cell_value, column_names in string_column_mapping.items():
|
||||
# if candidate in re.split(' |_|:',
|
||||
# cell_value) and candidate not in STOP_WORDS and candidate not in digits_list:
|
||||
# candidate_column_names.extend(column_names)
|
||||
# alignment.append((candidate, cell_value))
|
||||
candidate_column_names = list(set(candidate_column_names))
|
||||
return candidate_column_names, alignment
|
||||
|
||||
def get_entities_from_question(self, string_column_mapping: Dict[str, set]) -> List[Tuple[str, str]]:
|
||||
entity_data = []
|
||||
|
||||
for cell_value, column_names in string_column_mapping.items():
|
||||
if (cell_value.replace('_', ' ') in ' '.join(self.utterance) or cell_value.replace('_',
|
||||
' ') in self.original_utterance) \
|
||||
and len(re.split('_', cell_value)) >= 2:
|
||||
entity_data.append({'value': cell_value,
|
||||
'token_start': 0,
|
||||
'token_end': 0,
|
||||
'alignment': [(cell_value, cell_value)],
|
||||
'token_in_columns': list(set(column_names))})
|
||||
|
||||
for i, token in enumerate(self.tokenized_utterance):
|
||||
token_text = token
|
||||
if token_text in STOP_WORDS:
|
||||
continue
|
||||
normalized_token_text = token_text
|
||||
# normalized_token_text = self.normalize_string(token_text)
|
||||
if not normalized_token_text:
|
||||
continue
|
||||
token_columns, alignment = self._string_in_table(normalized_token_text, string_column_mapping)
|
||||
if token_columns:
|
||||
entity_data.append({'value': normalized_token_text,
|
||||
'token_start': i,
|
||||
'token_end': i+1,
|
||||
'alignment': alignment,
|
||||
'token_in_columns': token_columns})
|
||||
|
||||
return entity_data
|
||||
|
||||
@staticmethod
|
||||
def normalize_string(string: str) -> str:
|
||||
"""
|
||||
These are the transformation rules used to normalize cell in column names in Sempre. See
|
||||
``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and
|
||||
``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``. We reproduce those
|
||||
rules here to normalize and canonicalize cells and columns in the same way so that we can
|
||||
match them against constants in logical forms appropriately.
|
||||
"""
|
||||
# Normalization rules from Sempre
|
||||
# \u201A -> ,
|
||||
string = re.sub("‚", ",", string)
|
||||
string = re.sub("„", ",,", string)
|
||||
string = re.sub("[·・]", "../sql", string)
|
||||
string = re.sub("…", "...", string)
|
||||
string = re.sub("ˆ", "^", string)
|
||||
string = re.sub("˜", "~", string)
|
||||
string = re.sub("‹", "<", string)
|
||||
string = re.sub("›", ">", string)
|
||||
string = re.sub("[‘’´`]", "'", string)
|
||||
string = re.sub("[“”«»]", "\"", string)
|
||||
string = re.sub("[•†‡²³]", "", string)
|
||||
string = re.sub("[‐‑–—−]", "-", string)
|
||||
# Oddly, some unicode characters get converted to _ instead of being stripped. Not really
|
||||
# sure how sempre decides what to do with these... TODO(mattg): can we just get rid of the
|
||||
# need for this function somehow? It's causing a whole lot of headaches.
|
||||
string = re.sub("[ðø′″€⁄ªΣ]", "_", string)
|
||||
# This is such a mess. There isn't just a block of unicode that we can strip out, because
|
||||
# sometimes sempre just strips diacritics... We'll try stripping out a few separate
|
||||
# blocks, skipping the ones that sempre skips...
|
||||
string = re.sub("[\\u0180-\\u0210]", "", string).strip()
|
||||
string = re.sub("[\\u0220-\\uFFFF]", "", string).strip()
|
||||
string = string.replace("\\n", "_")
|
||||
string = re.sub("\\s+", " ", string)
|
||||
# Canonicalization rules from Sempre.
|
||||
string = re.sub("[^\\w]", "_", string)
|
||||
string = re.sub("_+", "_", string)
|
||||
string = re.sub("_$", "", string)
|
||||
return unidecode(string.lower())
|
||||
|
||||
def _expand_entities(self, question, entity_data, string_column_mapping: Dict[str, set]):
|
||||
new_entities = []
|
||||
for entity in entity_data:
|
||||
# to ensure the same strings are not used over and over
|
||||
if new_entities and entity['token_end'] <= new_entities[-1]['token_end']:
|
||||
continue
|
||||
current_start = entity['token_start']
|
||||
current_end = entity['token_end']
|
||||
current_token = entity['value']
|
||||
current_token_type = entity['token_type']
|
||||
current_token_columns = entity['token_in_columns']
|
||||
|
||||
while current_end < len(question):
|
||||
next_token = question[current_end].text
|
||||
next_token_normalized = self.normalize_string(next_token)
|
||||
if next_token_normalized == "":
|
||||
current_end += 1
|
||||
continue
|
||||
candidate = "%s_%s" %(current_token, next_token_normalized)
|
||||
candidate_columns = self._string_in_table(candidate, string_column_mapping)
|
||||
candidate_columns = list(set(candidate_columns).intersection(current_token_columns))
|
||||
if not candidate_columns:
|
||||
break
|
||||
candidate_type = candidate_columns[0].split(":")[1]
|
||||
if candidate_type != current_token_type:
|
||||
break
|
||||
current_end += 1
|
||||
current_token = candidate
|
||||
current_token_columns = candidate_columns
|
||||
|
||||
new_entities.append({'token_start': current_start,
|
||||
'token_end': current_end,
|
||||
'value': current_token,
|
||||
'token_type': current_token_type,
|
||||
'token_in_columns': current_token_columns})
|
||||
return new_entities
|
|
@ -0,0 +1,141 @@
|
|||
# pylint: disable=anomalous-backslash-in-string
|
||||
"""
|
||||
A ``Text2SqlTableContext`` represents the SQL context in which an utterance appears
|
||||
for the any of the text2sql datasets, with the grammar and the valid actions.
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from dataset_readers.dataset_util.spider_utils import Table
|
||||
|
||||
|
||||
GRAMMAR_DICTIONARY = {}
|
||||
GRAMMAR_DICTIONARY["statement"] = ['(query ws iue ws query)', '(query ws)']
|
||||
GRAMMAR_DICTIONARY["iue"] = ['"intersect"', '"except"', '"union"']
|
||||
GRAMMAR_DICTIONARY["query"] = ['(ws select_core ws groupby_clause ws orderby_clause ws limit)',
|
||||
'(ws select_core ws groupby_clause ws orderby_clause)',
|
||||
'(ws select_core ws groupby_clause ws limit)',
|
||||
'(ws select_core ws orderby_clause ws limit)',
|
||||
'(ws select_core ws groupby_clause)',
|
||||
'(ws select_core ws orderby_clause)',
|
||||
'(ws select_core)']
|
||||
|
||||
GRAMMAR_DICTIONARY["select_core"] = ['(select_with_distinct ws select_results ws from_clause ws where_clause)',
|
||||
'(select_with_distinct ws select_results ws from_clause)',
|
||||
'(select_with_distinct ws select_results ws where_clause)',
|
||||
'(select_with_distinct ws select_results)']
|
||||
GRAMMAR_DICTIONARY["select_with_distinct"] = ['(ws "select" ws "distinct")', '(ws "select")']
|
||||
GRAMMAR_DICTIONARY["select_results"] = ['(ws select_result ws "," ws select_results)', '(ws select_result)']
|
||||
GRAMMAR_DICTIONARY["select_result"] = ['"*"', '(table_source ws ".*")',
|
||||
'expr', 'col_ref']
|
||||
|
||||
GRAMMAR_DICTIONARY["from_clause"] = ['(ws "from" ws table_source ws join_clauses)',
|
||||
'(ws "from" ws source)']
|
||||
GRAMMAR_DICTIONARY["join_clauses"] = ['(join_clause ws join_clauses)', 'join_clause']
|
||||
GRAMMAR_DICTIONARY["join_clause"] = ['"join" ws table_source ws "on" ws join_condition_clause']
|
||||
GRAMMAR_DICTIONARY["join_condition_clause"] = ['(join_condition ws "and" ws join_condition_clause)', 'join_condition']
|
||||
GRAMMAR_DICTIONARY["join_condition"] = ['ws col_ref ws "=" ws col_ref']
|
||||
GRAMMAR_DICTIONARY["source"] = ['(ws single_source ws "," ws source)', '(ws single_source)']
|
||||
GRAMMAR_DICTIONARY["single_source"] = ['table_source', 'source_subq']
|
||||
GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")")']
|
||||
# GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")" ws "as" ws name)', '("(" ws query ws ")")']
|
||||
GRAMMAR_DICTIONARY["limit"] = ['("limit" ws non_literal_number)']
|
||||
|
||||
GRAMMAR_DICTIONARY["where_clause"] = ['(ws "where" wsp expr ws where_conj)', '(ws "where" wsp expr)']
|
||||
GRAMMAR_DICTIONARY["where_conj"] = ['(ws "and" wsp expr ws where_conj)', '(ws "and" wsp expr)']
|
||||
|
||||
GRAMMAR_DICTIONARY["groupby_clause"] = ['(ws "group" ws "by" ws group_clause ws "having" ws expr)',
|
||||
'(ws "group" ws "by" ws group_clause)']
|
||||
GRAMMAR_DICTIONARY["group_clause"] = ['(ws expr ws "," ws group_clause)', '(ws expr)']
|
||||
|
||||
GRAMMAR_DICTIONARY["orderby_clause"] = ['ws "order" ws "by" ws order_clause']
|
||||
GRAMMAR_DICTIONARY["order_clause"] = ['(ordering_term ws "," ws order_clause)', 'ordering_term']
|
||||
GRAMMAR_DICTIONARY["ordering_term"] = ['(ws expr ws ordering)', '(ws expr)']
|
||||
GRAMMAR_DICTIONARY["ordering"] = ['(ws "asc")', '(ws "desc")']
|
||||
|
||||
GRAMMAR_DICTIONARY["col_ref"] = ['(table_name ws "." ws column_name)', 'column_name']
|
||||
GRAMMAR_DICTIONARY["table_source"] = ['(table_name ws "as" ws table_alias)', 'table_name']
|
||||
GRAMMAR_DICTIONARY["table_name"] = ["table_alias"]
|
||||
GRAMMAR_DICTIONARY["table_alias"] = ['"t1"', '"t2"', '"t3"', '"t4"']
|
||||
GRAMMAR_DICTIONARY["column_name"] = []
|
||||
|
||||
GRAMMAR_DICTIONARY["ws"] = ['~"\s*"i']
|
||||
GRAMMAR_DICTIONARY['wsp'] = ['~"\s+"i']
|
||||
|
||||
GRAMMAR_DICTIONARY["expr"] = ['in_expr',
|
||||
# Like expressions.
|
||||
'(value wsp "like" wsp string)',
|
||||
# Between expressions.
|
||||
'(value ws "between" wsp value ws "and" wsp value)',
|
||||
# Binary expressions.
|
||||
'(value ws binaryop wsp expr)',
|
||||
# Unary expressions.
|
||||
'(unaryop ws expr)',
|
||||
'source_subq',
|
||||
'value']
|
||||
GRAMMAR_DICTIONARY["in_expr"] = ['(value wsp "not" wsp "in" wsp string_set)',
|
||||
'(value wsp "in" wsp string_set)',
|
||||
'(value wsp "not" wsp "in" wsp expr)',
|
||||
'(value wsp "in" wsp expr)']
|
||||
|
||||
GRAMMAR_DICTIONARY["value"] = ['parenval', '"YEAR(CURDATE())"', 'number', 'boolean',
|
||||
'function', 'col_ref', 'string']
|
||||
GRAMMAR_DICTIONARY["parenval"] = ['"(" ws expr ws ")"']
|
||||
GRAMMAR_DICTIONARY["function"] = ['(fname ws "(" ws "distinct" ws arg_list_or_star ws ")")',
|
||||
'(fname ws "(" ws arg_list_or_star ws ")")']
|
||||
|
||||
GRAMMAR_DICTIONARY["arg_list_or_star"] = ['arg_list', '"*"']
|
||||
GRAMMAR_DICTIONARY["arg_list"] = ['(expr ws "," ws arg_list)', 'expr']
|
||||
# TODO(MARK): Massive hack, remove and modify the grammar accordingly
|
||||
# GRAMMAR_DICTIONARY["number"] = ['~"\d*\.?\d+"i', "'3'", "'4'"]
|
||||
GRAMMAR_DICTIONARY["non_literal_number"] = ['"1"', '"2"', '"3"', '"4"']
|
||||
GRAMMAR_DICTIONARY["number"] = ['ws "value" ws']
|
||||
GRAMMAR_DICTIONARY["string_set"] = ['ws "(" ws string_set_vals ws ")"']
|
||||
GRAMMAR_DICTIONARY["string_set_vals"] = ['(string ws "," ws string_set_vals)', 'string']
|
||||
# GRAMMAR_DICTIONARY["string"] = ['~"\'.*?\'"i']
|
||||
GRAMMAR_DICTIONARY["string"] = ['"\'" ws "value" ws "\'"']
|
||||
GRAMMAR_DICTIONARY["fname"] = ['"count"', '"sum"', '"max"', '"min"', '"avg"', '"all"']
|
||||
GRAMMAR_DICTIONARY["boolean"] = ['"true"', '"false"']
|
||||
|
||||
# TODO(MARK): This is not tight enough. AND/OR are strictly boolean value operators.
|
||||
GRAMMAR_DICTIONARY["binaryop"] = ['"+"', '"-"', '"*"', '"/"', '"="', '"!="', '"<>"',
|
||||
'">="', '"<="', '">"', '"<"', '"and"', '"or"', '"like"']
|
||||
GRAMMAR_DICTIONARY["unaryop"] = ['"+"', '"-"', '"not"', '"not"']
|
||||
|
||||
|
||||
def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],
|
||||
schema: Dict[str, Table]) -> None:
|
||||
table_names = sorted([f'"{table.lower()}"' for table in
|
||||
list(schema.keys())], reverse=True)
|
||||
grammar_dictionary['table_name'] += table_names
|
||||
|
||||
all_columns = set()
|
||||
for table in schema.values():
|
||||
all_columns.update([f'"{table.name.lower()}@{column.name.lower()}"' for column in table.columns if column.name != '*'])
|
||||
sorted_columns = sorted([column for column in all_columns], reverse=True)
|
||||
grammar_dictionary['column_name'] += sorted_columns
|
||||
|
||||
|
||||
def update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, List[str]]):
|
||||
"""
|
||||
Remove table names from column names, remove aliases
|
||||
"""
|
||||
|
||||
grammar_dictionary["column_name"] = []
|
||||
grammar_dictionary["table_name"] = []
|
||||
grammar_dictionary["col_ref"] = ['column_name']
|
||||
grammar_dictionary["table_source"] = ['table_name']
|
||||
|
||||
del grammar_dictionary["table_alias"]
|
||||
|
||||
|
||||
def update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]):
|
||||
"""
|
||||
Remove table names from column names, remove aliases
|
||||
"""
|
||||
|
||||
# using a simple rule such as join_clauses-> [(join_clauses ws join_clause), join_clause]
|
||||
# resulted in a max recursion error, so for now just using a predefined max
|
||||
# number of joins
|
||||
grammar_dictionary["join_clauses"] = ['(join_clauses_1 ws join_clause)', 'join_clause']
|
||||
grammar_dictionary["join_clauses_1"] = ['(join_clauses_2 ws join_clause)', 'join_clause']
|
||||
grammar_dictionary["join_clauses_2"] = ['(join_clause ws join_clause)', 'join_clause']
|
|
@ -0,0 +1,607 @@
|
|||
################################
|
||||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
#
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
from nltk import word_tokenize
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
WHERE_COLUMN_OPS = ( '=', '>', '<', '>=', '<=', '!=')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
mapped_entities = []
|
||||
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
def __init__(self, schema):
|
||||
self._schema = schema
|
||||
self._idMap = self._map(self._schema)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema):
|
||||
idMap = {'*': "__all__"}
|
||||
id = 1
|
||||
for key, vals in schema.items():
|
||||
for val in vals:
|
||||
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
|
||||
id += 1
|
||||
|
||||
for key in schema:
|
||||
idMap[key.lower()] = "__" + key.lower() + "__"
|
||||
id += 1
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def get_schema(db):
|
||||
"""
|
||||
Get database's schema, which is a dict with table name as key
|
||||
and list of column names as value
|
||||
:param db: database path
|
||||
:return: schema dict
|
||||
"""
|
||||
|
||||
schema = {}
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# fetch table names
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
tables = [str(table[0].lower()) for table in cursor.fetchall()]
|
||||
|
||||
# fetch table info
|
||||
for table in tables:
|
||||
cursor.execute("PRAGMA table_info({})".format(table))
|
||||
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_schema_from_json(fpath):
|
||||
with open(fpath, encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
schema = {}
|
||||
for entry in data:
|
||||
if 'chase' in fpath:
|
||||
schema[entry['db_id']]={}
|
||||
for table_id,table_name in enumerate(entry['table_names']):
|
||||
cols = []
|
||||
for col in entry['column_names']:
|
||||
if col[0] == table_id:
|
||||
cols.append(str(col[-1].lower().replace(' ','_')))
|
||||
schema[entry['db_id']][table_name] = cols
|
||||
else:
|
||||
table = str(entry['table'].lower())
|
||||
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
|
||||
schema[table] = cols
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def tokenize(string):
|
||||
string = str(string)
|
||||
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
|
||||
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
|
||||
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
|
||||
|
||||
# keep string value as token
|
||||
vals = {}
|
||||
for i in range(len(quote_idxs)-1, -1, -2):
|
||||
qidx1 = quote_idxs[i-1]
|
||||
qidx2 = quote_idxs[i]
|
||||
val = string[qidx1: qidx2+1]
|
||||
key = "__val_{}_{}__".format(qidx1, qidx2)
|
||||
string = string[:qidx1] + key + string[qidx2+1:]
|
||||
vals[key] = val
|
||||
|
||||
toks = [word.lower() for word in word_tokenize(string)]
|
||||
# replace with string value token
|
||||
for i in range(len(toks)):
|
||||
if toks[i] in vals:
|
||||
toks[i] = vals[toks[i]]
|
||||
|
||||
# find if there exists !=, >=, <=
|
||||
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
|
||||
eq_idxs.reverse()
|
||||
prefix = ('!', '>', '<')
|
||||
for eq_idx in eq_idxs:
|
||||
pre_tok = toks[eq_idx-1]
|
||||
if pre_tok in prefix:
|
||||
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
|
||||
|
||||
return toks
|
||||
|
||||
|
||||
def scan_alias(toks):
|
||||
"""Scan the index of 'as' and build the map for all alias"""
|
||||
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
|
||||
alias = {}
|
||||
for idx in as_idxs:
|
||||
alias[toks[idx+1]] = toks[idx-1]
|
||||
return alias
|
||||
|
||||
|
||||
def get_tables_with_alias(schema, toks):
|
||||
tables = scan_alias(toks)
|
||||
for key in schema:
|
||||
assert key not in tables, "Alias {} has the same name in table".format(key)
|
||||
tables[key] = key
|
||||
return tables
|
||||
|
||||
|
||||
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, column id
|
||||
"""
|
||||
global mapped_entities
|
||||
tok = toks[start_idx]
|
||||
if tok == "*":
|
||||
return start_idx + 1, schema.idMap[tok]
|
||||
|
||||
if '.' in tok: # if token is a composite
|
||||
alias, col = tok.split('.')
|
||||
key = tables_with_alias[alias] + "." + col
|
||||
mapped_entities.append((start_idx, tables_with_alias[alias] + "@" + col))
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
|
||||
|
||||
for alias in default_tables:
|
||||
table = tables_with_alias[alias]
|
||||
if tok in schema.schema[table]:
|
||||
key = table + "." + tok
|
||||
mapped_entities.append((start_idx, table + "@" + tok))
|
||||
return start_idx+1, schema.idMap[key]
|
||||
|
||||
assert False, "Error col: {}".format(tok)
|
||||
|
||||
|
||||
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, (agg_op id, col_id)
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
isDistinct = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
assert idx < len_ and toks[idx] == '('
|
||||
idx += 1
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert idx < len_ and toks[idx] == ')'
|
||||
idx += 1
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
agg_id = AGG_OPS.index("none")
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
|
||||
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
col_unit1 = None
|
||||
col_unit2 = None
|
||||
unit_op = UNIT_OPS.index('none')
|
||||
|
||||
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if idx < len_ and toks[idx] in UNIT_OPS:
|
||||
unit_op = UNIT_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (unit_op, col_unit1, col_unit2)
|
||||
|
||||
|
||||
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
:returns next idx, table id, table name
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
key = tables_with_alias[toks[idx]]
|
||||
|
||||
if idx + 1 < len_ and toks[idx+1] == "as":
|
||||
idx += 3
|
||||
else:
|
||||
idx += 1
|
||||
|
||||
return idx, schema.idMap[key], key
|
||||
|
||||
|
||||
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
elif "\"" in toks[idx] or toks[idx]=='value': # token is a string value
|
||||
val = toks[idx]
|
||||
idx += 1
|
||||
elif "'" == toks[idx]:
|
||||
idx+=1
|
||||
val=[]
|
||||
while toks[idx]!="'":
|
||||
val.append(toks[idx])
|
||||
idx+=1
|
||||
idx+=1
|
||||
val='"'+''.join(val)+'"'
|
||||
else:
|
||||
try:
|
||||
val = float(toks[idx])
|
||||
idx += 1
|
||||
except:
|
||||
end_idx = idx
|
||||
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
|
||||
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
|
||||
end_idx += 1
|
||||
|
||||
idx, val = parse_col_unit(toks[: end_idx], start_idx, tables_with_alias, schema, default_tables)
|
||||
idx = end_idx
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
|
||||
return idx, val
|
||||
|
||||
|
||||
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
not_op = False
|
||||
if toks[idx] == 'not':
|
||||
not_op = True
|
||||
idx += 1
|
||||
|
||||
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
|
||||
op_id = WHERE_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
|
||||
# if idx<len_-2 and \
|
||||
# '.' in toks[idx] and toks[idx+1] in UNIT_OPS and '.' in toks[idx+2]:
|
||||
# idx, val_unit2 = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
# conds.append((not_op, op_id, val_unit, val_unit2))
|
||||
# elif '.' in toks[idx]:
|
||||
# idx, val_unit2 = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
# conds.append((not_op, op_id, val_unit, val_unit2))
|
||||
# else:
|
||||
val1 = val2 = None
|
||||
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert toks[idx] == 'and'
|
||||
idx += 1
|
||||
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
else: # normal case: single value
|
||||
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val2 = None
|
||||
|
||||
conds.append((not_op, op_id, val_unit, val1, val2))
|
||||
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
|
||||
break
|
||||
|
||||
if idx < len_ and toks[idx] in COND_OPS:
|
||||
conds.append(toks[idx])
|
||||
idx += 1 # skip and/or
|
||||
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
assert toks[idx] == 'select', "'select' not found"
|
||||
idx += 1
|
||||
isDistinct = False
|
||||
if idx < len_ and toks[idx] == 'distinct':
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
val_units = []
|
||||
|
||||
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
|
||||
agg_id = AGG_OPS.index("none")
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append((agg_id, val_unit))
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
|
||||
return idx, (isDistinct, val_units)
|
||||
|
||||
|
||||
def parse_from(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
Assume in the from clause, all table units are combined with join
|
||||
"""
|
||||
assert 'from' in toks[start_idx:], "'from' not found"
|
||||
|
||||
len_ = len(toks)
|
||||
idx = toks.index('from', start_idx) + 1
|
||||
default_tables = []
|
||||
table_units = []
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
isBlock = False
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == 'select':
|
||||
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['sql'], sql))
|
||||
else:
|
||||
if idx < len_ and toks[idx] == 'join':
|
||||
idx += 1 # skip join
|
||||
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE['table_unit'],table_unit))
|
||||
default_tables.append(table_name)
|
||||
if idx < len_ and toks[idx] == "on":
|
||||
idx += 1 # skip on
|
||||
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
if len(conds) > 0:
|
||||
conds.append('and')
|
||||
conds.extend(this_conds)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
break
|
||||
|
||||
return idx, table_units, conds, default_tables
|
||||
|
||||
|
||||
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'where':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
col_units = []
|
||||
|
||||
if idx >= len_ or toks[idx] != 'group':
|
||||
return idx, col_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
col_units.append(col_unit)
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, col_units
|
||||
|
||||
|
||||
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
val_units = []
|
||||
order_type = 'asc' # default type is 'asc'
|
||||
|
||||
if idx >= len_ or toks[idx] != 'order':
|
||||
return idx, val_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == 'by'
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
|
||||
val_units.append(val_unit)
|
||||
if idx < len_ and toks[idx] in ORDER_OPS:
|
||||
order_type = toks[idx]
|
||||
idx += 1
|
||||
if idx < len_ and toks[idx] == ',':
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, (order_type, val_units)
|
||||
|
||||
|
||||
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != 'having':
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_limit(toks, start_idx):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx < len_ and toks[idx] == 'limit':
|
||||
idx += 2
|
||||
try:
|
||||
limit_val = int(toks[idx-1])
|
||||
except Exception:
|
||||
limit_val = '"value"'
|
||||
return idx, limit_val
|
||||
|
||||
return idx, None
|
||||
|
||||
|
||||
def parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entities_fn=None):
|
||||
global mapped_entities
|
||||
|
||||
if mapped_entities_fn is not None:
|
||||
mapped_entities = mapped_entities_fn()
|
||||
isBlock = False # indicate whether this is a block of sql/sub-sql
|
||||
len_ = len(toks)
|
||||
idx = start_idx
|
||||
|
||||
sql = {}
|
||||
if toks[idx] == '(':
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
# parse from clause in order to get default tables
|
||||
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
|
||||
sql['from'] = {'table_units': table_units, 'conds': conds}
|
||||
# select clause
|
||||
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
|
||||
idx = from_end_idx
|
||||
sql['select'] = select_col_units
|
||||
# where clause
|
||||
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['where'] = where_conds
|
||||
# group by clause
|
||||
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['groupBy'] = group_col_units
|
||||
# having clause
|
||||
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['having'] = having_conds
|
||||
# order by clause
|
||||
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql['orderBy'] = order_col_units
|
||||
# limit clause
|
||||
idx, limit_val = parse_limit(toks, idx)
|
||||
sql['limit'] = limit_val
|
||||
|
||||
idx = skip_semicolon(toks, idx)
|
||||
if isBlock:
|
||||
assert toks[idx] == ')'
|
||||
idx += 1 # skip ')'
|
||||
idx = skip_semicolon(toks, idx)
|
||||
|
||||
# intersect/union/except clause
|
||||
for op in SQL_OPS: # initialize IUE
|
||||
sql[op] = None
|
||||
if idx < len_ and toks[idx] in SQL_OPS:
|
||||
sql_op = toks[idx]
|
||||
idx += 1
|
||||
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
sql[sql_op] = IUE_sql
|
||||
|
||||
if mapped_entities_fn is not None:
|
||||
return idx, sql, mapped_entities
|
||||
else:
|
||||
return idx, sql
|
||||
|
||||
|
||||
def load_data(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def get_sql(schema, query):
|
||||
toks = tokenize(query)
|
||||
tables_with_alias = get_tables_with_alias(schema.schema, toks)
|
||||
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def skip_semicolon(toks, start_idx):
|
||||
idx = start_idx
|
||||
while idx < len(toks) and toks[idx] == ";":
|
||||
idx += 1
|
||||
return idx
|
|
@ -0,0 +1,168 @@
|
|||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2019 seq2struct contributors and Microsoft Corporation
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
|
||||
import os
|
||||
import dataclasses
|
||||
import json
|
||||
from typing import Optional, Tuple, List, Iterable
|
||||
|
||||
import networkx as nx
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.main import BaseConfig
|
||||
|
||||
from third_party.spider import evaluation
|
||||
from third_party.spider.preprocess.schema import get_schemas_from_json, Schema
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpiderTable:
|
||||
id: int
|
||||
name: List[str]
|
||||
unsplit_name: str
|
||||
orig_name: str
|
||||
columns: List["SpiderColumn"] = dataclasses.field(default_factory=list)
|
||||
primary_keys: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpiderColumn:
|
||||
id: int
|
||||
table: Optional[SpiderTable]
|
||||
name: List[str]
|
||||
unsplit_name: str
|
||||
orig_name: str
|
||||
type: str
|
||||
foreign_key_for: Optional[str] = None
|
||||
|
||||
|
||||
SpiderTable.__pydantic_model__.update_forward_refs()
|
||||
|
||||
|
||||
class SpiderSchemaConfig:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@dataclass(config=SpiderSchemaConfig)
|
||||
class SpiderSchema(BaseConfig):
|
||||
db_id: str
|
||||
tables: Tuple[SpiderTable, ...]
|
||||
columns: Tuple[SpiderColumn, ...]
|
||||
foreign_key_graph: nx.DiGraph
|
||||
orig: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpiderItem:
|
||||
question: str
|
||||
slml_question: Optional[str]
|
||||
query: str
|
||||
spider_sql: dict
|
||||
spider_schema: SpiderSchema
|
||||
db_path: str
|
||||
orig: dict
|
||||
|
||||
|
||||
def schema_dict_to_spider_schema(schema_dict):
|
||||
tables = tuple(
|
||||
SpiderTable(id=i, name=name.split(), unsplit_name=name, orig_name=orig_name,)
|
||||
for i, (name, orig_name) in enumerate(
|
||||
zip(schema_dict["table_names"], schema_dict["table_names_original"])
|
||||
)
|
||||
)
|
||||
columns = tuple(
|
||||
SpiderColumn(
|
||||
id=i,
|
||||
table=tables[table_id] if table_id >= 0 else None,
|
||||
name=col_name.split(),
|
||||
unsplit_name=col_name,
|
||||
orig_name=orig_col_name,
|
||||
type=col_type,
|
||||
)
|
||||
for i, ((table_id, col_name), (_, orig_col_name), col_type,) in enumerate(
|
||||
zip(
|
||||
schema_dict["column_names"],
|
||||
schema_dict["column_names_original"],
|
||||
schema_dict["column_types"],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Link columns to tables
|
||||
for column in columns:
|
||||
if column.table:
|
||||
column.table.columns.append(column)
|
||||
|
||||
for column_id in schema_dict["primary_keys"]:
|
||||
# Register primary keys
|
||||
column = columns[column_id]
|
||||
column.table.primary_keys.append(column)
|
||||
|
||||
foreign_key_graph = nx.DiGraph()
|
||||
for source_column_id, dest_column_id in schema_dict["foreign_keys"]:
|
||||
# Register foreign keys
|
||||
source_column = columns[source_column_id]
|
||||
dest_column = columns[dest_column_id]
|
||||
source_column.foreign_key_for = dest_column
|
||||
foreign_key_graph.add_edge(
|
||||
source_column.table.id,
|
||||
dest_column.table.id,
|
||||
columns=(source_column_id, dest_column_id),
|
||||
)
|
||||
foreign_key_graph.add_edge(
|
||||
dest_column.table.id,
|
||||
source_column.table.id,
|
||||
columns=(dest_column_id, source_column_id),
|
||||
)
|
||||
|
||||
db_id = schema_dict["db_id"]
|
||||
return SpiderSchema(db_id, tables, columns, foreign_key_graph, schema_dict)
|
||||
|
||||
|
||||
def load_tables(paths):
|
||||
schemas = {}
|
||||
eval_foreign_key_maps = {}
|
||||
|
||||
with open(paths, 'r',encoding='UTF-8') as f:
|
||||
schema_dicts = json.load(f)
|
||||
|
||||
for schema_dict in schema_dicts:
|
||||
db_id = schema_dict["db_id"]
|
||||
if 'column_names_original' not in schema_dict: # {'table': [col.lower, ..., ]} * -> __all__
|
||||
# continue
|
||||
schema_dict["column_names_original"] = schema_dict["column_names"]
|
||||
schema_dict["table_names_original"] = schema_dict["table_names"]
|
||||
|
||||
# assert db_id not in schemas
|
||||
schemas[db_id] = schema_dict_to_spider_schema(schema_dict)
|
||||
eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(schema_dict)
|
||||
|
||||
|
||||
return schemas, eval_foreign_key_maps
|
||||
|
||||
|
||||
def load_original_schemas(tables_paths):
|
||||
all_schemas = {}
|
||||
schemas, db_ids, tables = get_schemas_from_json(tables_paths)
|
||||
for db_id in db_ids:
|
||||
all_schemas[db_id] = Schema(schemas[db_id], tables[db_id])
|
||||
return all_schemas
|
|
@ -0,0 +1,343 @@
|
|||
"""
|
||||
Utility functions for reading the standardised text2sql datasets presented in
|
||||
`"Improving Text to SQL Evaluation Methodology" <https://arxiv.org/abs/1806.09029>`_
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict, Optional, Any
|
||||
from semparse.sql.process_sql import get_tables_with_alias, parse_sql
|
||||
|
||||
|
||||
class TableColumn:
|
||||
def __init__(self,
|
||||
name: str,
|
||||
text: str,
|
||||
column_type: str,
|
||||
is_primary_key: bool,
|
||||
foreign_key: Optional[str],
|
||||
lemma: Optional[str]):
|
||||
self.name = name
|
||||
self.text = text
|
||||
self.column_type = column_type
|
||||
self.is_primary_key = is_primary_key
|
||||
self.foreign_key = foreign_key
|
||||
self.lemma = lemma
|
||||
|
||||
|
||||
class Table:
|
||||
def __init__(self,
|
||||
name: str,
|
||||
text: str,
|
||||
columns: List[TableColumn],
|
||||
lemma: Optional[str]):
|
||||
self.name = name
|
||||
self.text = text
|
||||
self.columns = columns
|
||||
self.lemma = lemma
|
||||
|
||||
|
||||
def read_dataset_schema(schema_path: str, stanza_model=None) -> Dict[str, List[Table]]:
|
||||
schemas: Dict[str, Dict[str, Table]] = defaultdict(dict)
|
||||
dbs_json_blob = json.load(open(schema_path, "r", encoding='utf-8'))
|
||||
for db in dbs_json_blob:
|
||||
db_id = db['db_id']
|
||||
|
||||
column_id_to_table = {}
|
||||
column_id_to_column = {}
|
||||
|
||||
concate_columns = [c[-1] for c in db['column_names']]
|
||||
concate_tables = [c for c in db['table_names']]
|
||||
|
||||
#load stanza model
|
||||
if stanza_model is not None:
|
||||
lemma_columns = stanza_model('\n\n'.join(concate_columns).replace(' ','none'))
|
||||
lemma_columns_collect = []
|
||||
for sent in lemma_columns.sentences:
|
||||
tmp = []
|
||||
for word in sent.words:
|
||||
if word.lemma != None:
|
||||
tmp.append(word.lemma)
|
||||
elif word.text==' ':
|
||||
tmp.append('none')
|
||||
else:
|
||||
tmp.append(word.text)
|
||||
lemma_columns_collect.append(' '.join(tmp))
|
||||
|
||||
lemma_tables = stanza_model('\n\n'.join(concate_tables).replace(' ','none'))
|
||||
lemma_tables_collect = {}
|
||||
for t,sent in zip(concate_tables, lemma_tables.sentences):
|
||||
tmp = []
|
||||
for word in sent.words:
|
||||
if word.lemma != None:
|
||||
tmp.append(word.lemma)
|
||||
elif word.text == ' ':
|
||||
tmp.append('none')
|
||||
else:
|
||||
tmp.append(word.text)
|
||||
lemma_tables_collect[t]=' '.join(tmp)
|
||||
else:
|
||||
lemma_columns_collect = concate_columns
|
||||
lemma_tables_collect = {t:t for t in concate_tables}
|
||||
|
||||
for i, (column, text, column_type) in enumerate(zip(db['column_names_original'], db['column_names'], db['column_types'])):
|
||||
table_id, column_name = column
|
||||
_, column_text = text
|
||||
|
||||
table_name = db['table_names_original'][table_id]
|
||||
|
||||
if table_name not in schemas[db_id]:
|
||||
table_text = db['table_names'][table_id]
|
||||
table_lemma = lemma_tables_collect[table_text]
|
||||
schemas[db_id][table_name] = Table(table_name, table_text, [], table_lemma)
|
||||
|
||||
if column_name == "*":
|
||||
continue
|
||||
|
||||
is_primary_key = i in db['primary_keys']
|
||||
table_column = TableColumn(column_name.lower(), column_text, column_type, is_primary_key, None, lemma_columns_collect[i])
|
||||
schemas[db_id][table_name].columns.append(table_column)
|
||||
column_id_to_table[i] = table_name
|
||||
column_id_to_column[i] = table_column
|
||||
|
||||
for (c1, c2) in db['foreign_keys']:
|
||||
foreign_key = column_id_to_table[c2] + ':' + column_id_to_column[c2].name
|
||||
column_id_to_column[c1].foreign_key = foreign_key
|
||||
|
||||
return {**schemas}
|
||||
|
||||
|
||||
def read_dataset_values(db_id: str, dataset_path: str, tables: List[str]):
|
||||
db = os.path.join(dataset_path, db_id, db_id + ".sqlite")
|
||||
try:
|
||||
conn = sqlite3.connect(db)
|
||||
except Exception as e:
|
||||
raise Exception(f"Can't connect to SQL: {e} in path {db}")
|
||||
conn.text_factory = str
|
||||
cursor = conn.cursor()
|
||||
|
||||
values = {}
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
|
||||
values[table] = cursor.fetchall()
|
||||
except:
|
||||
conn.text_factory = lambda x: str(x, 'latin1')
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
|
||||
values[table] = cursor.fetchall()
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def ent_key_to_name(key):
|
||||
parts = key.split(':')
|
||||
if parts[0] == 'table':
|
||||
return parts[1]
|
||||
elif parts[0] == 'column':
|
||||
_, _, table_name, column_name = parts
|
||||
return f'{table_name}@{column_name}'
|
||||
else:
|
||||
return parts[1]
|
||||
|
||||
|
||||
def fix_number_value(ex):
|
||||
"""
|
||||
There is something weird in the dataset files - the `query_toks_no_value` field anonymizes all values,
|
||||
which is good since the evaluator doesn't check for the values. But it also anonymizes numbers that
|
||||
should not be anonymized: e.g. LIMIT 3 becomes LIMIT 'value', while the evaluator fails if it is not a number.
|
||||
"""
|
||||
|
||||
def split_and_keep(s, sep):
|
||||
if not s: return [''] # consistent with string.split()
|
||||
|
||||
# Find replacement character that is not used in string
|
||||
# i.e. just use the highest available character plus one
|
||||
# Note: This fails if ord(max(s)) = 0x10FFFF (ValueError)
|
||||
p = chr(ord(max(s)) + 1)
|
||||
|
||||
return s.replace(sep, p + sep + p).split(p)
|
||||
|
||||
# input is tokenized in different ways... so first try to make splits equal
|
||||
query_toks = ex['query_toks']
|
||||
ex['query_toks'] = []
|
||||
for q in query_toks:
|
||||
ex['query_toks'] += split_and_keep(q, '.')
|
||||
|
||||
i_val, i_no_val = 0, 0
|
||||
while i_val < len(ex['query_toks']) and i_no_val < len(ex['query_toks_no_value']):
|
||||
if ex['query_toks_no_value'][i_no_val] != 'value':
|
||||
i_val += 1
|
||||
i_no_val += 1
|
||||
continue
|
||||
|
||||
i_val_end = i_val
|
||||
while i_val + 1 < len(ex['query_toks']) and \
|
||||
i_no_val + 1 < len(ex['query_toks_no_value']) and \
|
||||
ex['query_toks'][i_val_end + 1].lower() != ex['query_toks_no_value'][i_no_val + 1].lower():
|
||||
i_val_end += 1
|
||||
|
||||
if i_val == i_val_end and ex['query_toks'][i_val] in ["1", "2", "3", "4", "5"] and ex['query_toks'][i_val - 1].lower() == "limit":
|
||||
ex['query_toks_no_value'][i_no_val] = ex['query_toks'][i_val]
|
||||
i_val = i_val_end
|
||||
|
||||
i_val += 1
|
||||
i_no_val += 1
|
||||
|
||||
return ex
|
||||
|
||||
|
||||
_schemas_cache = None
|
||||
|
||||
|
||||
def disambiguate_items(db_id: str, query_toks: List[str], tables_file: str, allow_aliases: bool) -> List[str]:
|
||||
"""
|
||||
we want the query tokens to be non-ambiguous - so we can change each column name to explicitly
|
||||
tell which table it belongs to
|
||||
|
||||
parsed sql to sql clause is based on supermodel.gensql from syntaxsql
|
||||
"""
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
|
||||
def __init__(self, schema, table):
|
||||
self._schema = schema
|
||||
self._table = table
|
||||
self._idMap = self._map(self._schema, self._table)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema, table):
|
||||
column_names_original = table['column_names_original']
|
||||
table_names_original = table['table_names_original']
|
||||
# print 'column_names_original: ', column_names_original
|
||||
# print 'table_names_original: ', table_names_original
|
||||
for i, (tab_id, col) in enumerate(column_names_original):
|
||||
if tab_id == -1:
|
||||
idMap = {'*': i}
|
||||
else:
|
||||
key = table_names_original[tab_id].lower()
|
||||
val = col.lower().replace(' ','_')
|
||||
idMap[key + "." + val] = i
|
||||
|
||||
for i, tab in enumerate(table_names_original):
|
||||
key = tab.lower()
|
||||
idMap[key] = i
|
||||
|
||||
return idMap
|
||||
|
||||
def get_schemas_from_json(fpath):
|
||||
global _schemas_cache
|
||||
|
||||
if _schemas_cache is not None:
|
||||
return _schemas_cache
|
||||
|
||||
with open(fpath, encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
db_names = [db['db_id'] for db in data]
|
||||
|
||||
tables = {}
|
||||
schemas = {}
|
||||
for db in data:
|
||||
db_id = db['db_id']
|
||||
schema = {} # {'table': [col.lower, ..., ]} * -> __all__
|
||||
column_names_original = db['column_names_original'] if 'column_names_original' in db else db['column_names']
|
||||
table_names_original = db['table_names_original'] if 'table_names_original' in db else db['table_names']
|
||||
tables[db_id] = {'column_names_original': column_names_original,
|
||||
'table_names_original': table_names_original}
|
||||
for i, tabn in enumerate(table_names_original):
|
||||
table = str(tabn.lower())
|
||||
cols = [str(col.lower().replace(' ','_')) for td, col in column_names_original if td == i]
|
||||
schema[table] = cols
|
||||
schemas[db_id] = schema
|
||||
|
||||
_schemas_cache = schemas, db_names, tables
|
||||
return _schemas_cache
|
||||
|
||||
schemas, db_names, tables = get_schemas_from_json(tables_file)
|
||||
schema = Schema(schemas[db_id], tables[db_id])
|
||||
|
||||
fixed_toks = []
|
||||
i = 0
|
||||
while i < len(query_toks):
|
||||
tok = query_toks[i]
|
||||
if tok == 'value' or tok == "'value'":
|
||||
# TODO: value should alawys be between '/" (remove first if clause)
|
||||
new_tok = f'"{tok}"'
|
||||
elif tok in ['!','<','>'] and query_toks[i+1] == '=':
|
||||
new_tok = tok + '='
|
||||
i += 1
|
||||
# elif i+1 < len(query_toks) and query_toks[i+1] == '.' and query_toks[i] in schema.schema.keys():
|
||||
elif i + 1 < len(query_toks) and query_toks[i + 1] == '.':
|
||||
new_tok = ''.join(query_toks[i:i+3])
|
||||
i += 2
|
||||
else:
|
||||
new_tok = tok
|
||||
fixed_toks.append(new_tok)
|
||||
i += 1
|
||||
|
||||
toks = fixed_toks
|
||||
|
||||
tables_with_alias = get_tables_with_alias(schema.schema, toks)
|
||||
_, sql, mapped_entities = parse_sql(toks, 0, tables_with_alias, schema, mapped_entities_fn=lambda: [])
|
||||
|
||||
for i, new_name in mapped_entities:
|
||||
curr_tok = toks[i]
|
||||
if '.' in curr_tok and allow_aliases:
|
||||
parts = curr_tok.split('.')
|
||||
assert(len(parts) == 2)
|
||||
toks[i] = parts[0] + '.' + new_name
|
||||
else:
|
||||
toks[i] = new_name
|
||||
|
||||
if not allow_aliases:
|
||||
toks = [tok for tok in toks if tok not in ['as', 't1', 't2', 't3', 't4', 't5', 't6', 't7', 't8', 't9', 't10']]
|
||||
|
||||
toks = [f'\'value\'' if tok == '"value"' else tok for tok in toks]
|
||||
|
||||
return toks
|
||||
|
||||
|
||||
def remove_on(query):
|
||||
query_tok = query.split()
|
||||
sql_words = []
|
||||
t = 0
|
||||
while t < len(query_tok):
|
||||
if query_tok[t] != 'on':
|
||||
sql_words.append(query_tok[t])
|
||||
t += 1
|
||||
else:
|
||||
t += 4
|
||||
return ' '.join(sql_words)
|
||||
|
||||
def read_dataset_values_from_json(db_id: str, db_content_dict: Dict[str, Any], tables: List[str]):
|
||||
|
||||
values = {}
|
||||
item = db_content_dict[db_id]
|
||||
for table in tables:
|
||||
values[table] = item['tables'][table.name]['cell']
|
||||
return values
|
||||
|
||||
def extract_tree_style(sent):
|
||||
"""
|
||||
sent: List
|
||||
"""
|
||||
rnt = []
|
||||
|
||||
if __name__ == '__main__':
|
||||
import stanza
|
||||
stanza_model = stanza.Pipeline('en')
|
||||
doc = stanza_model("what is the name of the breed with the most dogs ?")
|
||||
word=[word.lemma for sent in doc.sentences for word in sent.words]
|
||||
rnt = []
|
|
@ -0,0 +1,881 @@
|
|||
################################
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
import os
|
||||
import json
|
||||
import sqlite3
|
||||
import argparse
|
||||
|
||||
from semparse.sql.process_sql import get_schema, Schema, get_sql
|
||||
|
||||
# Flag to disable value evaluation
|
||||
DISABLE_VALUE = True
|
||||
# Flag to disable distinct in select evaluation
|
||||
DISABLE_DISTINCT = True
|
||||
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
HARDNESS = {
|
||||
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
|
||||
"component2": ('except', 'union', 'intersect')
|
||||
}
|
||||
|
||||
|
||||
def condition_has_or(conds):
|
||||
return 'or' in conds[1::2]
|
||||
|
||||
|
||||
def condition_has_like(conds):
|
||||
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
|
||||
|
||||
|
||||
def condition_has_sql(conds):
|
||||
for cond_unit in conds[::2]:
|
||||
val1, val2 = cond_unit[3], cond_unit[4]
|
||||
if val1 is not None and type(val1) is dict:
|
||||
return True
|
||||
if val2 is not None and type(val2) is dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def val_has_op(val_unit):
|
||||
return val_unit[0] != UNIT_OPS.index('none')
|
||||
|
||||
|
||||
def has_agg(unit):
|
||||
return unit[0] != AGG_OPS.index('none')
|
||||
|
||||
|
||||
def accuracy(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def recall(count, total):
|
||||
if count == total:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def F1(acc, rec):
|
||||
if (acc + rec) == 0:
|
||||
return 0
|
||||
return (2. * acc * rec) / (acc + rec)
|
||||
|
||||
|
||||
def get_scores(count, pred_total, label_total):
|
||||
if pred_total != label_total:
|
||||
return 0,0,0
|
||||
elif count == pred_total:
|
||||
return 1,1,1
|
||||
return 0,0,0
|
||||
|
||||
|
||||
def eval_sel(pred, label):
|
||||
pred_sel = pred['select'][1]
|
||||
label_sel = label['select'][1]
|
||||
label_wo_agg = [unit[1] for unit in label_sel]
|
||||
pred_total = len(pred_sel)
|
||||
label_total = len(label_sel)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_sel:
|
||||
if unit in label_sel:
|
||||
cnt += 1
|
||||
label_sel.remove(unit)
|
||||
if unit[1] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[1])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_where(pred, label):
|
||||
pred_conds = [unit for unit in pred['where'][::2]]
|
||||
label_conds = [unit for unit in label['where'][::2]]
|
||||
label_wo_agg = [unit[2] for unit in label_conds]
|
||||
pred_total = len(pred_conds)
|
||||
label_total = len(label_conds)
|
||||
cnt = 0
|
||||
cnt_wo_agg = 0
|
||||
|
||||
for unit in pred_conds:
|
||||
if unit in label_conds:
|
||||
cnt += 1
|
||||
label_conds.remove(unit)
|
||||
if unit[2] in label_wo_agg:
|
||||
cnt_wo_agg += 1
|
||||
label_wo_agg.remove(unit[2])
|
||||
|
||||
return label_total, pred_total, cnt, cnt_wo_agg
|
||||
|
||||
|
||||
def eval_group(pred, label):
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
pred_total = len(pred_cols)
|
||||
label_total = len(label_cols)
|
||||
cnt = 0
|
||||
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
|
||||
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
|
||||
for col in pred_cols:
|
||||
if col in label_cols:
|
||||
cnt += 1
|
||||
label_cols.remove(col)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_having(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['groupBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['groupBy']) > 0:
|
||||
label_total = 1
|
||||
|
||||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||||
label_cols = [unit[1] for unit in label['groupBy']]
|
||||
if pred_total == label_total == 1 \
|
||||
and pred_cols == label_cols \
|
||||
and pred['having'] == label['having']:
|
||||
cnt = 1
|
||||
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_order(pred, label):
|
||||
pred_total = label_total = cnt = 0
|
||||
if len(pred['orderBy']) > 0:
|
||||
pred_total = 1
|
||||
if len(label['orderBy']) > 0:
|
||||
label_total = 1
|
||||
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
|
||||
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
|
||||
cnt = 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_and_or(pred, label):
|
||||
pred_ao = pred['where'][1::2]
|
||||
label_ao = label['where'][1::2]
|
||||
pred_ao = set(pred_ao)
|
||||
label_ao = set(label_ao)
|
||||
|
||||
if pred_ao == label_ao:
|
||||
return 1,1,1
|
||||
return len(pred_ao),len(label_ao),0
|
||||
|
||||
|
||||
def get_nestedSQL(sql):
|
||||
nested = []
|
||||
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
|
||||
if type(cond_unit[3]) is dict:
|
||||
nested.append(cond_unit[3])
|
||||
if type(cond_unit[4]) is dict:
|
||||
nested.append(cond_unit[4])
|
||||
if sql['intersect'] is not None:
|
||||
nested.append(sql['intersect'])
|
||||
if sql['except'] is not None:
|
||||
nested.append(sql['except'])
|
||||
if sql['union'] is not None:
|
||||
nested.append(sql['union'])
|
||||
return nested
|
||||
|
||||
|
||||
def eval_nested(pred, label):
|
||||
label_total = 0
|
||||
pred_total = 0
|
||||
cnt = 0
|
||||
if pred is not None:
|
||||
pred_total += 1
|
||||
if label is not None:
|
||||
label_total += 1
|
||||
if pred is not None and label is not None:
|
||||
cnt += Evaluator().eval_exact_match(pred, label)
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def eval_IUEN(pred, label):
|
||||
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
|
||||
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
|
||||
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
|
||||
label_total = lt1 + lt2 + lt3
|
||||
pred_total = pt1 + pt2 + pt3
|
||||
cnt = cnt1 + cnt2 + cnt3
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def get_keywords(sql):
|
||||
res = set()
|
||||
if len(sql['where']) > 0:
|
||||
res.add('where')
|
||||
if len(sql['groupBy']) > 0:
|
||||
res.add('group')
|
||||
if len(sql['having']) > 0:
|
||||
res.add('having')
|
||||
if len(sql['orderBy']) > 0:
|
||||
res.add(sql['orderBy'][0])
|
||||
res.add('order')
|
||||
if sql['limit'] is not None:
|
||||
res.add('limit')
|
||||
if sql['except'] is not None:
|
||||
res.add('except')
|
||||
if sql['union'] is not None:
|
||||
res.add('union')
|
||||
if sql['intersect'] is not None:
|
||||
res.add('intersect')
|
||||
|
||||
# or keyword
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
if len([token for token in ao if token == 'or']) > 0:
|
||||
res.add('or')
|
||||
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
# not keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
|
||||
res.add('not')
|
||||
|
||||
# in keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
|
||||
res.add('in')
|
||||
|
||||
# like keyword
|
||||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
|
||||
res.add('like')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def eval_keywords(pred, label):
|
||||
pred_keywords = get_keywords(pred)
|
||||
label_keywords = get_keywords(label)
|
||||
pred_total = len(pred_keywords)
|
||||
label_total = len(label_keywords)
|
||||
cnt = 0
|
||||
|
||||
for k in pred_keywords:
|
||||
if k in label_keywords:
|
||||
cnt += 1
|
||||
return label_total, pred_total, cnt
|
||||
|
||||
|
||||
def count_agg(units):
|
||||
return len([unit for unit in units if has_agg(unit)])
|
||||
|
||||
|
||||
def count_component1(sql):
|
||||
count = 0
|
||||
if len(sql['where']) > 0:
|
||||
count += 1
|
||||
if len(sql['groupBy']) > 0:
|
||||
count += 1
|
||||
if len(sql['orderBy']) > 0:
|
||||
count += 1
|
||||
if sql['limit'] is not None:
|
||||
count += 1
|
||||
if len(sql['from']['table_units']) > 0: # JOIN
|
||||
count += len(sql['from']['table_units']) - 1
|
||||
|
||||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||||
count += len([token for token in ao if token == 'or'])
|
||||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||||
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def count_component2(sql):
|
||||
nested = get_nestedSQL(sql)
|
||||
return len(nested)
|
||||
|
||||
|
||||
def count_others(sql):
|
||||
count = 0
|
||||
# number of aggregation
|
||||
agg_count = count_agg(sql['select'][1])
|
||||
agg_count += count_agg(sql['where'][::2])
|
||||
agg_count += count_agg(sql['groupBy'])
|
||||
if len(sql['orderBy']) > 0:
|
||||
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
|
||||
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
|
||||
agg_count += count_agg(sql['having'])
|
||||
if agg_count > 1:
|
||||
count += 1
|
||||
|
||||
# number of select columns
|
||||
if len(sql['select'][1]) > 1:
|
||||
count += 1
|
||||
|
||||
# number of where conditions
|
||||
if len(sql['where']) > 1:
|
||||
count += 1
|
||||
|
||||
# number of group by clauses
|
||||
if len(sql['groupBy']) > 1:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""A simple evaluator"""
|
||||
def __init__(self):
|
||||
self.partial_scores = None
|
||||
|
||||
def eval_hardness(self, sql):
|
||||
count_comp1_ = count_component1(sql)
|
||||
count_comp2_ = count_component2(sql)
|
||||
count_others_ = count_others(sql)
|
||||
|
||||
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
|
||||
return "easy"
|
||||
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
|
||||
return "medium"
|
||||
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
|
||||
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
|
||||
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
|
||||
return "hard"
|
||||
else:
|
||||
return "extra"
|
||||
|
||||
def eval_exact_match(self, pred, label):
|
||||
partial_scores = self.eval_partial_match(pred, label)
|
||||
self.partial_scores = partial_scores
|
||||
|
||||
for _, score in partial_scores.items():
|
||||
if score['f1'] != 1:
|
||||
return 0
|
||||
if len(label['from']['table_units']) > 0:
|
||||
label_tables = sorted(label['from']['table_units'])
|
||||
pred_tables = sorted(pred['from']['table_units'])
|
||||
if label_tables != pred_tables:
|
||||
return False
|
||||
# if len(label['from']['conds']) > 0:
|
||||
# label_joins = sorted(label['from']['conds'], key=lambda x: str(x))
|
||||
# pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))
|
||||
# if label_joins != pred_joins:
|
||||
# return False
|
||||
return 1
|
||||
|
||||
def eval_partial_match(self, pred, label):
|
||||
res = {}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||||
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_group(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
# res['group(no Having)'] = {'acc': 1, 'rec': 1, 'f1': 1, 'label_total': 1, 'pred_total': 1}
|
||||
|
||||
label_total, pred_total, cnt = eval_having(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
# res['group'] = {'acc': 1, 'rec': 1, 'f1': 1,'label_total':1,'pred_total':1}
|
||||
|
||||
label_total, pred_total, cnt = eval_order(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_and_or(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_IUEN(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
label_total, pred_total, cnt = eval_keywords(pred, label)
|
||||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||||
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def isValidSQL(sql, db):
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql, [])
|
||||
except Exception as e:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def print_scores(scores, etype):
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
|
||||
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
|
||||
counts = [scores[level]['count'] for level in levels]
|
||||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
print('===================== EXECUTION ACCURACY =====================')
|
||||
this_scores = [scores[level]['exec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
print('\n====================== EXACT MATCHING ACCURACY =====================')
|
||||
exact_scores = [scores[level]['exact'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||||
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
print('---------------------- PARTIAL MATCHING F1 --------------------------')
|
||||
for type_ in partial_types:
|
||||
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps):
|
||||
with open(gold, encoding='utf-8') as f:
|
||||
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
|
||||
with open(predict, encoding='utf-8') as f:
|
||||
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
|
||||
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
|
||||
evaluator = Evaluator()
|
||||
|
||||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||||
entries = []
|
||||
scores = {}
|
||||
|
||||
for level in levels:
|
||||
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
|
||||
scores[level]['exec'] = 0
|
||||
for type_ in partial_types:
|
||||
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
|
||||
|
||||
eval_err_num = 0
|
||||
for p, g in zip(plist, glist):
|
||||
p_str = p[0]
|
||||
g_str, db = g
|
||||
db_name = db
|
||||
db = os.path.join(db_dir, db, db + ".sqlite")
|
||||
schema = Schema(get_schema(db))
|
||||
g_sql = get_sql(schema, g_str)
|
||||
hardness = evaluator.eval_hardness(g_sql)
|
||||
scores[hardness]['count'] += 1
|
||||
scores['all']['count'] += 1
|
||||
|
||||
try:
|
||||
p_sql = get_sql(schema, p_str)
|
||||
except:
|
||||
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
|
||||
p_sql = {
|
||||
"except": None,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": []
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": None,
|
||||
"limit": None,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
False,
|
||||
[]
|
||||
],
|
||||
"union": None,
|
||||
"where": []
|
||||
}
|
||||
eval_err_num += 1
|
||||
print("eval_err_num:{}".format(eval_err_num))
|
||||
print(p_str)
|
||||
print()
|
||||
|
||||
# rebuild sql for value evaluation
|
||||
kmap = kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
|
||||
# p_sql_copy = copy.deepcopy(p_sql)
|
||||
# g_sql_copy = copy.deepcopy(g_sql)
|
||||
# if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):
|
||||
# a = 1
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if exec_score:
|
||||
scores[hardness]['exec'] += 1
|
||||
scores['all']['exec'] += exec_score
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||||
partial_scores = evaluator.partial_scores
|
||||
if exact_score == 0:
|
||||
print("{} pred: {}".format(hardness,p_str))
|
||||
print("{} gold: {}".format(hardness,g_str))
|
||||
print("")
|
||||
scores[hardness]['exact'] += exact_score
|
||||
scores['all']['exact'] += exact_score
|
||||
for type_ in partial_types:
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores[hardness]['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores[hardness]['partial'][type_]['rec_count'] += 1
|
||||
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
if partial_scores[type_]['pred_total'] > 0:
|
||||
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||||
scores['all']['partial'][type_]['acc_count'] += 1
|
||||
if partial_scores[type_]['label_total'] > 0:
|
||||
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||||
scores['all']['partial'][type_]['rec_count'] += 1
|
||||
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||||
|
||||
entries.append({
|
||||
'predictSQL': p_str,
|
||||
'goldSQL': g_str,
|
||||
'hardness': hardness,
|
||||
'exact': exact_score,
|
||||
'partial': partial_scores
|
||||
})
|
||||
|
||||
for level in levels:
|
||||
if scores[level]['count'] == 0:
|
||||
continue
|
||||
if etype in ["all", "exec"]:
|
||||
scores[level]['exec'] /= scores[level]['count']
|
||||
|
||||
if etype in ["all", "match"]:
|
||||
scores[level]['exact'] /= scores[level]['count']
|
||||
for type_ in partial_types:
|
||||
if scores[level]['partial'][type_]['acc_count'] == 0:
|
||||
scores[level]['partial'][type_]['acc'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
|
||||
scores[level]['partial'][type_]['acc_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['rec_count'] == 0:
|
||||
scores[level]['partial'][type_]['rec'] = 0
|
||||
else:
|
||||
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
|
||||
scores[level]['partial'][type_]['rec_count'] * 1.0
|
||||
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
|
||||
scores[level]['partial'][type_]['f1'] = 1
|
||||
else:
|
||||
scores[level]['partial'][type_]['f1'] = \
|
||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||
|
||||
print_scores(scores, etype)
|
||||
|
||||
|
||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
"""
|
||||
return 1 if the values between prediction and gold are matching
|
||||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||||
"""
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
conn.text_factory = bytes
|
||||
try:
|
||||
cursor.execute(p_str)
|
||||
p_res = cursor.fetchall()
|
||||
except:
|
||||
return False
|
||||
|
||||
cursor.execute(g_str)
|
||||
q_res = cursor.fetchall()
|
||||
|
||||
def res_map(res, val_units):
|
||||
rmap = {}
|
||||
for idx, val_unit in enumerate(val_units):
|
||||
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
|
||||
rmap[key] = [r[idx] for r in res]
|
||||
return rmap
|
||||
|
||||
p_val_units = [unit[1] for unit in pred['select'][1]]
|
||||
q_val_units = [unit[1] for unit in gold['select'][1]]
|
||||
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
|
||||
|
||||
|
||||
# Rebuild SQL functions for value evaluation
|
||||
def rebuild_cond_unit_val(cond_unit):
|
||||
if cond_unit is None or not DISABLE_VALUE:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
if type(val1) is not dict:
|
||||
val1 = None
|
||||
else:
|
||||
val1 = rebuild_sql_val(val1)
|
||||
if type(val2) is not dict:
|
||||
val2 = None
|
||||
else:
|
||||
val2 = rebuild_sql_val(val2)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_val(condition):
|
||||
if condition is None or not DISABLE_VALUE:
|
||||
return condition
|
||||
|
||||
res = []
|
||||
for idx, it in enumerate(condition):
|
||||
if idx % 2 == 0:
|
||||
res.append(rebuild_cond_unit_val(it))
|
||||
else:
|
||||
res.append(it)
|
||||
return res
|
||||
|
||||
|
||||
def rebuild_sql_val(sql):
|
||||
if sql is None or not DISABLE_VALUE:
|
||||
return sql
|
||||
|
||||
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
|
||||
sql['having'] = rebuild_condition_val(sql['having'])
|
||||
sql['where'] = rebuild_condition_val(sql['where'])
|
||||
sql['intersect'] = rebuild_sql_val(sql['intersect'])
|
||||
sql['except'] = rebuild_sql_val(sql['except'])
|
||||
sql['union'] = rebuild_sql_val(sql['union'])
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
# Rebuild SQL functions for foreign key evaluation
|
||||
def build_valid_col_units(table_units, schema):
|
||||
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
|
||||
prefixs = [col_id[:-2] for col_id in col_ids]
|
||||
valid_col_units= []
|
||||
for value in schema.idMap.values():
|
||||
if '.' in value and value[:value.index('.')] in prefixs:
|
||||
valid_col_units.append(value)
|
||||
return valid_col_units
|
||||
|
||||
|
||||
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
|
||||
if col_unit is None:
|
||||
return col_unit
|
||||
|
||||
agg_id, col_id, distinct = col_unit
|
||||
if col_id in kmap and col_id in valid_col_units:
|
||||
col_id = kmap[col_id]
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return agg_id, col_id, distinct
|
||||
|
||||
|
||||
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
|
||||
if val_unit is None:
|
||||
return val_unit
|
||||
|
||||
unit_op, col_unit1, col_unit2 = val_unit
|
||||
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
|
||||
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
|
||||
return unit_op, col_unit1, col_unit2
|
||||
|
||||
|
||||
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
|
||||
if table_unit is None:
|
||||
return table_unit
|
||||
|
||||
table_type, col_unit_or_sql = table_unit
|
||||
if isinstance(col_unit_or_sql, tuple):
|
||||
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
|
||||
return table_type, col_unit_or_sql
|
||||
|
||||
|
||||
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
|
||||
if cond_unit is None:
|
||||
return cond_unit
|
||||
|
||||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||||
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
|
||||
return not_op, op_id, val_unit, val1, val2
|
||||
|
||||
|
||||
def rebuild_condition_col(valid_col_units, condition, kmap):
|
||||
for idx in range(len(condition)):
|
||||
if idx % 2 == 0:
|
||||
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
|
||||
return condition
|
||||
|
||||
|
||||
def rebuild_select_col(valid_col_units, sel, kmap):
|
||||
if sel is None:
|
||||
return sel
|
||||
distinct, _list = sel
|
||||
new_list = []
|
||||
for it in _list:
|
||||
agg_id, val_unit = it
|
||||
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
|
||||
if DISABLE_DISTINCT:
|
||||
distinct = None
|
||||
return distinct, new_list
|
||||
|
||||
|
||||
def rebuild_from_col(valid_col_units, from_, kmap):
|
||||
if from_ is None:
|
||||
return from_
|
||||
|
||||
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
|
||||
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
|
||||
return from_
|
||||
|
||||
|
||||
def rebuild_group_by_col(valid_col_units, group_by, kmap):
|
||||
if group_by is None:
|
||||
return group_by
|
||||
|
||||
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
|
||||
|
||||
|
||||
def rebuild_order_by_col(valid_col_units, order_by, kmap):
|
||||
if order_by is None or len(order_by) == 0:
|
||||
return order_by
|
||||
|
||||
direction, val_units = order_by
|
||||
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
|
||||
return direction, new_val_units
|
||||
|
||||
|
||||
def rebuild_sql_col(valid_col_units, sql, kmap):
|
||||
if sql is None:
|
||||
return sql
|
||||
|
||||
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
|
||||
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
|
||||
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
|
||||
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
|
||||
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
|
||||
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
|
||||
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
|
||||
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
|
||||
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def build_foreign_key_map(entry):
|
||||
cols_orig = entry["column_names_original"]
|
||||
tables_orig = entry["table_names_original"]
|
||||
|
||||
# rebuild cols corresponding to idmap in Schema
|
||||
cols = []
|
||||
for col_orig in cols_orig:
|
||||
if col_orig[0] >= 0:
|
||||
t = tables_orig[col_orig[0]]
|
||||
c = col_orig[1]
|
||||
cols.append("__" + t.lower() + "." + c.lower() + "__")
|
||||
else:
|
||||
cols.append("__all__")
|
||||
|
||||
def keyset_in_list(k1, k2, k_list):
|
||||
for k_set in k_list:
|
||||
if k1 in k_set or k2 in k_set:
|
||||
return k_set
|
||||
new_k_set = set()
|
||||
k_list.append(new_k_set)
|
||||
return new_k_set
|
||||
|
||||
foreign_key_list = []
|
||||
foreign_keys = entry["foreign_keys"]
|
||||
for fkey in foreign_keys:
|
||||
key1, key2 = fkey
|
||||
key_set = keyset_in_list(key1, key2, foreign_key_list)
|
||||
key_set.add(key1)
|
||||
key_set.add(key2)
|
||||
|
||||
foreign_key_map = {}
|
||||
for key_set in foreign_key_list:
|
||||
sorted_list = sorted(list(key_set))
|
||||
midx = sorted_list[0]
|
||||
for idx in sorted_list:
|
||||
foreign_key_map[cols[idx]] = cols[midx]
|
||||
|
||||
return foreign_key_map
|
||||
|
||||
|
||||
def build_foreign_key_map_from_json(table):
|
||||
with open(table, encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
tables = {}
|
||||
for entry in data:
|
||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||
return tables
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gold', dest='gold', type=str)
|
||||
parser.add_argument('--pred', dest='pred', type=str)
|
||||
parser.add_argument('--db', dest='db', type=str)
|
||||
parser.add_argument('--table', dest='table', type=str)
|
||||
parser.add_argument('--etype', dest='etype', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
gold = args.gold
|
||||
pred = args.pred
|
||||
db_dir = args.db
|
||||
table = args.table
|
||||
etype = args.etype
|
||||
|
||||
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
|
||||
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
evaluate(gold, pred, db_dir, etype, kmaps)
|
|
@ -0,0 +1,89 @@
|
|||
import os
|
||||
import sqlite3
|
||||
|
||||
from semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, \
|
||||
build_foreign_key_map_from_json
|
||||
from semparse.sql.process_sql import Schema, get_schema, get_sql
|
||||
|
||||
_schemas = {}
|
||||
kmaps = None
|
||||
|
||||
|
||||
def evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=True, db_schema=None) -> bool:
|
||||
global kmaps
|
||||
|
||||
gold=gold.replace("t1 . ","").replace("t2 . ",'').replace("t3 . ",'').replace("t4 . ",'').replace(" as t1",'').replace(" as t2",'').replace(" as t3",'').replace(" as t4",'')
|
||||
predict=predict.replace("t1 . ","").replace("t2 . ",'').replace("t3 . ",'').replace("t4 . ",'').replace(" as t1",'').replace(" as t2",'').replace(" as t3",'').replace(" as t4",'')
|
||||
|
||||
# sgrammar = SpiderGrammar(
|
||||
# output_from=True,
|
||||
# use_table_pointer=True,
|
||||
# include_literals=True,
|
||||
# include_columns=True,
|
||||
# )
|
||||
# try:
|
||||
evaluator = Evaluator()
|
||||
|
||||
if kmaps is None:
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
|
||||
if 'chase' in db_dir:
|
||||
schema = _schemas[db_name] = Schema(db_schema)
|
||||
elif db_name in _schemas:
|
||||
schema = _schemas[db_name]
|
||||
else:
|
||||
db = os.path.join(db_dir, db_name, db_name + ".sqlite")
|
||||
schema = _schemas[db_name] = Schema(get_schema(db))
|
||||
|
||||
g_sql = get_sql(schema, gold)
|
||||
# try:
|
||||
p_sql = get_sql(schema, predict)
|
||||
# except Exception as e:
|
||||
# print('evaluate_spider.py L39')
|
||||
# return False
|
||||
|
||||
# rebuild sql for value evaluation
|
||||
kmap = kmaps[db_name]
|
||||
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
|
||||
g_sql = rebuild_sql_val(g_sql)
|
||||
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
|
||||
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
|
||||
p_sql = rebuild_sql_val(p_sql)
|
||||
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
|
||||
|
||||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||||
|
||||
if not check_valid:
|
||||
return exact_score
|
||||
else:
|
||||
return exact_score and check_valid_sql(predict, db_name, db_dir)
|
||||
# except Exception as e:
|
||||
# return 0
|
||||
|
||||
|
||||
_conns = {}
|
||||
|
||||
|
||||
def check_valid_sql(sql, db_name, db_dir, return_error=False):
|
||||
return True
|
||||
db = os.path.join(db_dir, db_name, db_name + ".sqlite")
|
||||
|
||||
if db_name == 'wta_1':
|
||||
# TODO: seems like there is a problem with this dataset - slow response - add limit 1
|
||||
return True if not return_error else (True, None)
|
||||
|
||||
if db_name not in _conns:
|
||||
_conns[db_name] = sqlite3.connect(db)
|
||||
|
||||
# fixes an encoding bug
|
||||
_conns[db_name].text_factory = bytes
|
||||
|
||||
conn = _conns[db_name]
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
cursor.fetchall()
|
||||
return True if not return_error else (True, None)
|
||||
except Exception as e:
|
||||
return False if not return_error else (False, e.args[0])
|
|
@ -0,0 +1,471 @@
|
|||
"""
|
||||
Based on https://github.com/ryanzhumich/editsql/blob/master/preprocess.py
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import stanza
|
||||
import sqlparse
|
||||
from tqdm import tqdm
|
||||
|
||||
from semparse.contexts.spider_db_context import SpiderDBContext
|
||||
from semparse.sql.spider_utils import disambiguate_items, fix_number_value
|
||||
from semparse.sql.spider_utils import read_dataset_schema
|
||||
|
||||
keyword = ['select', 'distinct', 'from', 'join', 'on', 'where', 'group', 'by', 'order', 'asc', 'desc', 'limit',
|
||||
'having',
|
||||
'and', 'not', 'or', 'like', 'between', 'in',
|
||||
'sum', 'count', 'max', 'min', 'avg',
|
||||
'(', ')', ',', '>', '<', '=', '==', '>=', '!=', '<=',
|
||||
'union', 'except', 'intersect',
|
||||
'\'value\'']
|
||||
|
||||
|
||||
stanza.download('en')
|
||||
stanza_model = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma')
|
||||
# stanza_model=None
|
||||
|
||||
def write_interaction(interaction_list, split, output_dir):
|
||||
interaction = []
|
||||
for db_id in interaction_list:
|
||||
interaction += interaction_list[db_id]
|
||||
|
||||
json_split = os.path.join(output_dir, split + '.json')
|
||||
with open(json_split, 'w', encoding="utf-8") as outfile:
|
||||
json.dump(interaction, outfile, indent=2, ensure_ascii=False)
|
||||
return
|
||||
|
||||
|
||||
def read_database_schema(table_path):
|
||||
schema_tokens = {}
|
||||
column_names = {}
|
||||
database_schemas_dict = {}
|
||||
|
||||
with open(table_path, 'r', encoding='UTF-8') as f:
|
||||
database_schemas = json.load(f)
|
||||
|
||||
def get_schema_tokens(table_schema):
|
||||
column_names_surface_form = []
|
||||
column_names = []
|
||||
column_names_original = table_schema['column_names_original']
|
||||
table_names = table_schema['table_names']
|
||||
table_names_original = table_schema['table_names_original']
|
||||
for i, (table_id, column_name) in enumerate(column_names_original):
|
||||
if table_id >= 0:
|
||||
table_name = table_names_original[table_id]
|
||||
column_name_surface_form = '{}.{}'.format(table_name, column_name)
|
||||
else:
|
||||
# this is just *
|
||||
column_name_surface_form = column_name
|
||||
column_names_surface_form.append(column_name_surface_form.lower())
|
||||
column_names.append(column_name.lower())
|
||||
|
||||
# also add table_name.*
|
||||
for table_name in table_names_original:
|
||||
column_names_surface_form.append('{}.*'.format(table_name.lower()))
|
||||
|
||||
return column_names_surface_form, column_names
|
||||
|
||||
for table_schema in database_schemas:
|
||||
database_id = table_schema['db_id']
|
||||
if 'column_names_original' not in table_schema:
|
||||
table_schema["column_names_original"] = table_schema["column_names"]
|
||||
table_schema["table_names_original"] = table_schema["table_names"]
|
||||
|
||||
table_schema['table_names_original'] = [t.lower() for t in table_schema['table_names_original']]
|
||||
table_schema['foreign_keys_col'] = [i[0] for i in table_schema['foreign_keys']]
|
||||
|
||||
structure_schema = []
|
||||
for t in table_schema['foreign_keys']:
|
||||
primary_col, foreign_col = t
|
||||
primary_col = table_schema['column_names_original'][primary_col]
|
||||
primary_col_tab = table_schema['table_names_original'][primary_col[0]].lower()
|
||||
foreign_col = table_schema['column_names_original'][foreign_col]
|
||||
foreign_col_tab = table_schema['table_names_original'][foreign_col[0]].lower()
|
||||
structure_schema.append(f"( {primary_col_tab} , {foreign_col_tab} )")
|
||||
structure_schema = list(sorted(set(structure_schema)))
|
||||
|
||||
table_schema['permutations'] = [structure_schema]
|
||||
database_schemas_dict[database_id] = table_schema
|
||||
schema_tokens[database_id], column_names[database_id] = get_schema_tokens(table_schema)
|
||||
|
||||
if 'column_rewrite_names' in table_schema:
|
||||
for i in range(len(table_schema['column_rewrite_names'])):
|
||||
table_schema['column_rewrite_names'][i] = [table_schema['column_names'][i][-1]] + \
|
||||
table_schema['column_rewrite_names'][i][-1]
|
||||
table_schema['column_rewrite_names'] = [list(set(map(lambda x: x.lower().replace(' ', ''), i))) for i in
|
||||
table_schema['column_rewrite_names']]
|
||||
|
||||
for i in range(len(table_schema['table_rewrite_names'])):
|
||||
table_schema['table_rewrite_names'][i] = [table_schema['table_names'][i]] + \
|
||||
table_schema['table_rewrite_names'][i]
|
||||
table_schema['table_rewrite_names'] = [list(set(map(lambda x: x.lower().replace(' ', ''), i))) for i in
|
||||
table_schema['table_rewrite_names']]
|
||||
|
||||
return schema_tokens, column_names, database_schemas_dict
|
||||
|
||||
|
||||
def remove_from_with_join(format_sql_2):
|
||||
used_tables_list = []
|
||||
format_sql_3 = []
|
||||
table_to_name = {}
|
||||
table_list = []
|
||||
old_table_to_name = {}
|
||||
old_table_list = []
|
||||
for sub_sql in format_sql_2.split('\n'):
|
||||
if 'select ' in sub_sql:
|
||||
# only replace alias: t1 -> table_name, t2 -> table_name, etc...
|
||||
if len(table_list) > 0:
|
||||
for i in range(len(format_sql_3)):
|
||||
for table, name in table_to_name.items():
|
||||
format_sql_3[i] = format_sql_3[i].replace(table, name)
|
||||
|
||||
old_table_list = table_list
|
||||
old_table_to_name = table_to_name
|
||||
table_to_name = {}
|
||||
table_list = []
|
||||
format_sql_3.append(sub_sql)
|
||||
elif sub_sql.startswith('from'):
|
||||
new_sub_sql = None
|
||||
sub_sql_tokens = sub_sql.split()
|
||||
for t_i, t in enumerate(sub_sql_tokens):
|
||||
if t == 'as':
|
||||
table_to_name[sub_sql_tokens[t_i + 1]] = sub_sql_tokens[t_i - 1]
|
||||
table_list.append(sub_sql_tokens[t_i - 1])
|
||||
elif t == ')' and new_sub_sql is None:
|
||||
# new_sub_sql keeps some trailing parts after ')'
|
||||
new_sub_sql = ' '.join(sub_sql_tokens[t_i:])
|
||||
if len(table_list) > 0:
|
||||
# if it's a from clause with join
|
||||
if new_sub_sql is not None:
|
||||
format_sql_3.append(new_sub_sql)
|
||||
|
||||
used_tables_list.append(table_list)
|
||||
else:
|
||||
# if it's a from clause without join
|
||||
table_list = old_table_list
|
||||
table_to_name = old_table_to_name
|
||||
assert 'join' not in sub_sql
|
||||
if new_sub_sql is not None:
|
||||
sub_sub_sql = sub_sql[:-len(new_sub_sql)].strip()
|
||||
assert len(sub_sub_sql.split()) == 2
|
||||
used_tables_list.append([sub_sub_sql.split()[1]])
|
||||
format_sql_3.append(sub_sub_sql)
|
||||
format_sql_3.append(new_sub_sql)
|
||||
elif 'join' not in sub_sql:
|
||||
assert len(sub_sql.split()) == 2 or len(sub_sql.split()) == 1
|
||||
if len(sub_sql.split()) == 2:
|
||||
used_tables_list.append([sub_sql.split()[1]])
|
||||
|
||||
format_sql_3.append(sub_sql)
|
||||
else:
|
||||
print('bad from clause in remove_from_with_join')
|
||||
exit()
|
||||
else:
|
||||
format_sql_3.append(sub_sql)
|
||||
|
||||
if len(table_list) > 0:
|
||||
for i in range(len(format_sql_3)):
|
||||
for table, name in table_to_name.items():
|
||||
format_sql_3[i] = format_sql_3[i].replace(table, name)
|
||||
|
||||
used_tables = []
|
||||
for t in used_tables_list:
|
||||
for tt in t:
|
||||
used_tables.append(tt)
|
||||
used_tables = list(set(used_tables))
|
||||
|
||||
return format_sql_3, used_tables, used_tables_list
|
||||
|
||||
|
||||
def remove_from_without_join(format_sql_3, column_names, schema_tokens):
|
||||
format_sql_4 = []
|
||||
table_name = None
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if 'select ' in sub_sql:
|
||||
if table_name:
|
||||
for i in range(len(format_sql_4)):
|
||||
tokens = format_sql_4[i].split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii - 1] != '.':
|
||||
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[
|
||||
ii + 1] != '(') or ii + 1 == len(tokens):
|
||||
if '{}.{}'.format(table_name, token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name, token)
|
||||
format_sql_4[i] = ' '.join(tokens)
|
||||
|
||||
format_sql_4.append(sub_sql)
|
||||
elif sub_sql.startswith('from'):
|
||||
sub_sql_tokens = sub_sql.split()
|
||||
if len(sub_sql_tokens) == 1:
|
||||
table_name = None
|
||||
elif len(sub_sql_tokens) == 2:
|
||||
table_name = sub_sql_tokens[1]
|
||||
else:
|
||||
print('bad from clause in remove_from_without_join')
|
||||
print(format_sql_3)
|
||||
exit()
|
||||
else:
|
||||
format_sql_4.append(sub_sql)
|
||||
|
||||
if table_name:
|
||||
for i in range(len(format_sql_4)):
|
||||
tokens = format_sql_4[i].split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii - 1] != '.':
|
||||
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[ii + 1] != '(') or ii + 1 == len(
|
||||
tokens):
|
||||
if '{}.{}'.format(table_name, token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name, token)
|
||||
format_sql_4[i] = ' '.join(tokens)
|
||||
|
||||
return format_sql_4
|
||||
|
||||
|
||||
def add_table_name(format_sql_3, used_tables, column_names, schema_tokens):
|
||||
# If just one table used, easy case, replace all column_name -> table_name.column_name
|
||||
if len(used_tables) == 1:
|
||||
table_name = used_tables[0]
|
||||
format_sql_4 = []
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if sub_sql.startswith('from'):
|
||||
format_sql_4.append(sub_sql)
|
||||
continue
|
||||
|
||||
tokens = sub_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
if token in column_names and tokens[ii - 1] != '.':
|
||||
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[ii + 1] != '(') or ii + 1 == len(
|
||||
tokens):
|
||||
if '{}.{}'.format(table_name, token) in schema_tokens:
|
||||
tokens[ii] = '{} . {}'.format(table_name, token)
|
||||
format_sql_4.append(' '.join(tokens))
|
||||
return format_sql_4
|
||||
|
||||
def get_table_name_for(token):
|
||||
table_names = []
|
||||
for table_name in used_tables:
|
||||
if '{}.{}'.format(table_name, token) in schema_tokens:
|
||||
table_names.append(table_name)
|
||||
if len(table_names) == 0:
|
||||
return 'table'
|
||||
if len(table_names) > 1:
|
||||
return None
|
||||
else:
|
||||
return table_names[0]
|
||||
|
||||
format_sql_4 = []
|
||||
for sub_sql in format_sql_3.split('\n'):
|
||||
if sub_sql.startswith('from'):
|
||||
format_sql_4.append(sub_sql)
|
||||
continue
|
||||
|
||||
tokens = sub_sql.split()
|
||||
for ii, token in enumerate(tokens):
|
||||
# skip *
|
||||
if token == '*':
|
||||
continue
|
||||
if token in column_names and tokens[ii - 1] != '.':
|
||||
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[ii + 1] != '(') or ii + 1 == len(tokens):
|
||||
table_name = get_table_name_for(token)
|
||||
if table_name:
|
||||
tokens[ii] = '{} . {}'.format(table_name, token)
|
||||
format_sql_4.append(' '.join(tokens))
|
||||
|
||||
return format_sql_4
|
||||
|
||||
|
||||
def normalize_space(format_sql):
|
||||
format_sql_1 = [' '.join(
|
||||
sub_sql.strip().replace(',', ' , ').replace('.', ' . ').replace('(', ' ( ').replace(')', ' ) ').split()) for
|
||||
sub_sql in format_sql.split('\n')]
|
||||
format_sql_1 = '\n'.join(format_sql_1)
|
||||
|
||||
format_sql_2 = format_sql_1.replace('\njoin', ' join').replace(',\n', ', ').replace(' where', '\nwhere').replace(
|
||||
' intersect', '\nintersect').replace('\nand', ' and').replace('order by t2 .\nstart desc',
|
||||
'order by t2 . start desc')
|
||||
|
||||
format_sql_2 = format_sql_2.replace('select\noperator', 'select operator').replace('select\nconstructor',
|
||||
'select constructor').replace(
|
||||
'select\nstart', 'select start').replace('select\ndrop', 'select drop').replace('select\nwork',
|
||||
'select work').replace(
|
||||
'select\ngroup', 'select group').replace('select\nwhere_built', 'select where_built').replace('select\norder',
|
||||
'select order').replace(
|
||||
'from\noperator', 'from operator').replace('from\nforward', 'from forward').replace('from\nfor',
|
||||
'from for').replace(
|
||||
'from\ndrop', 'from drop').replace('from\norder', 'from order').replace('.\nstart', '. start').replace(
|
||||
'.\norder', '. order').replace('.\noperator', '. operator').replace('.\nsets', '. sets').replace(
|
||||
'.\nwhere_built', '. where_built').replace('.\nwork', '. work').replace('.\nconstructor',
|
||||
'. constructor').replace('.\ngroup',
|
||||
'. group').replace(
|
||||
'.\nfor', '. for').replace('.\ndrop', '. drop').replace('.\nwhere', '. where')
|
||||
|
||||
format_sql_2 = format_sql_2.replace('group by', 'group_by').replace('order by', 'order_by').replace('! =',
|
||||
'!=').replace(
|
||||
'limit value', 'limit_value')
|
||||
return format_sql_2
|
||||
|
||||
|
||||
def normalize_final_sql(format_sql_5):
|
||||
format_sql_final = format_sql_5.replace('\n', ' ').replace(' . ', '.').replace('group by', 'group_by').replace(
|
||||
'order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
|
||||
|
||||
# normalize two bad sqls
|
||||
if 't1' in format_sql_final or 't2' in format_sql_final or 't3' in format_sql_final or 't4' in format_sql_final:
|
||||
format_sql_final = format_sql_final.replace('t2.dormid', 'dorm.dormid')
|
||||
|
||||
# This is the failure case of remove_from_without_join()
|
||||
format_sql_final = format_sql_final.replace(
|
||||
'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by population desc limit_value',
|
||||
'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by city.population desc limit_value')
|
||||
|
||||
return format_sql_final
|
||||
|
||||
|
||||
def normalize_original_sql(sql):
|
||||
sql = [i.lower() for i in sql]
|
||||
sql = ' '.join(sql).strip(';').replace("``", "'").replace("\"", "'").replace("''", "'")
|
||||
sql = sql.replace(')from', ') from')
|
||||
sql = sql.replace('(', ' ( ')
|
||||
sql = sql.replace(')', ' ) ')
|
||||
sql = re.sub('\s+', ' ', sql)
|
||||
|
||||
sql = re.sub(r"(')(\S+)", r"\1 \2", sql)
|
||||
sql = re.sub(r"(\S+)(')", r"\1 \2", sql).split(' ')
|
||||
|
||||
sql = ' '.join(sql)
|
||||
sql = sql.strip(' ;').replace('> =', '>=').replace('! =', '!=')
|
||||
return sql.split(' ')
|
||||
|
||||
|
||||
def parse_sql(sql_string, db_id, column_names, schema_tokens, schema):
|
||||
format_sql = sqlparse.format(sql_string, reindent=True)
|
||||
format_sql_2 = normalize_space(format_sql)
|
||||
|
||||
format_sql_3, used_tables, used_tables_list = remove_from_with_join(format_sql_2)
|
||||
|
||||
format_sql_3 = '\n'.join(format_sql_3)
|
||||
format_sql_4 = add_table_name(format_sql_3, used_tables, column_names, schema_tokens)
|
||||
|
||||
format_sql_4 = '\n'.join(format_sql_4)
|
||||
format_sql_5 = remove_from_without_join(format_sql_4, column_names, schema_tokens)
|
||||
|
||||
format_sql_5 = '\n'.join(format_sql_5)
|
||||
format_sql_final = normalize_final_sql(format_sql_5)
|
||||
|
||||
return format_sql_final
|
||||
|
||||
|
||||
def read_spider_split(dataset_path, table_path, database_path):
|
||||
with open(dataset_path) as f:
|
||||
split_data = json.load(f)
|
||||
print('read_spider_split', dataset_path, len(split_data))
|
||||
|
||||
schemas = read_dataset_schema(table_path, stanza_model)
|
||||
|
||||
interaction_list = {}
|
||||
for i, ex in enumerate(tqdm(split_data)):
|
||||
db_id = ex['db_id']
|
||||
|
||||
ex['query_toks_no_value'] = normalize_original_sql(ex['query_toks_no_value'])
|
||||
turn_sql = ' '.join(ex['query_toks_no_value'])
|
||||
turn_sql = turn_sql.replace('select count ( * ) from follows group by value',
|
||||
'select count ( * ) from follows group by f1')
|
||||
ex['query_toks_no_value'] = turn_sql.split(' ')
|
||||
|
||||
ex = fix_number_value(ex)
|
||||
try:
|
||||
ex['query_toks_no_value'] = disambiguate_items(db_id, ex['query_toks_no_value'],
|
||||
tables_file=table_path, allow_aliases=False)
|
||||
except:
|
||||
print(ex['query_toks'])
|
||||
continue
|
||||
|
||||
final_sql_parse = ' '.join(ex['query_toks_no_value'])
|
||||
final_utterance = ' '.join(ex['question_toks']).lower()
|
||||
|
||||
if stanza_model is not None:
|
||||
lemma_utterance_stanza = stanza_model(final_utterance)
|
||||
lemma_utterance = [word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words]
|
||||
original_utterance = final_utterance
|
||||
else:
|
||||
original_utterance = lemma_utterance = final_utterance.split(' ')
|
||||
|
||||
# using db content
|
||||
db_context = SpiderDBContext(db_id,
|
||||
lemma_utterance,
|
||||
tables_file=table_path,
|
||||
dataset_path=database_path,
|
||||
stanza_model=stanza_model,
|
||||
schemas=schemas,
|
||||
original_utterance=original_utterance)
|
||||
|
||||
value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(db_id)
|
||||
|
||||
if value_match != []:
|
||||
print(value_match, value_alignment)
|
||||
|
||||
if db_id not in interaction_list:
|
||||
interaction_list[db_id] = []
|
||||
|
||||
interaction = {}
|
||||
interaction['id'] = i
|
||||
interaction['database_id'] = db_id
|
||||
interaction['interaction'] = [{'utterance': final_utterance,
|
||||
'db_id': db_id,
|
||||
'query': ex['query'],
|
||||
'question': ex['question'],
|
||||
'sql': final_sql_parse,
|
||||
'value_match': value_match,
|
||||
'value_alignment': value_alignment,
|
||||
'exact_match': exact_match,
|
||||
'partial_match': partial_match,
|
||||
}]
|
||||
interaction_list[db_id].append(interaction)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
|
||||
def preprocess_dataset(dataset, dataset_dir, output_dir, table_path, database_path):
|
||||
# for session in ['train', 'dev']:
|
||||
for session in ['dev']:
|
||||
dataset_path = os.path.join(dataset_dir, f'{session}.json')
|
||||
interaction_list = read_spider_split(dataset_path, table_path, database_path)
|
||||
write_interaction(interaction_list, session, output_dir)
|
||||
|
||||
return interaction_list
|
||||
|
||||
|
||||
def preprocess(dataset, dataset_dir, table_path, database_path, output_dir):
|
||||
# directory
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
|
||||
# read schema
|
||||
print('Reading spider database schema file')
|
||||
schema_tokens, column_names, database_schemas = read_database_schema(table_path)
|
||||
print('total number of schema_tokens / databases:', len(schema_tokens))
|
||||
output_table_path = os.path.join(output_dir, 'tables.json')
|
||||
with open(output_table_path, 'w') as outfile:
|
||||
json.dump([v for k, v in database_schemas.items()], outfile, indent=4)
|
||||
|
||||
# process (SQL, Query) pair in train/dev
|
||||
preprocess_dataset(dataset, dataset_dir, output_dir, table_path, database_path)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", choices=('spider', 'sparc', 'cosql'), default='spider')
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = args.dataset
|
||||
dataset_dir = f'./data/{dataset}/'
|
||||
table_path = f'./data/{dataset}/tables.json'
|
||||
database_path = f'./data/{dataset}/database'
|
||||
output_dir = f'./data/{dataset}_schema_linking_tag'
|
||||
|
||||
preprocess(dataset, dataset_dir, table_path, database_path, output_dir)
|
|
@ -0,0 +1,174 @@
|
|||
import os
|
||||
import json
|
||||
import argparse
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
from step1_schema_linking import read_database_schema
|
||||
from train import run_command
|
||||
|
||||
def running_process(generate_path):
|
||||
# cmd = f"python -m multiprocessing_bpe_encoder \
|
||||
# --encoder-json ./models/spider_sl/encoder.json \
|
||||
# --vocab-bpe ./models/spider_sl/vocab.bpe \
|
||||
# --inputs {generate_path}/train.src \
|
||||
# --outputs {generate_path}/train.bpe.src \
|
||||
# --workers 1 \
|
||||
# --keep-empty"
|
||||
# run_command(cmd)
|
||||
#
|
||||
# cmd = f"python -m multiprocessing_bpe_encoder \
|
||||
# --encoder-json ./models/spider_sl/encoder.json \
|
||||
# --vocab-bpe ./models/spider_sl/vocab.bpe \
|
||||
# --inputs {generate_path}/train.tgt \
|
||||
# --outputs {generate_path}/train.bpe.tgt \
|
||||
# --workers 1 \
|
||||
# --keep-empty"
|
||||
# run_command(cmd)
|
||||
|
||||
cmd = f"python -m multiprocessing_bpe_encoder \
|
||||
--encoder-json ./models/spider_sl/encoder.json \
|
||||
--vocab-bpe ./models/spider_sl/vocab.bpe \
|
||||
--inputs {generate_path}/dev.src \
|
||||
--outputs {generate_path}/dev.bpe.src \
|
||||
--workers 1 \
|
||||
--keep-empty"
|
||||
run_command(cmd)
|
||||
|
||||
cmd = f"python -m multiprocessing_bpe_encoder \
|
||||
--encoder-json ./models/spider_sl/encoder.json \
|
||||
--vocab-bpe ./models/spider_sl/vocab.bpe \
|
||||
--inputs {generate_path}/dev.tgt \
|
||||
--outputs {generate_path}/dev.bpe.tgt \
|
||||
--workers 1 \
|
||||
--keep-empty"
|
||||
run_command(cmd)
|
||||
|
||||
# cmd = f'fairseq-preprocess --source-lang "src" --target-lang "tgt" \
|
||||
# --trainpref {generate_path}/train.bpe \
|
||||
# --validpref {generate_path}/dev.bpe \
|
||||
# --destdir {generate_path}/bin \
|
||||
# --workers 2 \
|
||||
# --srcdict ./models/spider_sl/dict.src.txt \
|
||||
# --tgtdict ./models/spider_sl/dict.tgt.txt '
|
||||
|
||||
cmd = f'fairseq-preprocess --source-lang "src" --target-lang "tgt" \
|
||||
--validpref {generate_path}/dev.bpe \
|
||||
--destdir {generate_path}/bin \
|
||||
--workers 2 \
|
||||
--srcdict ./models/spider_sl/dict.src.txt \
|
||||
--tgtdict ./models/spider_sl/dict.tgt.txt '
|
||||
|
||||
|
||||
subprocess.Popen(
|
||||
cmd, universal_newlines=True, shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
|
||||
|
||||
def build_schema_linking_data(schema, question, item, turn_id, linking_type):
|
||||
source_sequence_list, target_sequence_list = [], []
|
||||
|
||||
# column description
|
||||
column_names = []
|
||||
for i, (t, c) in enumerate(zip(schema['column_types'], schema['column_names_original'])):
|
||||
if c[0] == -1:
|
||||
column_names.append("{0} {1}".format(t, c[1].lower()))
|
||||
else:
|
||||
column_with_alias = "{0}@{1}".format(schema['table_names_original'][c[0]].lower(), c[1].lower())
|
||||
tag_list = []
|
||||
if column_with_alias in item['interaction'][turn_id]['exact_match']:
|
||||
tag_list.append('EM')
|
||||
elif column_with_alias in item['interaction'][turn_id]['partial_match']:
|
||||
tag_list.append('PA')
|
||||
if column_with_alias in item['interaction'][turn_id]['value_match']:
|
||||
tag_list.append('VC')
|
||||
|
||||
# primary-foreign key
|
||||
if i in schema['primary_keys']:
|
||||
tag_list.append('RK')
|
||||
elif i in schema['foreign_keys_col']:
|
||||
tag_list.append('FO')
|
||||
|
||||
if tag_list != []:
|
||||
column_names.append("{0} {1} {2}".format(' '.join(tag_list), t, column_with_alias))
|
||||
else:
|
||||
column_names.append("{0} {1}".format(t, column_with_alias))
|
||||
|
||||
# table description
|
||||
table_names = []
|
||||
for t in schema['table_names_original']:
|
||||
tag_list = []
|
||||
if t in item['interaction'][turn_id]['exact_match']:
|
||||
tag_list.append('EM')
|
||||
elif t in item['interaction'][turn_id]['partial_match']:
|
||||
tag_list.append('PA')
|
||||
if '_nosl' in linking_type or 'not' in linking_type:
|
||||
tag_list = []
|
||||
|
||||
if tag_list != []:
|
||||
table_names.append("{0} {1}".format(' '.join(tag_list), t.lower()))
|
||||
else:
|
||||
table_names.append("{0}".format(t.lower()))
|
||||
|
||||
table_names = ' | '.join(table_names)
|
||||
column_names = ' | '.join(column_names)
|
||||
for structure_schema_list in schema['permutations'][:10]:
|
||||
structure_schema_str = ' | '.join(structure_schema_list)
|
||||
|
||||
source_sequence = f"<C> {column_names} | <T> {table_names} | <S> {structure_schema_str} | <Q> {question.lower()}"
|
||||
target_sequence = item['interaction'][turn_id]['sql'].lower()
|
||||
|
||||
source_sequence_list.append(source_sequence)
|
||||
target_sequence_list.append(target_sequence)
|
||||
|
||||
return source_sequence_list, target_sequence_list
|
||||
|
||||
def extract_input_and_output(example_lines, linking_type):
|
||||
inputs = []
|
||||
outputs = []
|
||||
database_schema_filename = './data/spider/tables.json'
|
||||
|
||||
|
||||
schema_tokens, column_names, database_schemas = read_database_schema(database_schema_filename)
|
||||
|
||||
for item in tqdm(example_lines):
|
||||
question = item['interaction'][0]['utterance']
|
||||
schema = database_schemas[item['database_id']]
|
||||
source_sequence, target_sequence = build_schema_linking_data(schema=schema,
|
||||
question=question,
|
||||
item=item,
|
||||
turn_id=0,
|
||||
linking_type=linking_type)
|
||||
outputs.extend(target_sequence)
|
||||
inputs.extend(source_sequence)
|
||||
|
||||
|
||||
assert len(inputs) == len(outputs)
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def read_dataflow_dataset(file_path, out_folder, session, linking_type):
|
||||
train_out_path = os.path.join(out_folder, session)
|
||||
train_src_writer = open(train_out_path + ".src", "w", encoding="utf8")
|
||||
train_tgt_writer = open(train_out_path + ".tgt", "w", encoding="utf8")
|
||||
|
||||
with open(file_path, "r", encoding='utf-8') as data_file:
|
||||
lines = json.load(data_file)
|
||||
data_input, data_output = extract_input_and_output(lines, linking_type)
|
||||
train_src_writer.write("\n".join(data_input))
|
||||
train_tgt_writer.write("\n".join(data_output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sl_dataset_path", default='./data/spider_schema_linking_tag')
|
||||
parser.add_argument("--output_path", default='./dataset_post/spider_sl')
|
||||
parser.add_argument("--linking_type", default='default')
|
||||
args = parser.parse_args()
|
||||
|
||||
# for session in ["train", "dev"]:
|
||||
for session in ["dev"]:
|
||||
file_path = os.path.join(args.sl_dataset_path, "{}.json".format(session))
|
||||
out_folder = args.output_path
|
||||
if not os.path.exists(out_folder):
|
||||
os.makedirs(out_folder)
|
||||
read_dataflow_dataset(file_path, out_folder, session, args.linking_type)
|
||||
running_process(args.output_path)
|
|
@ -0,0 +1,336 @@
|
|||
import argparse
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from re import RegexFlag
|
||||
import networkx as nx
|
||||
import torch
|
||||
|
||||
from genre.fairseq_model import GENRE, mGENRE
|
||||
from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_fairseq as get_prefix_allowed_tokens_fn
|
||||
from genre.trie import Trie
|
||||
from semparse.sql.spider import load_original_schemas, load_tables
|
||||
from semparse.worlds.evaluate_spider import evaluate as evaluate_sql
|
||||
from step1_schema_linking import read_database_schema
|
||||
|
||||
database_dir='./data/spider/database'
|
||||
database_schema_filename = './data/spider/tables.json'
|
||||
schema_tokens, column_names, database_schemas = read_database_schema(database_schema_filename)
|
||||
|
||||
with open(f'./data/spider/dev.json', 'r', encoding='utf-8') as f:
|
||||
item = json.load(f)
|
||||
sql_to_db = []
|
||||
for i in item:
|
||||
sql_to_db.append(i['db_id'])
|
||||
|
||||
|
||||
def post_processing_sql(p_sql, foreign_key_maps, schemas, o_schemas):
|
||||
foreign_key = {}
|
||||
for k, v in foreign_key_maps.items():
|
||||
if k == v:
|
||||
continue
|
||||
key = ' '.join(sorted([k.split('.')[0].strip('_'), v.split('.')[0].strip('_')]))
|
||||
foreign_key[key] = (k.strip('_').replace('.', '@'), v.strip('_').replace('.', '@'))
|
||||
|
||||
primary_key = {}
|
||||
for t in o_schemas.tables:
|
||||
table = t.orig_name.lower()
|
||||
if len(t.primary_keys) == 0:
|
||||
continue
|
||||
column = t.primary_keys[0].orig_name.lower()
|
||||
primary_key[table] = f'{table}@{column}'
|
||||
|
||||
p_sql = re.sub(r'(=)(\S+)', r'\1 \2', p_sql)
|
||||
p_sql = p_sql.split()
|
||||
|
||||
columns = ['*']
|
||||
tables = []
|
||||
for table, column_list in schemas.schema.items():
|
||||
for column in column_list:
|
||||
columns.append(f"{table}@{column}")
|
||||
tables.append(table)
|
||||
|
||||
# infer table from mentioned column
|
||||
all_from_table_ids = set()
|
||||
from_idx = where_idx = group_idx = order_idx = -1
|
||||
for idx, token in enumerate(p_sql):
|
||||
if '@' in token and token in columns:
|
||||
all_from_table_ids.add(schemas.idMap[token.split('@')[0]])
|
||||
if token == 'from' and from_idx == -1:
|
||||
from_idx = idx
|
||||
if token == 'where' and where_idx == -1:
|
||||
where_idx = idx
|
||||
if token == 'group' and group_idx == -1:
|
||||
group_idx = idx
|
||||
if token == 'order' and order_idx == -1:
|
||||
order_idx = idx
|
||||
|
||||
#don't process nested SQL (more than one select)
|
||||
if len(re.findall('select', ' '.join(p_sql))) > 1 or len(all_from_table_ids) == 0:
|
||||
return ' '.join(p_sql)
|
||||
|
||||
covered_tables = set()
|
||||
candidate_table_ids = sorted(all_from_table_ids)
|
||||
start_table_id = candidate_table_ids[0]
|
||||
conds = set()
|
||||
all_conds = []
|
||||
|
||||
for table_id in candidate_table_ids[1:]:
|
||||
if table_id in covered_tables:
|
||||
continue
|
||||
try:
|
||||
path = nx.shortest_path(
|
||||
o_schemas.foreign_key_graph,
|
||||
source=start_table_id,
|
||||
target=table_id,
|
||||
)
|
||||
except (nx.NetworkXNoPath, nx.NodeNotFound):
|
||||
covered_tables.add(table_id)
|
||||
continue
|
||||
|
||||
for source_table_id, target_table_id in zip(path, path[1:]):
|
||||
if target_table_id in covered_tables:
|
||||
continue
|
||||
covered_tables.add(target_table_id)
|
||||
all_from_table_ids.add(target_table_id)
|
||||
col1, col2 = o_schemas.foreign_key_graph[source_table_id][target_table_id]["columns"]
|
||||
all_conds.append((columns[col1], columns[col2]))
|
||||
|
||||
conds.add((tables[source_table_id],
|
||||
tables[target_table_id],
|
||||
columns[col1],
|
||||
columns[col2]))
|
||||
|
||||
all_from_table_ids = list(all_from_table_ids)
|
||||
|
||||
try:
|
||||
tokens = ["from", tables[all_from_table_ids[0]]]
|
||||
for i, table_id in enumerate(all_from_table_ids[1:]):
|
||||
tokens += ["join"]
|
||||
tokens += [tables[table_id]]
|
||||
tokens += ["on", all_conds[i][0], "=", all_conds[i][1]]
|
||||
except:
|
||||
return ' '.join(p_sql)
|
||||
|
||||
if where_idx != -1:
|
||||
p_sql = p_sql[:from_idx] + tokens + p_sql[where_idx:]
|
||||
elif group_idx != -1:
|
||||
p_sql = p_sql[:from_idx] + tokens + p_sql[group_idx:]
|
||||
elif order_idx != -1:
|
||||
p_sql = p_sql[:from_idx] + tokens + p_sql[order_idx:]
|
||||
elif len(p_sql[:from_idx] + p_sql[from_idx:]) == len(p_sql):
|
||||
p_sql = p_sql[:from_idx] + tokens
|
||||
|
||||
return ' '.join(p_sql)
|
||||
|
||||
|
||||
def extract_structure_data(plain_text_content: str):
|
||||
def sort_by_id(data):
|
||||
data.sort(key=lambda x: int(x.split('\t')[0][2:]))
|
||||
return data
|
||||
|
||||
data = []
|
||||
original_schemas = load_original_schemas(database_schema_filename)
|
||||
schemas, eval_foreign_key_maps = load_tables(database_schema_filename)
|
||||
|
||||
predict_outputs = sort_by_id(re.findall("^D.+", plain_text_content, RegexFlag.MULTILINE))
|
||||
ground_outputs = sort_by_id(re.findall("^T.+", plain_text_content, RegexFlag.MULTILINE))
|
||||
source_inputs = sort_by_id(re.findall("^S.+", plain_text_content, RegexFlag.MULTILINE))
|
||||
|
||||
for idx, (predict, ground, source) in enumerate(zip(predict_outputs, ground_outputs, source_inputs)):
|
||||
predict_id, predict_score, predict_clean = predict.split('\t')
|
||||
ground_id, ground_clean = ground.split('\t')
|
||||
source_id, source_clean = source.split('\t')
|
||||
|
||||
db_id = sql_to_db[idx]
|
||||
#try to postprocess the incomplete sql from
|
||||
# (1) correcting the COLUMN in ON_CLAUSE based on foreign key graph
|
||||
# (2) adding the underlying TABLE via searching shortest path
|
||||
predict_clean = post_processing_sql(predict_clean, eval_foreign_key_maps[db_id], original_schemas[db_id],
|
||||
schemas[db_id])
|
||||
|
||||
data.append((predict_id[2:], source_clean.split('<Q>')[-1].strip(), ground_clean, predict_clean, db_id))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def evaluate(data):
|
||||
def evaluate_example(_predict_str: str, _ground_str: str):
|
||||
return re.sub("\s+", "", _predict_str.lower()) == re.sub("\s+", "", _ground_str.lower())
|
||||
|
||||
correct_num = 0
|
||||
correct_tag_list = []
|
||||
total = 0
|
||||
|
||||
tmp = []
|
||||
for example in data:
|
||||
idx, source_str, ground_str, predict_str, db_id = example
|
||||
total += 1
|
||||
|
||||
try:
|
||||
sql_match = evaluate_sql(gold=ground_str.replace('@', '.'),
|
||||
predict=predict_str.replace('@', '.'),
|
||||
db_name=db_id,
|
||||
db_dir=database_dir,
|
||||
table=database_schema_filename)
|
||||
except:
|
||||
print(predict_str)
|
||||
sql_match = False
|
||||
|
||||
if (sql_match or evaluate_example(predict_str, ground_str)):
|
||||
is_correct = True
|
||||
correct_num += 1
|
||||
else:
|
||||
is_correct = False
|
||||
|
||||
tmp.append(is_correct)
|
||||
correct_tag_list.append(is_correct)
|
||||
|
||||
print("Correct/Total : {}/{}, {:.4f}".format(correct_num, total, correct_num / total))
|
||||
return correct_tag_list, correct_num, total
|
||||
|
||||
|
||||
def predict_and_evaluate(model_path, dataset_path, constrain):
|
||||
if constrain:
|
||||
data = predict_with_constrain(
|
||||
model_path=model_path,
|
||||
dataset_path=dataset_path
|
||||
)
|
||||
else:
|
||||
decode_without_constrain(
|
||||
model_path=model_path,
|
||||
dataset_path=dataset_path
|
||||
)
|
||||
with open('./eval/generate-valid.txt', "r", encoding="utf8") as generate_f:
|
||||
file_content = generate_f.read()
|
||||
data = extract_structure_data(file_content)
|
||||
|
||||
correct_arr, correct_num, total = evaluate(data)
|
||||
with open('./eval/spider_eval.txt', "w", encoding="utf8") as eval_file:
|
||||
for example, correct in zip(data, correct_arr):
|
||||
eval_file.write(str(correct) + "\n" + "\n".join(
|
||||
[example[0], "db: " + example[-1], example[1], "gold: " + example[2], "pred: " + example[3]]) + "\n\n")
|
||||
return correct_num, total
|
||||
|
||||
def get_alias_schema(schemas):
|
||||
alias_schema = {}
|
||||
for db in schemas:
|
||||
schema = schemas[db].orig
|
||||
collect = []
|
||||
for i, (t, c) in enumerate(zip(schema['column_types'], schema['column_names_original'])):
|
||||
if c[0] == -1:
|
||||
collect.append('*')
|
||||
else:
|
||||
column_with_alias = "{0}@{1}".format(schema['table_names_original'][c[0]].lower(), c[1].lower())
|
||||
collect.append(column_with_alias)
|
||||
for t in schema['table_names_original']:
|
||||
collect.append(t.lower())
|
||||
collect.append("'value'")
|
||||
alias_schema[db] = collect
|
||||
return alias_schema
|
||||
|
||||
def predict_with_constrain(model_path, dataset_path):
|
||||
schemas, eval_foreign_key_maps = load_tables(database_schema_filename)
|
||||
original_schemas = load_original_schemas(database_schema_filename)
|
||||
with open(f'{dataset_path}/dev.src', 'r', encoding='utf-8') as f:
|
||||
item = [i.strip() for i in f.readlines()]
|
||||
with open(f'{dataset_path}/dev.tgt', 'r', encoding='utf-8') as f:
|
||||
ground = [i.strip() for i in f.readlines()]
|
||||
|
||||
alias_schema = get_alias_schema(schemas)
|
||||
|
||||
item_db_cluster = defaultdict(list)
|
||||
ground_db_cluster = defaultdict(list)
|
||||
source_db_cluster = defaultdict(list)
|
||||
|
||||
num_example = 1034
|
||||
for db, sentence, g_sql in zip(sql_to_db[:num_example], item[:num_example], ground[:num_example]):
|
||||
source = sentence.split('<Q>')[-1].strip()
|
||||
item_db_cluster[db].append(sentence)
|
||||
ground_db_cluster[db].append(g_sql)
|
||||
source_db_cluster[db].append(source)
|
||||
|
||||
source = []
|
||||
ground = []
|
||||
for db, sentence in source_db_cluster.items():
|
||||
source.extend(sentence)
|
||||
for db, g_SQL in ground_db_cluster.items():
|
||||
ground.extend(g_SQL)
|
||||
|
||||
model = GENRE.from_pretrained(model_path).eval()
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
|
||||
result=[]
|
||||
for db, sentence in item_db_cluster.items():
|
||||
print(f'processing db: {db} with {len(sentence)} sentences')
|
||||
rnt=decode_with_constrain(sentence, alias_schema[db], model)
|
||||
result.extend([i[0]['text'] if isinstance(i[0]['text'], str) else i[0]['text'][0] for i in rnt])
|
||||
|
||||
eval_file_path= f'./eval/generate-valid-constrain.txt'
|
||||
with open(eval_file_path, "w", encoding="utf8") as f:
|
||||
f.write('\n'.join(result))
|
||||
|
||||
# result = []
|
||||
# with open(f'./eval/generate-valid-constrain.txt', "r", encoding="utf8") as f:
|
||||
# for idx, (sent, db_id) in enumerate(zip(f.readlines(), sql_to_db)):
|
||||
# result.append(sent.strip())
|
||||
|
||||
data = []
|
||||
for predict_id, (predict_clean, ground_clean, source_clean, db_id) in enumerate(
|
||||
zip(result, ground, source, sql_to_db)):
|
||||
predict_clean = post_processing_sql(predict_clean, eval_foreign_key_maps[db_id], original_schemas[db_id],
|
||||
schemas[db_id])
|
||||
data.append((str(predict_id), source_clean.split('<Q>')[-1].strip(), ground_clean, predict_clean, db_id))
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def decode_with_constrain(sentences, schema, model):
|
||||
trie = Trie([
|
||||
model.encode(" {}".format(e))[1:].tolist()
|
||||
for e in schema
|
||||
])
|
||||
|
||||
prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(
|
||||
model,
|
||||
sentences,
|
||||
mention_trie=trie,
|
||||
)
|
||||
|
||||
return model.sample(
|
||||
sentences,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
)
|
||||
|
||||
|
||||
def decode_without_constrain( model_path, dataset_path):
|
||||
cmd = f'fairseq-generate \
|
||||
--path {model_path}/model.pt {dataset_path}/bin \
|
||||
--gen-subset valid \
|
||||
--nbest 1 \
|
||||
--max-tokens 4096 \
|
||||
--source-lang src --target-lang tgt \
|
||||
--results-path ./eval \
|
||||
--beam 5 \
|
||||
--bpe gpt2 \
|
||||
--remove-bpe \
|
||||
--skip-invalid-size-inputs-valid-test'
|
||||
|
||||
subprocess.Popen(
|
||||
cmd, universal_newlines=True, shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", default='./models/spider_sl')
|
||||
parser.add_argument("--dataset_path", default='./dataset_post/spider_sl')
|
||||
parser.add_argument("--constrain", action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
predict_and_evaluate(model_path=args.model_path,
|
||||
dataset_path=args.dataset_path,
|
||||
constrain=args.constrain)
|
|
@ -0,0 +1,232 @@
|
|||
# Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task
|
||||
|
||||
Spider is a large human-labeled dataset for complex and cross-domain semantic parsing and text-to-SQL task (natural language interfaces for relational databases). It is released along with our EMNLP 2018 paper: [Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task](https://arxiv.org/abs/1809.08887). This repo contains all code for evaluation, preprocessing, and all baselines used in our paper. Please refer to [the task site](https://yale-lily.github.io/spider) for more general introduction and the leaderboard.
|
||||
|
||||
|
||||
### Changelog
|
||||
|
||||
- `06/07/2020` We corrected some annotation errors and label mismatches (not errors) in Spider dev and test sets (~4% of dev examples updated, click [here](https://github.com/taoyds/spider/commit/25fcd85d9b6e94acaeb5e9172deadeefeed83f5e#diff-18b0a730a7b0d29b0a78a5070d971d49) for more details). Please download the Spider dataset from [the page](https://yale-lily.github.io/spider) again.
|
||||
- `01/16/2020` For value prediction (in order to compute the execution accuracy), your model should be able to 1) copy from the question inputs, 2) retrieve from the database content (database content is available), or 3) generate numbers (e.g. 3 in "LIMIT 3").
|
||||
- `1/14/2019` The submission toturial is ready! Please follow it to get your results on the unreleased test data.
|
||||
- `12/17/2018` We updated 7 sqlite database files. Please download the Spider data from the official website again. Please refer to [the issue 14](https://github.com/taoyds/spider/issues/14) for more details.
|
||||
- `10/25/2018`: evaluation script is updated so that the table in `count(*)`cases will be evaluated as well. Please check out [the issue 5](https://github.com/taoyds/spider/issues/5) for more info. Results of all baselines and [syntaxSQL](https://github.com/taoyds/syntaxSQL) on the papers are updated as well.
|
||||
- `10/25/2018`: to get the latest SQL parsing results (a few small bugs fixed), please use `preprocess/parse_raw_json.py` to update. Please refer to [the issue 3](https://github.com/taoyds/spider/issues/3) for more details.
|
||||
|
||||
### Citation
|
||||
|
||||
The dataset is annotated by 11 college students. When you use the Spider dataset, we would appreciate it if you cite the following:
|
||||
|
||||
```
|
||||
@inproceedings{Yu&al.18c,
|
||||
title = {Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task},
|
||||
author = {Tao Yu and Rui Zhang and Kai Yang and Michihiro Yasunaga and Dongxu Wang and Zifan Li and James Ma and Irene Li and Qingning Yao and Shanelle Roman and Zilin Zhang and Dragomir Radev}
|
||||
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing",
|
||||
address = "Brussels, Belgium",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
year = 2018
|
||||
}
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
`evaluation.py` and `process_sql.py` are written in Python 3. Enviroment setup for each baseline is in README under each baseline directory.
|
||||
|
||||
|
||||
### Data Content and Format
|
||||
|
||||
#### Question, SQL, and Parsed SQL
|
||||
|
||||
Each file in`train.json` and `dev.json` contains the following fields:
|
||||
- `question`: the natural language question
|
||||
- `question_toks`: the natural language question tokens
|
||||
- `db_id`: the database id to which this question is addressed.
|
||||
- `query`: the SQL query corresponding to the question.
|
||||
- `query_toks`: the SQL query tokens corresponding to the question.
|
||||
- `sql`: parsed results of this SQL query using `process_sql.py`. Please refer to `parsed_sql_examples.sql` in the`preprocess` directory for the detailed documentation.
|
||||
|
||||
|
||||
```
|
||||
{
|
||||
"db_id": "world_1",
|
||||
"query": "SELECT avg(LifeExpectancy) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = \"English\" AND T2.IsOfficial = \"T\")",
|
||||
"query_toks": ["SELECT", "avg", "(", "LifeExpectancy", ")", "FROM", ...],
|
||||
"question": "What is average life expectancy in the countries where English is not the official language?",
|
||||
"question_toks": ["What", "is", "average", "life", ...],
|
||||
"sql": {
|
||||
"except": null,
|
||||
"from": {
|
||||
"conds": [],
|
||||
"table_units": [
|
||||
...
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": null,
|
||||
"limit": null,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
...
|
||||
],
|
||||
"union": null,
|
||||
"where": [
|
||||
[
|
||||
true,
|
||||
...
|
||||
{
|
||||
"except": null,
|
||||
"from": {
|
||||
"conds": [
|
||||
[
|
||||
false,
|
||||
2,
|
||||
[
|
||||
...
|
||||
},
|
||||
"groupBy": [],
|
||||
"having": [],
|
||||
"intersect": null,
|
||||
"limit": null,
|
||||
"orderBy": [],
|
||||
"select": [
|
||||
false,
|
||||
...
|
||||
"union": null,
|
||||
"where": [
|
||||
[
|
||||
false,
|
||||
2,
|
||||
[
|
||||
0,
|
||||
...
|
||||
}
|
||||
},
|
||||
|
||||
```
|
||||
|
||||
#### Tables
|
||||
|
||||
`tables.json` contains the following information for each database:
|
||||
- `db_id`: database id
|
||||
- `table_names_original`: original table names stored in the database.
|
||||
- `table_names`: cleaned and normalized table names. We make sure the table names are meaningful. [to be changed]
|
||||
- `column_names_original`: original column names stored in the database. Each column looks like: `[0, "id"]`. `0` is the index of table names in `table_names`, which is `city` in this case. `"id"` is the column name.
|
||||
- `column_names`: cleaned and normalized column names. We make sure the column names are meaningful. [to be changed]
|
||||
- `column_types`: data type of each column
|
||||
- `foreign_keys`: foreign keys in the database. `[3, 8]` means column indices in the `column_names`. These two columns are foreign keys of two different tables.
|
||||
- `primary_keys`: primary keys in the database. Each number is the index of `column_names`.
|
||||
|
||||
|
||||
```
|
||||
{
|
||||
"column_names": [
|
||||
[
|
||||
0,
|
||||
"id"
|
||||
],
|
||||
[
|
||||
0,
|
||||
"name"
|
||||
],
|
||||
[
|
||||
0,
|
||||
"country code"
|
||||
],
|
||||
[
|
||||
0,
|
||||
"district"
|
||||
],
|
||||
.
|
||||
.
|
||||
.
|
||||
],
|
||||
"column_names_original": [
|
||||
[
|
||||
0,
|
||||
"ID"
|
||||
],
|
||||
[
|
||||
0,
|
||||
"Name"
|
||||
],
|
||||
[
|
||||
0,
|
||||
"CountryCode"
|
||||
],
|
||||
[
|
||||
0,
|
||||
"District"
|
||||
],
|
||||
.
|
||||
.
|
||||
.
|
||||
],
|
||||
"column_types": [
|
||||
"number",
|
||||
"text",
|
||||
"text",
|
||||
"text",
|
||||
.
|
||||
.
|
||||
.
|
||||
],
|
||||
"db_id": "world_1",
|
||||
"foreign_keys": [
|
||||
[
|
||||
3,
|
||||
8
|
||||
],
|
||||
[
|
||||
23,
|
||||
8
|
||||
]
|
||||
],
|
||||
"primary_keys": [
|
||||
1,
|
||||
8,
|
||||
23
|
||||
],
|
||||
"table_names": [
|
||||
"city",
|
||||
"sqlite sequence",
|
||||
"country",
|
||||
"country language"
|
||||
],
|
||||
"table_names_original": [
|
||||
"city",
|
||||
"sqlite_sequence",
|
||||
"country",
|
||||
"countrylanguage"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
#### Databases
|
||||
|
||||
All table contents are contained in corresponding SQLite3 database files.
|
||||
|
||||
|
||||
### Evaluation
|
||||
|
||||
Our evaluation metrics include Component Matching, Exact Matching, and Execution Accuracy. For component and exact matching evaluation, instead of simply conducting string comparison between the predicted and gold SQL queries, we decompose each SQL into several clauses, and conduct set comparison in each SQL clause.
|
||||
|
||||
For Execution Accuracy, our current models do not predict any value in SQL conditions so that we do not provide execution accuracies. However, we encourage you to provide it in the future submissions. For value prediction, you can assume that a list of gold values for each question is given. Your model has to fill them into the right slots in the SQL.
|
||||
|
||||
Please refer to [our paper]() and [this page](https://github.com/taoyds/spider/tree/master/evaluation) for more details and examples.
|
||||
|
||||
```
|
||||
python evaluation.py --gold [gold file] --pred [predicted file] --etype [evaluation type] --db [database dir] --table [table file]
|
||||
|
||||
arguments:
|
||||
[gold file] gold.sql file where each line is `a gold SQL \t db_id`
|
||||
[predicted file] predicted sql file where each line is a predicted SQL
|
||||
[evaluation type] "match" for exact set matching score, "exec" for execution score, and "all" for both
|
||||
[database dir] directory which contains sub-directories where each SQLite3 database is stored
|
||||
[table file] table.json file which includes foreign key info of each database
|
||||
|
||||
```
|
||||
|
||||
### FAQ
|
||||
|
||||
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,15 @@
|
|||
## Data Preprocess
|
||||
|
||||
#### Get Parsed SQL Output
|
||||
|
||||
The SQL parsing script is `process_sql.py` in the main directory. Please refer to `parsed_sql_examples.sql` for the explanation of some parsed SQL output examples.
|
||||
|
||||
If you would like to use `process_sql.py` to parse SQL queries by yourself, `parse_sql_one.py` provides an example of how the script is called. Or you can use `parse_raw_json.py` to update all parsed SQL results (value for `sql`) in `train.json` and `dev.json`.
|
||||
|
||||
#### Get Table Info from Database
|
||||
|
||||
To generate the final `tables.json` file. It reads sqlite files from `database/` dir and previous `tables.json` with hand-corrected names:
|
||||
|
||||
```
|
||||
python process/get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]
|
||||
```
|
|
@ -0,0 +1,162 @@
|
|||
import os
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
from os import listdir, makedirs
|
||||
from os.path import isfile, isdir, join, split, exists, splitext
|
||||
from nltk import word_tokenize, tokenize
|
||||
import traceback
|
||||
|
||||
EXIST = {"atis", "geo", "advising", "yelp", "restaurants", "imdb", "academic"}
|
||||
|
||||
|
||||
def convert_fk_index(data):
|
||||
fk_holder = []
|
||||
for fk in data["foreign_keys"]:
|
||||
tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1]
|
||||
ref_cid, cid = None, None
|
||||
try:
|
||||
tid = data["table_names_original"].index(tn)
|
||||
ref_tid = data["table_names_original"].index(ref_tn)
|
||||
|
||||
for i, (tab_id, col_org) in enumerate(data["column_names_original"]):
|
||||
if tab_id == ref_tid and ref_col == col_org:
|
||||
ref_cid = i
|
||||
elif tid == tab_id and col == col_org:
|
||||
cid = i
|
||||
if ref_cid and cid:
|
||||
fk_holder.append([cid, ref_cid])
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print("table_names_original: ", data["table_names_original"])
|
||||
print("finding tab name: ", tn, ref_tn)
|
||||
sys.exit()
|
||||
return fk_holder
|
||||
|
||||
|
||||
def dump_db_json_schema(db, f):
|
||||
"""read table and column info"""
|
||||
|
||||
conn = sqlite3.connect(db)
|
||||
conn.execute("pragma foreign_keys=ON")
|
||||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
|
||||
data = {
|
||||
"db_id": f,
|
||||
"table_names_original": [],
|
||||
"table_names": [],
|
||||
"column_names_original": [(-1, "*")],
|
||||
"column_names": [(-1, "*")],
|
||||
"column_types": ["text"],
|
||||
"primary_keys": [],
|
||||
"foreign_keys": [],
|
||||
}
|
||||
|
||||
fk_holder = []
|
||||
for i, item in enumerate(cursor.fetchall()):
|
||||
table_name = item[0]
|
||||
data["table_names_original"].append(table_name)
|
||||
data["table_names"].append(table_name.lower().replace("_", " "))
|
||||
fks = conn.execute(
|
||||
"PRAGMA foreign_key_list('{}') ".format(table_name)
|
||||
).fetchall()
|
||||
# print("db:{} table:{} fks:{}".format(f,table_name,fks))
|
||||
fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks])
|
||||
cur = conn.execute("PRAGMA table_info('{}') ".format(table_name))
|
||||
for j, col in enumerate(cur.fetchall()):
|
||||
data["column_names_original"].append((i, col[1]))
|
||||
data["column_names"].append((i, col[1].lower().replace("_", " ")))
|
||||
# varchar, '' -> text, int, numeric -> integer,
|
||||
col_type = col[2].lower()
|
||||
if (
|
||||
"char" in col_type
|
||||
or col_type == ""
|
||||
or "text" in col_type
|
||||
or "var" in col_type
|
||||
):
|
||||
data["column_types"].append("text")
|
||||
elif (
|
||||
"int" in col_type
|
||||
or "numeric" in col_type
|
||||
or "decimal" in col_type
|
||||
or "number" in col_type
|
||||
or "id" in col_type
|
||||
or "real" in col_type
|
||||
or "double" in col_type
|
||||
or "float" in col_type
|
||||
):
|
||||
data["column_types"].append("number")
|
||||
elif "date" in col_type or "time" in col_type or "year" in col_type:
|
||||
data["column_types"].append("time")
|
||||
elif "boolean" in col_type:
|
||||
data["column_types"].append("boolean")
|
||||
else:
|
||||
data["column_types"].append("others")
|
||||
|
||||
if col[5] == 1:
|
||||
data["primary_keys"].append(len(data["column_names"]) - 1)
|
||||
|
||||
data["foreign_keys"] = fk_holder
|
||||
data["foreign_keys"] = convert_fk_index(data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print(
|
||||
"Usage: python get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]"
|
||||
)
|
||||
sys.exit()
|
||||
input_dir = sys.argv[1]
|
||||
output_file = sys.argv[2]
|
||||
ex_tab_file = sys.argv[3]
|
||||
|
||||
all_fs = [
|
||||
df for df in listdir(input_dir) if exists(join(input_dir, df, df + ".sqlite"))
|
||||
]
|
||||
with open(ex_tab_file) as f:
|
||||
ex_tabs = json.load(f)
|
||||
# for tab in ex_tabs:
|
||||
# tab["foreign_keys"] = convert_fk_index(tab)
|
||||
ex_tabs = {tab["db_id"]: tab for tab in ex_tabs if tab["db_id"] in all_fs}
|
||||
print("precessed file num: ", len(ex_tabs))
|
||||
not_fs = [
|
||||
df
|
||||
for df in listdir(input_dir)
|
||||
if not exists(join(input_dir, df, df + ".sqlite"))
|
||||
]
|
||||
for d in not_fs:
|
||||
print("no sqlite file found in: ", d)
|
||||
db_files = [
|
||||
(df + ".sqlite", df)
|
||||
for df in listdir(input_dir)
|
||||
if exists(join(input_dir, df, df + ".sqlite"))
|
||||
]
|
||||
tables = []
|
||||
for f, df in db_files:
|
||||
# if df in ex_tabs.keys():
|
||||
# print 'reading old db: ', df
|
||||
# tables.append(ex_tabs[df])
|
||||
db = join(input_dir, df, f)
|
||||
print("\nreading new db: ", df)
|
||||
table = dump_db_json_schema(db, df)
|
||||
prev_tab_num = len(ex_tabs[df]["table_names"])
|
||||
prev_col_num = len(ex_tabs[df]["column_names"])
|
||||
cur_tab_num = len(table["table_names"])
|
||||
cur_col_num = len(table["column_names"])
|
||||
if (
|
||||
df in ex_tabs.keys()
|
||||
and prev_tab_num == cur_tab_num
|
||||
and prev_col_num == cur_col_num
|
||||
and prev_tab_num != 0
|
||||
and len(ex_tabs[df]["column_names"]) > 1
|
||||
):
|
||||
table["table_names"] = ex_tabs[df]["table_names"]
|
||||
table["column_names"] = ex_tabs[df]["column_names"]
|
||||
else:
|
||||
print("\n----------------------------------problem db: ", df)
|
||||
tables.append(table)
|
||||
print("final db num: ", len(tables))
|
||||
with open(output_file, "wt") as out:
|
||||
json.dump(tables, out, sort_keys=True, indent=2, separators=(",", ": "))
|
45
unified_parser_text_to_sql/third_party/spider/preprocess/parse_raw_json.py
поставляемый
Normal file
45
unified_parser_text_to_sql/third_party/spider/preprocess/parse_raw_json.py
поставляемый
Normal file
|
@ -0,0 +1,45 @@
|
|||
import os, sys
|
||||
import json
|
||||
import sqlite3
|
||||
import traceback
|
||||
import argparse
|
||||
import tqdm
|
||||
|
||||
from ..process_sql import get_sql
|
||||
from .schema import Schema, get_schemas_from_json
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", required=True)
|
||||
parser.add_argument("--tables", required=True)
|
||||
parser.add_argument("--output", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_path = args.input
|
||||
output_file = args.output
|
||||
table_file = args.tables
|
||||
|
||||
schemas, db_names, tables = get_schemas_from_json(table_file)
|
||||
|
||||
with open(sql_path) as inf:
|
||||
sql_data = json.load(inf)
|
||||
|
||||
sql_data_new = []
|
||||
for data in tqdm.tqdm(sql_data):
|
||||
try:
|
||||
db_id = data["db_id"]
|
||||
schema = schemas[db_id]
|
||||
table = tables[db_id]
|
||||
schema = Schema(schema, table)
|
||||
sql = data["query"]
|
||||
sql_label = get_sql(schema, sql)
|
||||
data["sql"] = sql_label
|
||||
sql_data_new.append(data)
|
||||
except:
|
||||
print("db_id: ", db_id)
|
||||
print("sql: ", sql)
|
||||
raise
|
||||
|
||||
with open(output_file, "wt") as out:
|
||||
json.dump(sql_data_new, out, sort_keys=True, indent=4, separators=(",", ": "))
|
27
unified_parser_text_to_sql/third_party/spider/preprocess/parse_sql_one.py
поставляемый
Normal file
27
unified_parser_text_to_sql/third_party/spider/preprocess/parse_sql_one.py
поставляемый
Normal file
|
@ -0,0 +1,27 @@
|
|||
import os
|
||||
import traceback
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
import random
|
||||
from os import listdir, makedirs
|
||||
from collections import OrderedDict
|
||||
from nltk import word_tokenize, tokenize
|
||||
from os.path import isfile, isdir, join, split, exists, splitext
|
||||
|
||||
from ..process_sql import get_sql
|
||||
from .schema import Schema, get_schemas_from_json
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
sql = "SELECT name , country , age FROM singer ORDER BY age DESC"
|
||||
db_id = "concert_singer"
|
||||
table_file = "tables.json"
|
||||
|
||||
schemas, db_names, tables = get_schemas_from_json(table_file)
|
||||
schema = schemas[db_id]
|
||||
table = tables[db_id]
|
||||
schema = Schema(schema, table)
|
||||
sql_label = get_sql(schema, sql)
|
74
unified_parser_text_to_sql/third_party/spider/preprocess/parsed_sql_examples.sql
поставляемый
Normal file
74
unified_parser_text_to_sql/third_party/spider/preprocess/parsed_sql_examples.sql
поставляемый
Normal file
|
@ -0,0 +1,74 @@
|
|||
################################
|
||||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
#
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||||
|
||||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||||
TABLE_TYPE = {
|
||||
'sql': "sql",
|
||||
'table_unit': "table_unit",
|
||||
}
|
||||
COND_OPS = ('and', 'or')
|
||||
SQL_OPS = ('intersect', 'union', 'except')
|
||||
ORDER_OPS = ('desc', 'asc')
|
||||
|
||||
|
||||
############ Example 1 ############
|
||||
|
||||
SELECT T1.lname , T1.fname
|
||||
FROM artists AS T1 JOIN paintings AS T2 ON T1.artistID = T2.painterID
|
||||
EXCEPT
|
||||
SELECT T1.lname , T1.fname
|
||||
FROM artists AS T1 JOIN sculptures AS T2 ON T1.artistID = T2.sculptorID
|
||||
|
||||
'select': (False #is distinct#, [(0 #index of AGG_OPS#, (0 #index of unit_op to save col1-col2 cases#, (0#index of AGG_OPS#, '__artists.lname__' #index of column names#, False #is DISTINCT#), None #col_unit2 usually is None, for saving col1-col2#)), (0, (0, (0, '__artists.fname__', False), None))])
|
||||
'from': {'table_units' #list of tables in from#: [('table_unit' #some froms are nested sql#, '__artists__' #gonna be index of table#), ('table_unit', '__paintings__')],
|
||||
'conds': [(False #if there is NOT#, 2 #index of WHERE_OPS#, (0 #index of unit_op to save col1-col2 cases#, (0 #index of AGG_OPS#, '__artists.artistid__' #index of column names#, False #is DISTINCT#), None #col_unit2 usually is None, for saving col1-col2#), 't2.painterid' #val1, here is t2 ref id#, None #val2 for between val1 and val2#]}
|
||||
'except': {'from': {'table_units': [('table_unit', '__artists__'), ('table_unit', '__sculptures__')], 'conds': [(False, 2, (0, (0, '__artists.artistid__', False), None), (0, '__sculptures.sculptorid__', False), None)]}, 'select': (False, [(0, (0, (0, '__artists.lname__', False), None)), (0, (0, (0, '__artists.fname__', False), None))])}
|
||||
|
||||
############ Example 2 ############
|
||||
|
||||
SELECT paintingID
|
||||
FROM paintings
|
||||
WHERE height_mm > (SELECT max(height_mm)
|
||||
FROM paintings
|
||||
WHERE YEAR > 1900)
|
||||
'select': (False, [(0, (0, (0, '__paintings.paintingid__', False), None))])
|
||||
'from': {'table_units': [('table_unit', '__paintings__')], 'conds': []}
|
||||
'where': [(False, 3, (0, (0, '__paintings.height_mm__', False), None) #finshed val_unit1#, #start val1 which is a sql# {'from': {'table_units': [('table_unit', '__paintings__')], 'conds': []}, 'where': [(False, 3, (0, (0, '__paintings.year__', False), None), 1900.0 #cond val1#, None #cond val2#)], 'select': (False, [(1, (0, (0, '__paintings.height_mm__', False), None))])}, None)]
|
||||
|
||||
############ Example 3 ############
|
||||
|
||||
ORDER BY count(*) DESC LIMIT 1
|
||||
'orderBy': ('desc', [(0 #index of unit_op no -/+#, (3 #agg count index#, '__all__', False), None)])
|
||||
|
||||
############ Example 4 ############
|
||||
|
||||
GROUP BY T2.painterID HAVING count(*) >= 2
|
||||
'groupBy': [(0 #index of AGG_OPS#, '__paintings.painterid__', False #is distinct#)], 'having': [(False, 5, (0, (3, '__all__', False), None), 2.0, None)]
|
|
@ -0,0 +1,80 @@
|
|||
import json
|
||||
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
|
||||
def __init__(self, schema, table):
|
||||
self._schema = schema
|
||||
self._table = table
|
||||
self._idMap = self._map(self._schema, self._table)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema, table):
|
||||
|
||||
if 'column_names_original' not in table: # {'table': [col.lower, ..., ]} * -> __all__
|
||||
table["column_names_original"] = table["column_names"]
|
||||
table["table_names_original"] = table["table_names"]
|
||||
|
||||
column_names_original = table["column_names_original"]
|
||||
table_names_original = table["table_names_original"]
|
||||
# print 'column_names_original: ', column_names_original
|
||||
# print 'table_names_original: ', table_names_original
|
||||
for i, (tab_id, col) in enumerate(column_names_original):
|
||||
if tab_id == -1:
|
||||
idMap = {"*": i}
|
||||
else:
|
||||
key = table_names_original[tab_id].lower()
|
||||
val = col.lower()
|
||||
idMap[key + "." + val] = i
|
||||
|
||||
for i, tab in enumerate(table_names_original):
|
||||
key = tab.lower()
|
||||
idMap[key] = i
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def _get_schemas_from_json(data: dict):
|
||||
db_names = [db["db_id"] for db in data]
|
||||
|
||||
tables = {}
|
||||
schemas = {}
|
||||
for db in data:
|
||||
db_id = db["db_id"]
|
||||
schema = {}
|
||||
|
||||
if 'column_names_original' not in db:
|
||||
db["column_names_original"] = db["column_names"]
|
||||
db["table_names_original"] = db["table_names"]
|
||||
|
||||
column_names_original = db["column_names_original"]
|
||||
table_names_original = db["table_names_original"]
|
||||
|
||||
tables[db_id] = {
|
||||
"column_names_original": column_names_original,
|
||||
"table_names_original": table_names_original,
|
||||
}
|
||||
|
||||
for i, tabn in enumerate(table_names_original):
|
||||
table = str(tabn.lower())
|
||||
cols = [str(col.lower()) for td, col in column_names_original if td == i]
|
||||
schema[table] = cols
|
||||
schemas[db_id] = schema
|
||||
|
||||
return schemas, db_names, tables
|
||||
|
||||
|
||||
def get_schemas_from_json(fpath):
|
||||
with open(fpath, 'r',encoding='UTF-8') as f:
|
||||
data = json.load(f)
|
||||
return _get_schemas_from_json(data)
|
|
@ -0,0 +1,659 @@
|
|||
################################
|
||||
# Assumptions:
|
||||
# 1. sql is correct
|
||||
# 2. only table name has alias
|
||||
# 3. only one intersect/union/except
|
||||
#
|
||||
# val: number(float)/string(str)/sql(dict)
|
||||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||||
# table_unit: (table_type, col_unit/sql)
|
||||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||||
# sql {
|
||||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||||
# 'where': condition
|
||||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||||
# 'having': condition
|
||||
# 'limit': None/limit value
|
||||
# 'intersect': None/sql
|
||||
# 'except': None/sql
|
||||
# 'union': None/sql
|
||||
# }
|
||||
################################
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from nltk import word_tokenize
|
||||
import copy
|
||||
|
||||
CLAUSE_KEYWORDS = (
|
||||
"select",
|
||||
"from",
|
||||
"where",
|
||||
"group",
|
||||
"order",
|
||||
"limit",
|
||||
"intersect",
|
||||
"union",
|
||||
"except",
|
||||
)
|
||||
JOIN_KEYWORDS = ("join", "on", "as")
|
||||
|
||||
WHERE_OPS = (
|
||||
"not",
|
||||
"between",
|
||||
"=",
|
||||
">",
|
||||
"<",
|
||||
">=",
|
||||
"<=",
|
||||
"!=",
|
||||
"in",
|
||||
"like",
|
||||
"is",
|
||||
"exists",
|
||||
)
|
||||
UNIT_OPS = ("none", "-", "+", "*", "/")
|
||||
AGG_OPS = ("none", "max", "min", "count", "sum", "avg")
|
||||
TABLE_TYPE = {
|
||||
"sql": "sql",
|
||||
"table_unit": "table_unit",
|
||||
}
|
||||
|
||||
COND_OPS = ("and", "or")
|
||||
SQL_OPS = ("intersect", "union", "except")
|
||||
ORDER_OPS = ("desc", "asc")
|
||||
|
||||
global_mentioned_schema = []
|
||||
|
||||
class Schema:
|
||||
"""
|
||||
Simple schema which maps table&column to a unique identifier
|
||||
"""
|
||||
|
||||
def __init__(self, schema):
|
||||
self._schema = schema
|
||||
self._idMap = self._map(self._schema)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
return self._schema
|
||||
|
||||
@property
|
||||
def idMap(self):
|
||||
return self._idMap
|
||||
|
||||
def _map(self, schema):
|
||||
idMap = {"*": "__all__"}
|
||||
id = 1
|
||||
for key, vals in schema.items():
|
||||
for val in vals:
|
||||
idMap[key.lower() + "." + val.lower()] = (
|
||||
"__" + key.lower() + "." + val.lower() + "__"
|
||||
)
|
||||
id += 1
|
||||
|
||||
for key in schema:
|
||||
idMap[key.lower()] = "__" + key.lower() + "__"
|
||||
id += 1
|
||||
|
||||
return idMap
|
||||
|
||||
|
||||
def get_schema(db):
|
||||
"""
|
||||
Get database's schema, which is a dict with table name as key
|
||||
and list of column names as value
|
||||
:param db: database path
|
||||
:return: schema dict
|
||||
"""
|
||||
|
||||
schema = {}
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# fetch table names
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
tables = [str(table[0].lower()) for table in cursor.fetchall()]
|
||||
|
||||
# fetch table info
|
||||
for table in tables:
|
||||
cursor.execute("PRAGMA table_info({})".format(table))
|
||||
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_schema_from_json(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
|
||||
schema = {}
|
||||
for entry in data:
|
||||
table = str(entry["table"].lower())
|
||||
cols = [str(col["column_name"].lower()) for col in entry["col_data"]]
|
||||
schema[table] = cols
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def tokenize(string):
|
||||
string = str(string)
|
||||
string = string.replace(
|
||||
"'", '"'
|
||||
) # ensures all string values wrapped by "" problem??
|
||||
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
|
||||
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
|
||||
|
||||
# keep string value as token
|
||||
vals = {}
|
||||
for i in range(len(quote_idxs) - 1, -1, -2):
|
||||
qidx1 = quote_idxs[i - 1]
|
||||
qidx2 = quote_idxs[i]
|
||||
val = string[qidx1 : qidx2 + 1]
|
||||
key = "__val_{}_{}__".format(qidx1, qidx2)
|
||||
string = string[:qidx1] + key + string[qidx2 + 1 :]
|
||||
vals[key] = val
|
||||
|
||||
toks = [word.lower() for word in word_tokenize(string)]
|
||||
# replace with string value token
|
||||
for i in range(len(toks)):
|
||||
if toks[i] in vals:
|
||||
toks[i] = vals[toks[i]]
|
||||
|
||||
# find if there exists !=, >=, <=
|
||||
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
|
||||
eq_idxs.reverse()
|
||||
prefix = ("!", ">", "<")
|
||||
for eq_idx in eq_idxs:
|
||||
pre_tok = toks[eq_idx - 1]
|
||||
if pre_tok in prefix:
|
||||
toks = toks[: eq_idx - 1] + [pre_tok + "="] + toks[eq_idx + 1 :]
|
||||
|
||||
return toks
|
||||
|
||||
|
||||
def scan_alias(toks):
|
||||
"""Scan the index of 'as' and build the map for all alias"""
|
||||
as_idxs = [idx for idx, tok in enumerate(toks) if tok == "as"]
|
||||
alias = {}
|
||||
for idx in as_idxs:
|
||||
alias[toks[idx + 1]] = toks[idx - 1]
|
||||
return alias
|
||||
|
||||
|
||||
def get_tables_with_alias(schema, toks):
|
||||
tables = scan_alias(toks)
|
||||
for key in schema:
|
||||
assert key not in tables, "Alias {} has the same name in table".format(key)
|
||||
tables[key] = key
|
||||
return tables
|
||||
|
||||
|
||||
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, column id
|
||||
"""
|
||||
tok = toks[start_idx]
|
||||
if tok == "*":
|
||||
global_mentioned_schema.append(('column', schema.idMap[tok]))
|
||||
return start_idx + 1, schema.idMap[tok]
|
||||
|
||||
if "." in tok: # if token is a composite
|
||||
alias, col = tok.split(".")
|
||||
key = tables_with_alias[alias] + "." + col
|
||||
global_mentioned_schema.append(('column', schema.idMap[key]))
|
||||
return start_idx + 1, schema.idMap[key]
|
||||
|
||||
assert (
|
||||
default_tables is not None and len(default_tables) > 0
|
||||
), "Default tables should not be None or empty"
|
||||
|
||||
for alias in default_tables:
|
||||
table = tables_with_alias[alias]
|
||||
if tok in schema.schema[table]:
|
||||
key = table + "." + tok
|
||||
global_mentioned_schema.append(('column', schema.idMap[key]))
|
||||
return start_idx + 1, schema.idMap[key]
|
||||
|
||||
assert False, "Error col: {}".format(tok)
|
||||
|
||||
|
||||
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
"""
|
||||
:returns next idx, (agg_op id, col_id)
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
isDistinct = False
|
||||
if toks[idx] == "(":
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
assert idx < len_ and toks[idx] == "("
|
||||
idx += 1
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
assert idx < len_ and toks[idx] == ")"
|
||||
idx += 1
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
if toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
agg_id = AGG_OPS.index("none")
|
||||
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ")"
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (agg_id, col_id, isDistinct)
|
||||
|
||||
|
||||
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
isBlock = False
|
||||
if toks[idx] == "(":
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
col_unit1 = None
|
||||
col_unit2 = None
|
||||
unit_op = UNIT_OPS.index("none")
|
||||
|
||||
idx, col_unit1 = parse_col_unit(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
if idx < len_ and toks[idx] in UNIT_OPS:
|
||||
unit_op = UNIT_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, col_unit2 = parse_col_unit(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ")"
|
||||
idx += 1 # skip ')'
|
||||
|
||||
return idx, (unit_op, col_unit1, col_unit2)
|
||||
|
||||
|
||||
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
:returns next idx, table id, table name
|
||||
"""
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
key = tables_with_alias[toks[idx]]
|
||||
|
||||
if idx + 1 < len_ and toks[idx + 1] == "as":
|
||||
idx += 3
|
||||
else:
|
||||
idx += 1
|
||||
|
||||
global_mentioned_schema.append(('table', schema.idMap[key]))
|
||||
return idx, schema.idMap[key], key
|
||||
|
||||
|
||||
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
isBlock = False
|
||||
if toks[idx] == "(":
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == "select":
|
||||
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
elif '"' in toks[idx]: # token is a string value
|
||||
val = toks[idx]
|
||||
idx += 1
|
||||
else:
|
||||
try:
|
||||
val = float(toks[idx])
|
||||
idx += 1
|
||||
except:
|
||||
end_idx = idx
|
||||
while (
|
||||
end_idx < len_
|
||||
and toks[end_idx] != ","
|
||||
and toks[end_idx] != ")"
|
||||
and toks[end_idx] != "and"
|
||||
and toks[end_idx] not in CLAUSE_KEYWORDS
|
||||
and toks[end_idx] not in JOIN_KEYWORDS
|
||||
):
|
||||
end_idx += 1
|
||||
|
||||
idx, val = parse_col_unit(
|
||||
toks[start_idx:end_idx], 0, tables_with_alias, schema, default_tables
|
||||
)
|
||||
idx = end_idx
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ")"
|
||||
idx += 1
|
||||
|
||||
return idx, val
|
||||
|
||||
|
||||
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
idx, val_unit = parse_val_unit(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
not_op = False
|
||||
if toks[idx] == "not":
|
||||
not_op = True
|
||||
idx += 1
|
||||
|
||||
assert (
|
||||
idx < len_ and toks[idx] in WHERE_OPS
|
||||
), "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
|
||||
op_id = WHERE_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
val1 = val2 = None
|
||||
if op_id == WHERE_OPS.index(
|
||||
"between"
|
||||
): # between..and... special case: dual values
|
||||
idx, val1 = parse_value(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
assert toks[idx] == "and"
|
||||
idx += 1
|
||||
idx, val2 = parse_value(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
else: # normal case: single value
|
||||
idx, val1 = parse_value(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
val2 = None
|
||||
|
||||
conds.append((not_op, op_id, val_unit, val1, val2))
|
||||
|
||||
if idx < len_ and (
|
||||
toks[idx] in CLAUSE_KEYWORDS
|
||||
or toks[idx] in (")", ";")
|
||||
or toks[idx] in JOIN_KEYWORDS
|
||||
):
|
||||
break
|
||||
|
||||
if idx < len_ and toks[idx] in COND_OPS:
|
||||
conds.append(toks[idx])
|
||||
idx += 1 # skip and/or
|
||||
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
assert toks[idx] == "select", "'select' not found"
|
||||
idx += 1
|
||||
isDistinct = False
|
||||
if idx < len_ and toks[idx] == "distinct":
|
||||
idx += 1
|
||||
isDistinct = True
|
||||
val_units = []
|
||||
|
||||
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
|
||||
agg_id = AGG_OPS.index("none")
|
||||
if toks[idx] in AGG_OPS:
|
||||
agg_id = AGG_OPS.index(toks[idx])
|
||||
idx += 1
|
||||
idx, val_unit = parse_val_unit(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
val_units.append((agg_id, val_unit))
|
||||
if idx < len_ and toks[idx] == ",":
|
||||
idx += 1 # skip ','
|
||||
|
||||
return idx, (isDistinct, val_units)
|
||||
|
||||
|
||||
def parse_from(toks, start_idx, tables_with_alias, schema):
|
||||
"""
|
||||
Assume in the from clause, all table units are combined with join
|
||||
"""
|
||||
assert "from" in toks[start_idx:], "'from' not found"
|
||||
|
||||
len_ = len(toks)
|
||||
idx = toks.index("from", start_idx) + 1
|
||||
default_tables = []
|
||||
table_units = []
|
||||
conds = []
|
||||
|
||||
while idx < len_:
|
||||
isBlock = False
|
||||
if toks[idx] == "(":
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
if toks[idx] == "select":
|
||||
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
table_units.append((TABLE_TYPE["sql"], sql))
|
||||
else:
|
||||
if idx < len_ and toks[idx] == "join":
|
||||
idx += 1 # skip join
|
||||
idx, table_unit, table_name = parse_table_unit(
|
||||
toks, idx, tables_with_alias, schema
|
||||
)
|
||||
table_units.append((TABLE_TYPE["table_unit"], table_unit))
|
||||
default_tables.append(table_name)
|
||||
if idx < len_ and toks[idx] == "on":
|
||||
idx += 1 # skip on
|
||||
idx, this_conds = parse_condition(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
if len(conds) > 0:
|
||||
conds.append("and")
|
||||
conds.extend(this_conds)
|
||||
|
||||
if isBlock:
|
||||
assert toks[idx] == ")"
|
||||
idx += 1
|
||||
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
break
|
||||
|
||||
return idx, table_units, conds, default_tables
|
||||
|
||||
|
||||
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != "where":
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
col_units = []
|
||||
|
||||
if idx >= len_ or toks[idx] != "group":
|
||||
return idx, col_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == "by"
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, col_unit = parse_col_unit(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
col_units.append(col_unit)
|
||||
if idx < len_ and toks[idx] == ",":
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, col_units
|
||||
|
||||
|
||||
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
val_units = []
|
||||
order_type = "asc" # default type is 'asc'
|
||||
|
||||
if idx >= len_ or toks[idx] != "order":
|
||||
return idx, val_units
|
||||
|
||||
idx += 1
|
||||
assert toks[idx] == "by"
|
||||
idx += 1
|
||||
|
||||
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
|
||||
idx, val_unit = parse_val_unit(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
val_units.append(val_unit)
|
||||
if idx < len_ and toks[idx] in ORDER_OPS:
|
||||
order_type = toks[idx]
|
||||
idx += 1
|
||||
if idx < len_ and toks[idx] == ",":
|
||||
idx += 1 # skip ','
|
||||
else:
|
||||
break
|
||||
|
||||
return idx, (order_type, val_units)
|
||||
|
||||
|
||||
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx >= len_ or toks[idx] != "having":
|
||||
return idx, []
|
||||
|
||||
idx += 1
|
||||
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
|
||||
return idx, conds
|
||||
|
||||
|
||||
def parse_limit(toks, start_idx):
|
||||
idx = start_idx
|
||||
len_ = len(toks)
|
||||
|
||||
if idx < len_ and toks[idx] == "limit":
|
||||
idx += 2
|
||||
return idx, int(toks[idx - 1])
|
||||
|
||||
return idx, None
|
||||
|
||||
|
||||
def parse_sql(toks, start_idx, tables_with_alias, schema, require_mentioned_schema=False):
|
||||
global global_mentioned_schema
|
||||
global_mentioned_schema = []
|
||||
|
||||
isBlock = False # indicate whether this is a block of sql/sub-sql
|
||||
len_ = len(toks)
|
||||
idx = start_idx
|
||||
|
||||
sql = {}
|
||||
if toks[idx] == "(":
|
||||
isBlock = True
|
||||
idx += 1
|
||||
|
||||
# parse from clause in order to get default tables
|
||||
from_end_idx, table_units, conds, default_tables = parse_from(
|
||||
toks, start_idx, tables_with_alias, schema
|
||||
)
|
||||
sql["from"] = {"table_units": table_units, "conds": conds}
|
||||
# select clause
|
||||
_, select_col_units = parse_select(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
idx = from_end_idx
|
||||
sql["select"] = select_col_units
|
||||
# where clause
|
||||
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
|
||||
sql["where"] = where_conds
|
||||
# group by clause
|
||||
idx, group_col_units = parse_group_by(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
sql["groupBy"] = group_col_units
|
||||
# having clause
|
||||
idx, having_conds = parse_having(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
sql["having"] = having_conds
|
||||
# order by clause
|
||||
idx, order_col_units = parse_order_by(
|
||||
toks, idx, tables_with_alias, schema, default_tables
|
||||
)
|
||||
sql["orderBy"] = order_col_units
|
||||
# limit clause
|
||||
idx, limit_val = parse_limit(toks, idx)
|
||||
sql["limit"] = limit_val
|
||||
|
||||
idx = skip_semicolon(toks, idx)
|
||||
if isBlock:
|
||||
assert toks[idx] == ")"
|
||||
idx += 1 # skip ')'
|
||||
idx = skip_semicolon(toks, idx)
|
||||
|
||||
# intersect/union/except clause
|
||||
for op in SQL_OPS: # initialize IUE
|
||||
sql[op] = None
|
||||
if idx < len_ and toks[idx] in SQL_OPS:
|
||||
sql_op = toks[idx]
|
||||
idx += 1
|
||||
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
|
||||
sql[sql_op] = IUE_sql
|
||||
|
||||
if require_mentioned_schema:
|
||||
return idx, sql, global_mentioned_schema
|
||||
else:
|
||||
return idx, sql
|
||||
|
||||
|
||||
def load_data(fpath):
|
||||
with open(fpath) as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def get_sql(schema, query):
|
||||
toks = tokenize(query)
|
||||
tables_with_alias = get_tables_with_alias(schema.schema, toks)
|
||||
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
|
||||
|
||||
return sql
|
||||
|
||||
def extract_mentioned_schema_in_sql(schema, query):
|
||||
toks = tokenize(query)
|
||||
tables_with_alias = get_tables_with_alias(schema.schema, toks)
|
||||
|
||||
_, sql, mentioned_schema = parse_sql(toks, 0, tables_with_alias, schema, require_mentioned_schema = True)
|
||||
|
||||
return mentioned_schema
|
||||
|
||||
def skip_semicolon(toks, start_idx):
|
||||
idx = start_idx
|
||||
while idx < len(toks) and toks[idx] == ";":
|
||||
idx += 1
|
||||
return idx
|
|
@ -0,0 +1,74 @@
|
|||
import subprocess
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def run_command(bash_command):
|
||||
process = subprocess.Popen(bash_command.split())
|
||||
output, error = process.communicate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset_path", type=str, default="", help="dataset path")
|
||||
parser.add_argument("--exp_name", type=str, default="", help="test")
|
||||
parser.add_argument("--models_path", type=str, default="", help="models path")
|
||||
parser.add_argument("--bart_model_path", type=str, default="", help="bart init models path")
|
||||
parser.add_argument("--total_num_update", type=int, default=200000)
|
||||
parser.add_argument("--max_tokens", type=int, default=4096)
|
||||
parser.add_argument("--tensorboard_path", type=str, default="", help="tensorboard path")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("START training")
|
||||
run_command("printenv")
|
||||
|
||||
restore_file = os.path.join(args.bart_model_path, "model.pt")
|
||||
|
||||
cmd = f"""
|
||||
fairseq-train {args.dataset_path} \
|
||||
--save-dir {args.models_path}/{args.exp_name} \
|
||||
--restore-file {restore_file} \
|
||||
--arch bart_large \
|
||||
--criterion label_smoothed_cross_entropy \
|
||||
--source-lang src \
|
||||
--target-lang tgt \
|
||||
--truncate-source \
|
||||
--label-smoothing 0.1 \
|
||||
--max-tokens {args.max_tokens} \
|
||||
--update-freq 4 \
|
||||
--max-update {args.total_num_update} \
|
||||
--required-batch-size-multiple 1 \
|
||||
--dropout 0.1 \
|
||||
--attention-dropout 0.1 \
|
||||
--relu-dropout 0.0 \
|
||||
--weight-decay 0.05 \
|
||||
--optimizer adam \
|
||||
--adam-eps 1e-08 \
|
||||
--clip-norm 0.1 \
|
||||
--lr-scheduler polynomial_decay \
|
||||
--lr 1e-05 \
|
||||
--total-num-update {args.total_num_update} \
|
||||
--warmup-updates 5000 \
|
||||
--ddp-backend no_c10d \
|
||||
--num-workers 20 \
|
||||
--reset-meters \
|
||||
--reset-optimizer \
|
||||
--reset-dataloader \
|
||||
--share-all-embeddings \
|
||||
--layernorm-embedding \
|
||||
--share-decoder-input-output-embed \
|
||||
--skip-invalid-size-inputs-valid-test \
|
||||
--log-format json \
|
||||
--log-interval 10 \
|
||||
--save-interval-updates 500 \
|
||||
--validate-interval-updates 500 \
|
||||
--validate-interval 10 \
|
||||
--save-interval 10 \
|
||||
--patience 200 \
|
||||
--no-last-checkpoints \
|
||||
--no-save-optimizer-state \
|
||||
--report-accuracy
|
||||
"""
|
||||
|
||||
print("RUN {}".format(cmd))
|
||||
run_command(cmd)
|
|
@ -0,0 +1,110 @@
|
|||
"""
|
||||
Based on https://github.com/ElementAI/Unisar/blob/master/Unisar/api.py
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from genre.fairseq_model import GENRE
|
||||
from semparse.contexts.spider_db_context import SpiderDBContext
|
||||
from semparse.sql.spider import load_original_schemas, load_tables
|
||||
from semparse.sql.spider_utils import read_dataset_schema
|
||||
from step1_schema_linking import read_database_schema
|
||||
from step2_serialization import build_schema_linking_data
|
||||
from step3_evaluate import decode_with_constrain, get_alias_schema, post_processing_sql
|
||||
|
||||
|
||||
class UnisarAPI(object):
|
||||
def __init__(self, logdir: str, config_path: str):
|
||||
self.model = self.inferer.load_model(logdir, step=None)
|
||||
|
||||
|
||||
def convert_csv_to_sqlite(csv_path: str):
|
||||
# TODO: infer types when importing
|
||||
db_path = csv_path + ".sqlite"
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
subprocess.run(["sqlite3", db_path, ".mode csv", f".import {csv_path} Data"])
|
||||
return db_path
|
||||
|
||||
|
||||
class UnisarAPI(object):
|
||||
"""Run Unisar model on a given database."""
|
||||
|
||||
def __init__(self, log_dir: str, db_path: str, schema_path: Optional[str], stanza_model):
|
||||
self.log_dir = log_dir
|
||||
self.db_path = db_path
|
||||
self.schema_path = schema_path
|
||||
self.stanza_model = stanza_model
|
||||
|
||||
# if self.db_path.endswith(".sqlite"):
|
||||
# pass
|
||||
# elif self.db_path.endswith(".csv"):
|
||||
# self.db_path = convert_csv_to_sqlite(self.db_path)
|
||||
# else:
|
||||
# raise ValueError("expected either .sqlite or .csv file")
|
||||
|
||||
self.schema = read_dataset_schema(self.schema_path, stanza_model)
|
||||
_, _, self.database_schemas = read_database_schema(self.schema_path)
|
||||
self.model = GENRE.from_pretrained(self.log_dir).eval()
|
||||
if torch.cuda.is_available():
|
||||
self.model.cuda()
|
||||
|
||||
def infer_query(self, question, db_id):
|
||||
###step-1 schema-linking
|
||||
lemma_utterance_stanza = self.stanza_model(question)
|
||||
lemma_utterance = [word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words]
|
||||
db_context = SpiderDBContext(db_id,
|
||||
lemma_utterance,
|
||||
tables_file=self.schema_path,
|
||||
dataset_path=self.db_path,
|
||||
stanza_model=self.stanza_model,
|
||||
schemas=self.schema,
|
||||
original_utterance=question)
|
||||
value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(db_id)
|
||||
|
||||
item = {}
|
||||
item['interaction'] = [{'db_id': db_id,
|
||||
'question': question,
|
||||
'sql': '',
|
||||
'value_match': value_match,
|
||||
'value_alignment': value_alignment,
|
||||
'exact_match': exact_match,
|
||||
'partial_match': partial_match,
|
||||
}]
|
||||
|
||||
###step-2 serialization
|
||||
source_sequence, _ = build_schema_linking_data(schema=self.database_schemas[db_id],
|
||||
question=question,
|
||||
item=item,
|
||||
turn_id=0,
|
||||
linking_type='default')
|
||||
slml_question = source_sequence[0]
|
||||
|
||||
###step-3 prediction
|
||||
schemas, eval_foreign_key_maps = load_tables(self.schema_path)
|
||||
original_schemas = load_original_schemas(self.schema_path)
|
||||
alias_schema = get_alias_schema(schemas)
|
||||
rnt = decode_with_constrain(slml_question, alias_schema[db_id], self.model)
|
||||
predict_sql = rnt[0]['text'] if isinstance(rnt[0]['text'], str) else rnt[0]['text'][0]
|
||||
score = rnt[0]['score'].tolist()
|
||||
|
||||
predict_sql = post_processing_sql(predict_sql, eval_foreign_key_maps[db_id], original_schemas[db_id],schemas[db_id])
|
||||
|
||||
return {
|
||||
"slml_question": slml_question,
|
||||
"predict_sql": predict_sql,
|
||||
"score": score
|
||||
}
|
||||
|
||||
def execute(self, query):
|
||||
### TODO: replace the query with value version
|
||||
pass
|
||||
# conn = sqlite3.connect(self.db_path)
|
||||
# # Temporary Hack: makes sure all literals are collated in a case-insensitive way
|
||||
# query = add_collate_nocase(query)
|
||||
# results = conn.execute(query).fetchall()
|
||||
# conn.close()
|
||||
# return results
|
Загрузка…
Ссылка в новой задаче