This commit is contained in:
longxud 2022-04-14 12:34:57 +08:00
Родитель eb2335c51e
Коммит 5ae4e368e2
55 изменённых файлов: 209574 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,57 @@
## Introduction
This paper introduces [UniSAr](https://arxiv.org/pdf/2203.07781.pdf), which extends existing autoregressive language models to incorporate three non-invasive extensions to make them structure-aware:
(1) adding structure mark to encode database schema, conversation context, and their relationships;
(2) constrained decoding to decode well structured SQL for a given database schema; and
(3) SQL completion to complete potential missing JOIN relationships in SQL based on database schema.
[//]: # (On seven well-known text-to-SQL datasets covering multi-domain, multi-table and multi-turn, UniSAr demonstrates highly comparable or better performance to the most advanced specifically-designed text-to-SQL models.)
## Dataset and Model
[Spider](https://github.com/taoyds/spider) -> `./data/spider`
[Fine-tuned BART model](https://huggingface.co/dreamerdeo/mark-bart/tree/main) -> `./models/spider_sl`
(Please download this model by `git-lfs` to avoid the [issue](https://github.com/DreamerDeo/UniSAr_text2sql/issues/1).)
```angular2html
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/dreamerdeo/mark-bart
```
## Main dependencies
* Python version >= 3.6
* PyTorch version >= 1.5.0
* `pip install -r requirements.txt`
* fairseq is going though changing without backward compatibility. Install `fairseq` from source and use [this](https://github.com/nicola-decao/fairseq/tree/fixing_prefix_allowed_tokens_fn) commit for reproducibilty. See [here](https://github.com/pytorch/fairseq/pull/3276) for the current PR that should fix `fairseq/master`.
## Evaluation Pipeline
Step 1: Preprocess via adding schema-linking and value-linking tag.
`python step1_schema_linking.py`
Step 2: Building the input and output for BART.
`python step2_serialization.py`
Step 3: Evaluation Script with/without constrained decoding.
`python step3_evaluate.py --constrain`
## Results
Prediction: `69.34`
Prediction with Constrain Decoding: `70.02`
## Interactive
`python interactive.py --logdir ./models/spider-sl --db_id student_1 --db-path ./data/spider/database --schema-path ./data/spider/tables.json`
## Reference Code
`https://github.com/ryanzhumich/editsql`
`https://github.com/benbogin/spider-schema-gnn-global`
`https://github.com/ElementAI/duorat`
`https://github.com/facebookresearch/GENRE`

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,20 @@
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.source.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.target.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.source.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.target.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.src.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.tgt.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
[src] Dictionary: 50264 types
[src] ./dataset_post/spider_sl/dev.bpe.src: 1034 sents, 270269 tokens, 0.0% replaced by <unk>
[tgt] Dictionary: 50264 types
[tgt] ./dataset_post/spider_sl/dev.bpe.tgt: 1034 sents, 33481 tokens, 0.0% replaced by <unk>
Wrote preprocessed data to ./dataset_post/spider_sl/bin
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/spider_sl/dict.src.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/spider_sl/dict.tgt.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
[src] Dictionary: 50264 types
[src] ./dataset_post/spider_sl/dev.bpe.src: 1034 sents, 270269 tokens, 0.0% replaced by <unk>
[tgt] Dictionary: 50264 types
[tgt] ./dataset_post/spider_sl/dev.bpe.tgt: 1034 sents, 33481 tokens, 0.0% replaced by <unk>
Wrote preprocessed data to ./dataset_post/spider_sl/bin
Namespace(align_suffix=None, alignfile=None, all_gather_list_size=16384, azureml_logging=False, bf16=False, bpe=None, cpu=False, criterion='cross_entropy', dataset_impl='mmap', destdir='./dataset_post/spider_sl/bin', empty_cache_freq=0, fp16=False, fp16_init_scale=128, fp16_no_flatten_grads=False, fp16_scale_tolerance=0.0, fp16_scale_window=None, joined_dictionary=False, log_format=None, log_interval=100, lr_scheduler='fixed', memory_efficient_bf16=False, memory_efficient_fp16=False, min_loss_scale=0.0001, model_parallel_size=1, no_progress_bar=False, nwordssrc=-1, nwordstgt=-1, only_source=False, optimizer=None, padding_factor=8, profile=False, quantization_config_path=None, reset_logging=False, scoring='bleu', seed=1, simul_type=None, source_lang='src', srcdict='./models/mark-bart/dict.src.txt', suppress_crashes=False, target_lang='tgt', task='translation', tensorboard_logdir=None, testpref=None, tgtdict='./models/mark-bart/dict.tgt.txt', threshold_loss_scale=None, thresholdsrc=0, thresholdtgt=0, tokenizer=None, tpu=False, trainpref=None, user_dir=None, validpref='./dataset_post/spider_sl/dev.bpe', wandb_project=None, workers=2)
[src] Dictionary: 50264 types
[src] ./dataset_post/spider_sl/dev.bpe.src: 1034 sents, 281522 tokens, 0.0% replaced by <unk>
[tgt] Dictionary: 50264 types
[tgt] ./dataset_post/spider_sl/dev.bpe.tgt: 1034 sents, 33481 tokens, 0.0% replaced by <unk>
Wrote preprocessed data to ./dataset_post/spider_sl/bin

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

Просмотреть файл

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

Просмотреть файл

@ -0,0 +1,133 @@
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
import torch
from genre.trie import Trie
keyword = ['select', 'distinct', 'from', 'join', 'on', 'where', 'group', 'by', 'order', 'asc', 'desc', 'limit',
'having',
'and', 'not', 'or', 'like', 'between', 'in',
'sum', 'count', 'max', 'min', 'avg',
'(', ')', ',', '>', '<', '=', '>=', '!=', '<=',
'union', 'except', 'intersect',
'1', '2', '3', '4', '5']
def get_end_to_end_prefix_allowed_tokens_fn_hf(
model,
sentences: List[str],
start_mention_token="{",
end_mention_token="}",
start_entity_token="[",
end_entity_token="]",
mention_trie: Trie = None,
candidates_trie: Trie = None,
mention_to_candidates_dict: Dict[str, List[str]] = None,
):
return _get_end_to_end_prefix_allowed_tokens_fn(
lambda x: model.tokenizer.encode(x),
lambda x: model.tokenizer.decode(torch.tensor(x)),
model.tokenizer.bos_token_id,
model.tokenizer.pad_token_id,
model.tokenizer.eos_token_id,
len(model.tokenizer) - 1,
sentences,
start_mention_token,
end_mention_token,
start_entity_token,
end_entity_token,
mention_trie,
candidates_trie,
mention_to_candidates_dict,
)
def get_end_to_end_prefix_allowed_tokens_fn_fairseq(
model,
sentences: List[str],
start_mention_token="{",
end_mention_token="}",
start_entity_token="[",
end_entity_token="]",
mention_trie: Trie = None,
candidates_trie: Trie = None,
mention_to_candidates_dict: Dict[str, List[str]] = None,
):
return _get_end_to_end_prefix_allowed_tokens_fn(
lambda x: model.encode(x).tolist(),
lambda x: model.decode(torch.tensor(x)),
model.model.decoder.dictionary.bos(),
model.model.decoder.dictionary.pad(),
model.model.decoder.dictionary.eos(),
len(model.model.decoder.dictionary),
sentences,
start_mention_token,
end_mention_token,
start_entity_token,
end_entity_token,
mention_trie,
candidates_trie,
mention_to_candidates_dict,
)
def _get_end_to_end_prefix_allowed_tokens_fn(
encode_fn,
decode_fn,
bos_token_id,
pad_token_id,
eos_token_id,
vocabulary_length,
sentences: List[str],
start_mention_token="{",
end_mention_token="}",
start_entity_token="[",
end_entity_token="]",
mention_trie: Trie = None,
candidates_trie: Trie = None,
mention_to_candidates_dict: Dict[str, List[str]] = None,
):
assert not (
candidates_trie is not None and mention_to_candidates_dict is not None
), "`candidates_trie` and `mention_to_candidates_dict` cannot be both != `None`"
codes = {}
codes["EOS"] = eos_token_id
codes["BOS"] = bos_token_id
keyword_codes = {k: encode_fn(" {}".format(k))[1] for k in keyword}
keyword_codes['wselect'] = encode_fn("{}".format('select'))[1]
def prefix_allowed_tokens_fn(batch_id, sent):
sent = sent.tolist()
trie_out = get_trie_schema(sent)
return trie_out
def get_trie_schema(sent):
pointer_start = get_keyword_mention(sent)
keyword_rnt = list(keyword_codes.values())
if pointer_start + 1 < len(sent) and pointer_start != -1:
ment_next = mention_trie.get(sent[pointer_start + 1:])
if codes["EOS"] in ment_next:
return ment_next + keyword_rnt
else:
return ment_next
else:
ment_next = mention_trie.get([])
return ment_next + keyword_rnt + [codes["EOS"]]
def get_keyword_mention(sent):
pointer_start = -1
for i, e in enumerate(sent):
if e in keyword_codes.values():
pointer_start = i
return pointer_start
return prefix_allowed_tokens_fn

Просмотреть файл

@ -0,0 +1,157 @@
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import os
from collections import defaultdict
from typing import Dict, List
import torch
from fairseq import search, utils
from fairseq.models.bart import BARTHubInterface, BARTModel
from omegaconf import open_dict
logger = logging.getLogger(__name__)
class GENREHubInterface(BARTHubInterface):
def sample(
self,
sentences: List[str],
beam: int = 5,
verbose: bool = False,
text_to_id=None,
marginalize=False,
marginalize_lenpen=0.5,
max_len_a=1024,
max_len_b=1024,
**kwargs,
) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(
tokenized_sentences,
beam,
verbose,
max_len_a=max_len_a,
max_len_b=max_len_b,
**kwargs,
)
outputs = [
[
{"text": self.decode(hypo["tokens"]), "score": hypo["score"]}
for hypo in hypos
]
for hypos in batched_hypos
]
if text_to_id:
outputs = [
[{**hypo, "id": text_to_id(hypo["text"])} for hypo in hypos]
for hypos in outputs
]
if marginalize:
for (i, hypos), hypos_tok in zip(enumerate(outputs), batched_hypos):
outputs_dict = defaultdict(list)
for hypo, hypo_tok in zip(hypos, hypos_tok):
outputs_dict[hypo["id"]].append(
{**hypo, "len": len(hypo_tok["tokens"])}
)
outputs[i] = sorted(
[
{
"id": _id,
"texts": [hypo["text"] for hypo in hypos],
"scores": torch.stack(
[hypo["score"] for hypo in hypos]
),
"score": torch.stack(
[
hypo["score"]
* hypo["len"]
/ (hypo["len"] ** marginalize_lenpen)
for hypo in hypos
]
).logsumexp(-1),
}
for _id, hypos in outputs_dict.items()
],
key=lambda x: x["score"],
reverse=True,
)
return outputs
def generate(self, *args, **kwargs) -> List[List[Dict[str, torch.Tensor]]]:
return super(BARTHubInterface, self).generate(*args, **kwargs)
def encode(self, sentence) -> torch.LongTensor:
tokens = super(BARTHubInterface, self).encode(sentence)
tokens[
tokens >= len(self.task.target_dictionary)
] = self.task.target_dictionary.unk_index
if tokens[0] != self.task.target_dictionary.bos_index:
return torch.cat(
(torch.tensor([self.task.target_dictionary.bos_index]), tokens)
)
else:
return tokens
class GENRE(BARTModel):
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="gpt2",
**kwargs,
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return GENREHubInterface(x["args"], x["task"], x["models"][0])
class mGENRE(BARTModel):
@classmethod
def from_pretrained(
cls,
model_name_or_path,
sentencepiece_model="sentence.bpe.model",
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="sentencepiece",
layernorm_embedding=True,
**kwargs,
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
sentencepiece_model=os.path.join(model_name_or_path, sentencepiece_model),
**kwargs,
)
return GENREHubInterface(x["args"], x["task"], x["models"][0])

Просмотреть файл

@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List
import torch
from genre.utils import chunk_it
from transformers import BartForConditionalGeneration, BartTokenizer
logger = logging.getLogger(__name__)
class GENREHubInterface(BartForConditionalGeneration):
def sample(
self, sentences: List[str], num_beams: int = 5, num_return_sequences=5, **kwargs
) -> List[str]:
input_args = {
k: v.to(self.device)
for k, v in self.tokenizer.batch_encode_plus(
sentences, padding=True, return_tensors="pt"
).items()
}
outputs = self.generate(
**input_args,
min_length=0,
max_length=1024,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
output_scores=True,
return_dict_in_generate=True,
**kwargs
)
return chunk_it(
[
{
"text": text,
"logprob": score,
}
for text, score in zip(
self.tokenizer.batch_decode(
outputs.sequences, skip_special_tokens=True
),
outputs.sequences_scores,
)
],
num_return_sequences,
)
def encode(self, sentence):
return self.tokenizer.encode(sentence, return_tensors="pt")[0]
class GENRE(BartForConditionalGeneration):
@classmethod
def from_pretrained(cls, model_name_or_path):
model = GENREHubInterface.from_pretrained(model_name_or_path)
model.tokenizer = BartTokenizer.from_pretrained(model_name_or_path)
return model

Просмотреть файл

@ -0,0 +1,189 @@
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
try:
import marisa_trie
except ModuleNotFoundError:
pass
class Trie(object):
def __init__(self, sequences: List[List[int]] = []):
self.trie_dict = {}
self.len = 0
if sequences:
for sequence in sequences:
Trie._add_to_trie(sequence, self.trie_dict)
self.len += 1
self.append_trie = None
self.bos_token_id = None
def append(self, trie, bos_token_id):
self.append_trie = trie
self.bos_token_id = bos_token_id
def add(self, sequence: List[int]):
Trie._add_to_trie(sequence, self.trie_dict)
self.len += 1
def get(self, prefix_sequence: List[int]):
return Trie._get_from_trie(
prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
)
@staticmethod
def load_from_dict(trie_dict):
trie = Trie()
trie.trie_dict = trie_dict
trie.len = sum(1 for _ in trie)
return trie
@staticmethod
def _add_to_trie(sequence: List[int], trie_dict: Dict):
if sequence:
if sequence[0] not in trie_dict:
trie_dict[sequence[0]] = {}
Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
@staticmethod
def _get_from_trie(
prefix_sequence: List[int],
trie_dict: Dict,
append_trie=None,
bos_token_id: int = None,
):
if len(prefix_sequence) == 0:
output = list(trie_dict.keys())
if append_trie and bos_token_id in output:
output.remove(bos_token_id)
output += list(append_trie.trie_dict.keys())
return output
elif prefix_sequence[0] in trie_dict:
return Trie._get_from_trie(
prefix_sequence[1:],
trie_dict[prefix_sequence[0]],
append_trie,
bos_token_id,
)
else:
if append_trie:
return append_trie.get(prefix_sequence)
else:
return []
def __iter__(self):
def _traverse(prefix_sequence, trie_dict):
if trie_dict:
for next_token in trie_dict:
yield from _traverse(
prefix_sequence + [next_token], trie_dict[next_token]
)
else:
yield prefix_sequence
return _traverse([], self.trie_dict)
def __len__(self):
return self.len
def __getitem__(self, value):
return self.get(value)
class MarisaTrie(object):
def __init__(
self,
sequences: List[List[int]] = [],
cache_fist_branch=True,
max_token_id=256001,
):
self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + (
[chr(i) for i in range(65000, max_token_id + 10000)]
if max_token_id >= 55000
else []
)
self.char2int = {self.int2char[i]: i for i in range(max_token_id)}
self.cache_fist_branch = cache_fist_branch
if self.cache_fist_branch:
self.zero_iter = list({sequence[0] for sequence in sequences})
assert len(self.zero_iter) == 1
self.first_iter = list({sequence[1] for sequence in sequences})
self.trie = marisa_trie.Trie(
"".join([self.int2char[i] for i in sequence]) for sequence in sequences
)
def get(self, prefix_sequence: List[int]):
if self.cache_fist_branch and len(prefix_sequence) == 0:
return self.zero_iter
elif (
self.cache_fist_branch
and len(prefix_sequence) == 1
and self.zero_iter == prefix_sequence
):
return self.first_iter
else:
key = "".join([self.int2char[i] for i in prefix_sequence])
return list(
{
self.char2int[e[len(key)]]
for e in self.trie.keys(key)
if len(e) > len(key)
}
)
def __iter__(self):
for sequence in self.trie.iterkeys():
yield [self.char2int[e] for e in sequence]
def __len__(self):
return len(self.trie)
def __getitem__(self, value):
return self.get(value)
class DummyTrieMention(object):
def __init__(self, return_values):
self._return_values = return_values
def get(self, indices=None):
return self._return_values
class DummyTrieEntity(object):
def __init__(self, return_values, codes):
self._return_values = list(
set(return_values).difference(
set(
codes[e]
for e in (
"start_mention_token",
"end_mention_token",
"start_entity_token",
)
)
)
)
self._codes = codes
def get(self, indices, depth=0):
if len(indices) == 0 and depth == 0:
return self._codes["end_mention_token"]
elif len(indices) == 0 and depth == 1:
return self._codes["start_entity_token"]
elif len(indices) == 0:
return self._return_values
elif len(indices) == 1 and indices[0] == self._codes["end_entity_token"]:
return self._codes["EOS"]
else:
return self.get(indices[1:], depth=depth + 1)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,52 @@
import argparse
import stanza
from unisar.api import UnisarAPI
class Interactive(object):
def __init__(self, Unisar: UnisarAPI):
self.unisar = Unisar
def ask_any_question(self, question, db_id):
results = self.unisar.infer_query(question, db_id)
print('input:', results['slml_question'])
print(f'"pred:" {results["predict_sql"]} ({results["score"]})')
# try:
# results = self.unisar.execute(results['query'])
# print(results)
# except Exception as e:
# print(str(e))
def show_schema(self, db_id):
for table in self.unisar.schema[db_id].values():
print("Table", f"{table.name}")
for column in table.columns:
print(" Column", f"{column.name}")
def run(self, db_id):
self.show_schema(db_id)
# self.ask_any_question('Tell me the name about organization', db_id)
while True:
question = input("Ask a question: ")
self.ask_any_question(question, db_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--logdir", default='./models/spider_sl')
parser.add_argument("--db_id", default='student_1')
parser.add_argument(
"--db-path", default='./data/spider/database',
help="The path to the sqlite database or csv file"
)
parser.add_argument(
"--schema-path", default='./data/spider/tables.json',
help="The path to the tables.json file with human-readable database schema."
)
args = parser.parse_args()
stanza_model = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma')
interactive = Interactive(UnisarAPI(args.logdir, args.db_path, args.schema_path, stanza_model))
interactive.run(args.db_id)

Просмотреть файл

@ -0,0 +1,136 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and Microsoft Corporation.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import sys
from collections import Counter
from multiprocessing import Pool
# from .bpe_utils import get_encoder
from fairseq.data.encoders.gpt2_bpe import get_encoder
def main():
"""
Helper script to encode raw text with the GPT-2 BPE using multiple processes.
The encoder.json and vocab.bpe files can be obtained here:
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder-json",
help="path to encoder.json",
)
parser.add_argument(
"--vocab-bpe",
type=str,
help="path to vocab.bpe",
)
# parser.add_argument(
# "--special-token",
# type=str,
# help="path to special tokens split by \n"
# )
parser.add_argument(
"--inputs",
nargs="+",
default=["-"],
help="input files to filter/encode",
)
parser.add_argument(
"--outputs",
nargs="+",
default=["-"],
help="path to save encoded outputs",
)
parser.add_argument(
"--keep-empty",
action="store_true",
help="keep empty lines",
)
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()
assert len(args.inputs) == len(
args.outputs
), "number of input and output paths should match"
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-"
else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-"
else sys.stdout
for output in args.outputs
]
encoder = MultiprocessingEncoder(args)
pool = Pool(args.workers, initializer=encoder.initializer)
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
stats = Counter()
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
if filt == "PASS":
for enc_line, output_h in zip(enc_lines, outputs):
print(enc_line, file=output_h)
else:
stats["num_filtered_" + filt] += 1
if i % 10000 == 0:
print("processed {} lines".format(i), file=sys.stderr)
for k, v in stats.most_common():
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
class MultiprocessingEncoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
global bpe
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
# bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe, self.args.special_token)
def encode(self, line):
global bpe
ids = bpe.encode(line)
return list(map(str, ids))
def decode(self, tokens):
global bpe
return bpe.decode(tokens)
def encode_lines(self, lines):
"""
Encode a set of lines. All lines will be encoded together.
"""
enc_lines = []
for line in lines:
line = line.strip()
if len(line) == 0 and not self.args.keep_empty:
return ["EMPTY", None]
tokens = self.encode(line)
enc_lines.append(" ".join(tokens))
return ["PASS", enc_lines]
def decode_lines(self, lines):
dec_lines = []
for line in lines:
tokens = map(int, line.strip().split())
dec_lines.append(self.decode(tokens))
return ["PASS", dec_lines]
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,11 @@
dataclasses
tqdm
pydantic
sqlparse
unidecode
networkx
stanza
nltk
vocab
sentencepiece
torch

Просмотреть файл

@ -0,0 +1,23 @@
#!/bin/bash
#requirement:
#./data/spider
#./BART-large
# data/spider -> data/spider_schema_linking_tag
python step1_schema_linking.py --dataset=spider
# data/spider_schema_linking_tag -> dataset_post/spider_sl
python step2_serialization.py
###training
python train.py \
--dataset_path ./dataset_post/spider_sl/bin/ \
--exp_name spider_sl_v1 \
--models_path ./models \
--total_num_update 10000 \
--max_tokens 1024 \
--bart_model_path ./data/BART-large \
###evaluate
python step3_evaluate --constrain

Просмотреть файл

Просмотреть файл

@ -0,0 +1,257 @@
import re
from collections import Counter, defaultdict
from typing import Dict, Tuple, List
from unidecode import unidecode
from semparse.sql.spider_utils import TableColumn, read_dataset_schema, read_dataset_values
# == stop words that will be omitted by ContextGenerator
STOP_WORDS = {"", "", "all", "being", "-", "over", "through", "yourselves", "its", "before",
"hadn", "with", "had", ",", "should", "to", "only", "under", "ours", "has", "ought", "do",
"them", "his", "than", "very", "cannot", "they", "not", "during", "yourself", "him",
"nor", "did", "didn", "'ve", "this", "she", "each", "where", "because", "doing", "some", "we", "are",
"further", "ourselves", "out", "what", "for", "weren", "does", "above", "between", "mustn", "?",
"be", "hasn", "who", "were", "here", "shouldn", "let", "hers", "by", "both", "about", "couldn",
"of", "could", "against", "isn", "or", "own", "into", "while", "whom", "down", "wasn", "your",
"from", "her", "their", "aren", "there", "been", ".", "few", "too", "wouldn", "themselves",
":", "was", "until", "more", "himself", "on", "but", "don", "herself", "haven", "those", "he",
"me", "myself", "these", "up", ";", "below", "'re", "can", "theirs", "my", "and", "would", "then",
"is", "am", "it", "doesn", "an", "as", "itself", "at", "have", "in", "any", "if", "!",
"again", "'ll", "no", "that", "when", "same", "how", "other", "which", "you", "many", "shan",
"'t", "'s", "our", "after", "most", "'d", "such", "'m", "why", "a", "off", "i", "yours", "so",
"the", "having", "once"}
digits_list = [str(i) for i in range(10)]
class SpiderDBContext:
schemas = {}
db_knowledge_graph = {}
db_tables_data = {}
def __init__(self, db_id: str, utterance: str, tables_file: str, dataset_path: str, stanza_model=None, schemas=None,
original_utterance=None):
self.dataset_path = dataset_path
self.tables_file = tables_file
self.db_id = db_id
self.utterance = utterance
self.tokenized_utterance = utterance
self.stanza_model = stanza_model
self.original_utterance = original_utterance if original_utterance is not None else utterance
if schemas is not None:
SpiderDBContext.schemas = schemas
elif db_id not in SpiderDBContext.schemas:
SpiderDBContext.schemas = read_dataset_schema(self.tables_file, self.stanza_model)
self.schema = SpiderDBContext.schemas[db_id]
@staticmethod
def entity_key_for_column(table_name: str, column: TableColumn) -> str:
return f"{table_name.lower()}@{column.name.lower()}"
if column.foreign_key is not None:
column_type = "foreign"
elif column.is_primary_key:
column_type = "primary"
else:
column_type = column.column_type
return f"column:{column_type.lower()}:{table_name.lower()}:{column.name.lower()}"
def get_db_knowledge_graph(self, db_id: str):
db_schema = self.schema
tables = db_schema.values()
if db_id not in self.db_tables_data:
self.db_tables_data[db_id] = read_dataset_values(db_id, self.dataset_path, tables)
tables_data = self.db_tables_data[db_id]
string_column_mapping: Dict[str, set] = defaultdict(set)
for table, table_data in tables_data.items():
for table_row in table_data:
for column, cell_value in zip(db_schema[table.name].columns, table_row):
if column.column_type == 'text' and type(cell_value) is str:
cell_value_normalized = self.normalize_string(cell_value)
column_key = self.entity_key_for_column(table.name, column)
string_column_mapping[cell_value_normalized].add(column_key)
# for key in string_column_mapping:
# string_column_mapping[key]=list(string_column_mapping[key])
string_entities = self.get_entities_from_question(string_column_mapping)
value_match=[]
value_alignment = []
for item in string_entities:
value_match+=item['token_in_columns']
value_alignment.append(item['alignment'])
value_match = list(set(value_match))
r_schemas={}
for table in db_schema:
r_schemas["{0}".format(db_schema[table].name).lower()] = db_schema[table].lemma.lower()
for column in db_schema[table].columns:
r_schemas[f"{db_schema[table].name}@{column.name}".lower()] = column.lemma.strip('is ').lower()
question_tokens = [t for t in self.tokenized_utterance]
schema_counter = Counter()
partial_match = []
exact_match = []
for r_k, r_s in r_schemas.items():
schema_counter[r_s] = 0
#exact match
if r_s in self.tokenized_utterance and r_s not in STOP_WORDS:
schema_counter[r_s] += 2
#partial_match
else:
for tok in r_s.split(' '):
if tok in question_tokens and tok not in STOP_WORDS:
schema_counter[r_s]+=1
continue
for ques_tok in question_tokens:
if tok in STOP_WORDS or ques_tok in STOP_WORDS or \
len(tok)<=3 or len(ques_tok)<=3:
continue
if ques_tok in tok or tok in ques_tok:
schema_counter[r_s] += 1
if schema_counter[r_s]>=2:
exact_match.append(r_k)
elif schema_counter[r_s]==1:
partial_match.append(r_k)
return value_match, value_alignment, exact_match, partial_match
def _string_in_table(self, candidate: str,
string_column_mapping: Dict[str, set]) -> List[str]:
"""
Checks if the string occurs in the table, and if it does, returns the names of the columns
under which it occurs. If it does not, returns an empty list.
"""
candidate_column_names: List[str] = []
alignment = []
# First check if the entire candidate occurs as a cell.
candidate= candidate.strip('-_"\'')
if candidate in string_column_mapping and candidate not in digits_list:
candidate_column_names = string_column_mapping[candidate]
alignment.append((candidate,candidate))
# # If not, check if it is a substring pf any cell value.
# if not candidate_column_names:
# for cell_value, column_names in string_column_mapping.items():
# if candidate in re.split(' |_|:',
# cell_value) and candidate not in STOP_WORDS and candidate not in digits_list:
# candidate_column_names.extend(column_names)
# alignment.append((candidate, cell_value))
candidate_column_names = list(set(candidate_column_names))
return candidate_column_names, alignment
def get_entities_from_question(self, string_column_mapping: Dict[str, set]) -> List[Tuple[str, str]]:
entity_data = []
for cell_value, column_names in string_column_mapping.items():
if (cell_value.replace('_', ' ') in ' '.join(self.utterance) or cell_value.replace('_',
' ') in self.original_utterance) \
and len(re.split('_', cell_value)) >= 2:
entity_data.append({'value': cell_value,
'token_start': 0,
'token_end': 0,
'alignment': [(cell_value, cell_value)],
'token_in_columns': list(set(column_names))})
for i, token in enumerate(self.tokenized_utterance):
token_text = token
if token_text in STOP_WORDS:
continue
normalized_token_text = token_text
# normalized_token_text = self.normalize_string(token_text)
if not normalized_token_text:
continue
token_columns, alignment = self._string_in_table(normalized_token_text, string_column_mapping)
if token_columns:
entity_data.append({'value': normalized_token_text,
'token_start': i,
'token_end': i+1,
'alignment': alignment,
'token_in_columns': token_columns})
return entity_data
@staticmethod
def normalize_string(string: str) -> str:
"""
These are the transformation rules used to normalize cell in column names in Sempre. See
``edu.stanford.nlp.sempre.tables.StringNormalizationUtils.characterNormalize`` and
``edu.stanford.nlp.sempre.tables.TableTypeSystem.canonicalizeName``. We reproduce those
rules here to normalize and canonicalize cells and columns in the same way so that we can
match them against constants in logical forms appropriately.
"""
# Normalization rules from Sempre
# \u201A -> ,
string = re.sub("", ",", string)
string = re.sub("", ",,", string)
string = re.sub("[·・]", "../sql", string)
string = re.sub("", "...", string)
string = re.sub("ˆ", "^", string)
string = re.sub("˜", "~", string)
string = re.sub("", "<", string)
string = re.sub("", ">", string)
string = re.sub("[´`]", "'", string)
string = re.sub("[“”«»]", "\"", string)
string = re.sub("[•†‡²³]", "", string)
string = re.sub("[‐‑–—−]", "-", string)
# Oddly, some unicode characters get converted to _ instead of being stripped. Not really
# sure how sempre decides what to do with these... TODO(mattg): can we just get rid of the
# need for this function somehow? It's causing a whole lot of headaches.
string = re.sub("[ðø′″€⁄ªΣ]", "_", string)
# This is such a mess. There isn't just a block of unicode that we can strip out, because
# sometimes sempre just strips diacritics... We'll try stripping out a few separate
# blocks, skipping the ones that sempre skips...
string = re.sub("[\\u0180-\\u0210]", "", string).strip()
string = re.sub("[\\u0220-\\uFFFF]", "", string).strip()
string = string.replace("\\n", "_")
string = re.sub("\\s+", " ", string)
# Canonicalization rules from Sempre.
string = re.sub("[^\\w]", "_", string)
string = re.sub("_+", "_", string)
string = re.sub("_$", "", string)
return unidecode(string.lower())
def _expand_entities(self, question, entity_data, string_column_mapping: Dict[str, set]):
new_entities = []
for entity in entity_data:
# to ensure the same strings are not used over and over
if new_entities and entity['token_end'] <= new_entities[-1]['token_end']:
continue
current_start = entity['token_start']
current_end = entity['token_end']
current_token = entity['value']
current_token_type = entity['token_type']
current_token_columns = entity['token_in_columns']
while current_end < len(question):
next_token = question[current_end].text
next_token_normalized = self.normalize_string(next_token)
if next_token_normalized == "":
current_end += 1
continue
candidate = "%s_%s" %(current_token, next_token_normalized)
candidate_columns = self._string_in_table(candidate, string_column_mapping)
candidate_columns = list(set(candidate_columns).intersection(current_token_columns))
if not candidate_columns:
break
candidate_type = candidate_columns[0].split(":")[1]
if candidate_type != current_token_type:
break
current_end += 1
current_token = candidate
current_token_columns = candidate_columns
new_entities.append({'token_start': current_start,
'token_end': current_end,
'value': current_token,
'token_type': current_token_type,
'token_in_columns': current_token_columns})
return new_entities

Просмотреть файл

@ -0,0 +1,141 @@
# pylint: disable=anomalous-backslash-in-string
"""
A ``Text2SqlTableContext`` represents the SQL context in which an utterance appears
for the any of the text2sql datasets, with the grammar and the valid actions.
"""
from typing import List, Dict
from dataset_readers.dataset_util.spider_utils import Table
GRAMMAR_DICTIONARY = {}
GRAMMAR_DICTIONARY["statement"] = ['(query ws iue ws query)', '(query ws)']
GRAMMAR_DICTIONARY["iue"] = ['"intersect"', '"except"', '"union"']
GRAMMAR_DICTIONARY["query"] = ['(ws select_core ws groupby_clause ws orderby_clause ws limit)',
'(ws select_core ws groupby_clause ws orderby_clause)',
'(ws select_core ws groupby_clause ws limit)',
'(ws select_core ws orderby_clause ws limit)',
'(ws select_core ws groupby_clause)',
'(ws select_core ws orderby_clause)',
'(ws select_core)']
GRAMMAR_DICTIONARY["select_core"] = ['(select_with_distinct ws select_results ws from_clause ws where_clause)',
'(select_with_distinct ws select_results ws from_clause)',
'(select_with_distinct ws select_results ws where_clause)',
'(select_with_distinct ws select_results)']
GRAMMAR_DICTIONARY["select_with_distinct"] = ['(ws "select" ws "distinct")', '(ws "select")']
GRAMMAR_DICTIONARY["select_results"] = ['(ws select_result ws "," ws select_results)', '(ws select_result)']
GRAMMAR_DICTIONARY["select_result"] = ['"*"', '(table_source ws ".*")',
'expr', 'col_ref']
GRAMMAR_DICTIONARY["from_clause"] = ['(ws "from" ws table_source ws join_clauses)',
'(ws "from" ws source)']
GRAMMAR_DICTIONARY["join_clauses"] = ['(join_clause ws join_clauses)', 'join_clause']
GRAMMAR_DICTIONARY["join_clause"] = ['"join" ws table_source ws "on" ws join_condition_clause']
GRAMMAR_DICTIONARY["join_condition_clause"] = ['(join_condition ws "and" ws join_condition_clause)', 'join_condition']
GRAMMAR_DICTIONARY["join_condition"] = ['ws col_ref ws "=" ws col_ref']
GRAMMAR_DICTIONARY["source"] = ['(ws single_source ws "," ws source)', '(ws single_source)']
GRAMMAR_DICTIONARY["single_source"] = ['table_source', 'source_subq']
GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")")']
# GRAMMAR_DICTIONARY["source_subq"] = ['("(" ws query ws ")" ws "as" ws name)', '("(" ws query ws ")")']
GRAMMAR_DICTIONARY["limit"] = ['("limit" ws non_literal_number)']
GRAMMAR_DICTIONARY["where_clause"] = ['(ws "where" wsp expr ws where_conj)', '(ws "where" wsp expr)']
GRAMMAR_DICTIONARY["where_conj"] = ['(ws "and" wsp expr ws where_conj)', '(ws "and" wsp expr)']
GRAMMAR_DICTIONARY["groupby_clause"] = ['(ws "group" ws "by" ws group_clause ws "having" ws expr)',
'(ws "group" ws "by" ws group_clause)']
GRAMMAR_DICTIONARY["group_clause"] = ['(ws expr ws "," ws group_clause)', '(ws expr)']
GRAMMAR_DICTIONARY["orderby_clause"] = ['ws "order" ws "by" ws order_clause']
GRAMMAR_DICTIONARY["order_clause"] = ['(ordering_term ws "," ws order_clause)', 'ordering_term']
GRAMMAR_DICTIONARY["ordering_term"] = ['(ws expr ws ordering)', '(ws expr)']
GRAMMAR_DICTIONARY["ordering"] = ['(ws "asc")', '(ws "desc")']
GRAMMAR_DICTIONARY["col_ref"] = ['(table_name ws "." ws column_name)', 'column_name']
GRAMMAR_DICTIONARY["table_source"] = ['(table_name ws "as" ws table_alias)', 'table_name']
GRAMMAR_DICTIONARY["table_name"] = ["table_alias"]
GRAMMAR_DICTIONARY["table_alias"] = ['"t1"', '"t2"', '"t3"', '"t4"']
GRAMMAR_DICTIONARY["column_name"] = []
GRAMMAR_DICTIONARY["ws"] = ['~"\s*"i']
GRAMMAR_DICTIONARY['wsp'] = ['~"\s+"i']
GRAMMAR_DICTIONARY["expr"] = ['in_expr',
# Like expressions.
'(value wsp "like" wsp string)',
# Between expressions.
'(value ws "between" wsp value ws "and" wsp value)',
# Binary expressions.
'(value ws binaryop wsp expr)',
# Unary expressions.
'(unaryop ws expr)',
'source_subq',
'value']
GRAMMAR_DICTIONARY["in_expr"] = ['(value wsp "not" wsp "in" wsp string_set)',
'(value wsp "in" wsp string_set)',
'(value wsp "not" wsp "in" wsp expr)',
'(value wsp "in" wsp expr)']
GRAMMAR_DICTIONARY["value"] = ['parenval', '"YEAR(CURDATE())"', 'number', 'boolean',
'function', 'col_ref', 'string']
GRAMMAR_DICTIONARY["parenval"] = ['"(" ws expr ws ")"']
GRAMMAR_DICTIONARY["function"] = ['(fname ws "(" ws "distinct" ws arg_list_or_star ws ")")',
'(fname ws "(" ws arg_list_or_star ws ")")']
GRAMMAR_DICTIONARY["arg_list_or_star"] = ['arg_list', '"*"']
GRAMMAR_DICTIONARY["arg_list"] = ['(expr ws "," ws arg_list)', 'expr']
# TODO(MARK): Massive hack, remove and modify the grammar accordingly
# GRAMMAR_DICTIONARY["number"] = ['~"\d*\.?\d+"i', "'3'", "'4'"]
GRAMMAR_DICTIONARY["non_literal_number"] = ['"1"', '"2"', '"3"', '"4"']
GRAMMAR_DICTIONARY["number"] = ['ws "value" ws']
GRAMMAR_DICTIONARY["string_set"] = ['ws "(" ws string_set_vals ws ")"']
GRAMMAR_DICTIONARY["string_set_vals"] = ['(string ws "," ws string_set_vals)', 'string']
# GRAMMAR_DICTIONARY["string"] = ['~"\'.*?\'"i']
GRAMMAR_DICTIONARY["string"] = ['"\'" ws "value" ws "\'"']
GRAMMAR_DICTIONARY["fname"] = ['"count"', '"sum"', '"max"', '"min"', '"avg"', '"all"']
GRAMMAR_DICTIONARY["boolean"] = ['"true"', '"false"']
# TODO(MARK): This is not tight enough. AND/OR are strictly boolean value operators.
GRAMMAR_DICTIONARY["binaryop"] = ['"+"', '"-"', '"*"', '"/"', '"="', '"!="', '"<>"',
'">="', '"<="', '">"', '"<"', '"and"', '"or"', '"like"']
GRAMMAR_DICTIONARY["unaryop"] = ['"+"', '"-"', '"not"', '"not"']
def update_grammar_with_tables(grammar_dictionary: Dict[str, List[str]],
schema: Dict[str, Table]) -> None:
table_names = sorted([f'"{table.lower()}"' for table in
list(schema.keys())], reverse=True)
grammar_dictionary['table_name'] += table_names
all_columns = set()
for table in schema.values():
all_columns.update([f'"{table.name.lower()}@{column.name.lower()}"' for column in table.columns if column.name != '*'])
sorted_columns = sorted([column for column in all_columns], reverse=True)
grammar_dictionary['column_name'] += sorted_columns
def update_grammar_to_be_table_names_free(grammar_dictionary: Dict[str, List[str]]):
"""
Remove table names from column names, remove aliases
"""
grammar_dictionary["column_name"] = []
grammar_dictionary["table_name"] = []
grammar_dictionary["col_ref"] = ['column_name']
grammar_dictionary["table_source"] = ['table_name']
del grammar_dictionary["table_alias"]
def update_grammar_flip_joins(grammar_dictionary: Dict[str, List[str]]):
"""
Remove table names from column names, remove aliases
"""
# using a simple rule such as join_clauses-> [(join_clauses ws join_clause), join_clause]
# resulted in a max recursion error, so for now just using a predefined max
# number of joins
grammar_dictionary["join_clauses"] = ['(join_clauses_1 ws join_clause)', 'join_clause']
grammar_dictionary["join_clauses_1"] = ['(join_clauses_2 ws join_clause)', 'join_clause']
grammar_dictionary["join_clauses_2"] = ['(join_clause ws join_clause)', 'join_clause']

Просмотреть файл

Просмотреть файл

@ -0,0 +1,607 @@
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import json
import sqlite3
import nltk
nltk.download('punkt')
from nltk import word_tokenize
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
WHERE_COLUMN_OPS = ( '=', '>', '<', '>=', '<=', '!=')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
mapped_entities = []
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema):
self._schema = schema
self._idMap = self._map(self._schema)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema):
idMap = {'*': "__all__"}
id = 1
for key, vals in schema.items():
for val in vals:
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
id += 1
for key in schema:
idMap[key.lower()] = "__" + key.lower() + "__"
id += 1
return idMap
def get_schema(db):
"""
Get database's schema, which is a dict with table name as key
and list of column names as value
:param db: database path
:return: schema dict
"""
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# fetch table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]
# fetch table info
for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
return schema
def get_schema_from_json(fpath):
with open(fpath, encoding='utf-8') as f:
data = json.load(f)
schema = {}
for entry in data:
if 'chase' in fpath:
schema[entry['db_id']]={}
for table_id,table_name in enumerate(entry['table_names']):
cols = []
for col in entry['column_names']:
if col[0] == table_id:
cols.append(str(col[-1].lower().replace(' ','_')))
schema[entry['db_id']][table_name] = cols
else:
table = str(entry['table'].lower())
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
schema[table] = cols
return schema
def tokenize(string):
string = str(string)
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
# keep string value as token
vals = {}
for i in range(len(quote_idxs)-1, -1, -2):
qidx1 = quote_idxs[i-1]
qidx2 = quote_idxs[i]
val = string[qidx1: qidx2+1]
key = "__val_{}_{}__".format(qidx1, qidx2)
string = string[:qidx1] + key + string[qidx2+1:]
vals[key] = val
toks = [word.lower() for word in word_tokenize(string)]
# replace with string value token
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ('!', '>', '<')
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx-1]
if pre_tok in prefix:
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
return toks
def scan_alias(toks):
"""Scan the index of 'as' and build the map for all alias"""
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
alias[toks[idx+1]] = toks[idx-1]
return alias
def get_tables_with_alias(schema, toks):
tables = scan_alias(toks)
for key in schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
global mapped_entities
tok = toks[start_idx]
if tok == "*":
return start_idx + 1, schema.idMap[tok]
if '.' in tok: # if token is a composite
alias, col = tok.split('.')
key = tables_with_alias[alias] + "." + col
mapped_entities.append((start_idx, tables_with_alias[alias] + "@" + col))
return start_idx+1, schema.idMap[key]
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
mapped_entities.append((start_idx, table + "@" + tok))
return start_idx+1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index('none')
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
"""
:returns next idx, table id, table name
"""
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx+1] == "as":
idx += 3
else:
idx += 1
return idx, schema.idMap[key], key
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
elif "\"" in toks[idx] or toks[idx]=='value': # token is a string value
val = toks[idx]
idx += 1
elif "'" == toks[idx]:
idx+=1
val=[]
while toks[idx]!="'":
val.append(toks[idx])
idx+=1
idx+=1
val='"'+''.join(val)+'"'
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
end_idx += 1
idx, val = parse_col_unit(toks[: end_idx], start_idx, tables_with_alias, schema, default_tables)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, val
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
# if idx<len_-2 and \
# '.' in toks[idx] and toks[idx+1] in UNIT_OPS and '.' in toks[idx+2]:
# idx, val_unit2 = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
# conds.append((not_op, op_id, val_unit, val_unit2))
# elif '.' in toks[idx]:
# idx, val_unit2 = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
# conds.append((not_op, op_id, val_unit, val_unit2))
# else:
val1 = val2 = None
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
else: # normal case: single value
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in COND_OPS:
conds.append(toks[idx])
idx += 1 # skip and/or
return idx, conds
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
return idx, (isDistinct, val_units)
def parse_from(toks, start_idx, tables_with_alias, schema):
"""
Assume in the from clause, all table units are combined with join
"""
assert 'from' in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1 # skip join
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['table_unit'],table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
if len(conds) > 0:
conds.append('and')
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx, col_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, col_units
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc' # default type is 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx, val_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, (order_type, val_units)
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
try:
limit_val = int(toks[idx-1])
except Exception:
limit_val = '"value"'
return idx, limit_val
return idx, None
def parse_sql(toks, start_idx, tables_with_alias, schema, mapped_entities_fn=None):
global mapped_entities
if mapped_entities_fn is not None:
mapped_entities = mapped_entities_fn()
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == '(':
isBlock = True
idx += 1
# parse from clause in order to get default tables
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
sql['from'] = {'table_units': table_units, 'conds': conds}
# select clause
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
sql['select'] = select_col_units
# where clause
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql['where'] = where_conds
# group by clause
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
sql['groupBy'] = group_col_units
# having clause
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
sql['having'] = having_conds
# order by clause
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
sql['orderBy'] = order_col_units
# limit clause
idx, limit_val = parse_limit(toks, idx)
sql['limit'] = limit_val
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# intersect/union/except clause
for op in SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
sql[sql_op] = IUE_sql
if mapped_entities_fn is not None:
return idx, sql, mapped_entities
else:
return idx, sql
def load_data(fpath):
with open(fpath) as f:
data = json.load(f)
return data
def get_sql(schema, query):
toks = tokenize(query)
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
return sql
def skip_semicolon(toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx

Просмотреть файл

@ -0,0 +1,168 @@
# MIT License
#
# Copyright (c) 2019 seq2struct contributors and Microsoft Corporation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import dataclasses
import json
from typing import Optional, Tuple, List, Iterable
import networkx as nx
from pydantic.dataclasses import dataclass
from pydantic.main import BaseConfig
from third_party.spider import evaluation
from third_party.spider.preprocess.schema import get_schemas_from_json, Schema
@dataclass
class SpiderTable:
id: int
name: List[str]
unsplit_name: str
orig_name: str
columns: List["SpiderColumn"] = dataclasses.field(default_factory=list)
primary_keys: List[str] = dataclasses.field(default_factory=list)
@dataclass
class SpiderColumn:
id: int
table: Optional[SpiderTable]
name: List[str]
unsplit_name: str
orig_name: str
type: str
foreign_key_for: Optional[str] = None
SpiderTable.__pydantic_model__.update_forward_refs()
class SpiderSchemaConfig:
arbitrary_types_allowed = True
@dataclass(config=SpiderSchemaConfig)
class SpiderSchema(BaseConfig):
db_id: str
tables: Tuple[SpiderTable, ...]
columns: Tuple[SpiderColumn, ...]
foreign_key_graph: nx.DiGraph
orig: dict
@dataclass
class SpiderItem:
question: str
slml_question: Optional[str]
query: str
spider_sql: dict
spider_schema: SpiderSchema
db_path: str
orig: dict
def schema_dict_to_spider_schema(schema_dict):
tables = tuple(
SpiderTable(id=i, name=name.split(), unsplit_name=name, orig_name=orig_name,)
for i, (name, orig_name) in enumerate(
zip(schema_dict["table_names"], schema_dict["table_names_original"])
)
)
columns = tuple(
SpiderColumn(
id=i,
table=tables[table_id] if table_id >= 0 else None,
name=col_name.split(),
unsplit_name=col_name,
orig_name=orig_col_name,
type=col_type,
)
for i, ((table_id, col_name), (_, orig_col_name), col_type,) in enumerate(
zip(
schema_dict["column_names"],
schema_dict["column_names_original"],
schema_dict["column_types"],
)
)
)
# Link columns to tables
for column in columns:
if column.table:
column.table.columns.append(column)
for column_id in schema_dict["primary_keys"]:
# Register primary keys
column = columns[column_id]
column.table.primary_keys.append(column)
foreign_key_graph = nx.DiGraph()
for source_column_id, dest_column_id in schema_dict["foreign_keys"]:
# Register foreign keys
source_column = columns[source_column_id]
dest_column = columns[dest_column_id]
source_column.foreign_key_for = dest_column
foreign_key_graph.add_edge(
source_column.table.id,
dest_column.table.id,
columns=(source_column_id, dest_column_id),
)
foreign_key_graph.add_edge(
dest_column.table.id,
source_column.table.id,
columns=(dest_column_id, source_column_id),
)
db_id = schema_dict["db_id"]
return SpiderSchema(db_id, tables, columns, foreign_key_graph, schema_dict)
def load_tables(paths):
schemas = {}
eval_foreign_key_maps = {}
with open(paths, 'r',encoding='UTF-8') as f:
schema_dicts = json.load(f)
for schema_dict in schema_dicts:
db_id = schema_dict["db_id"]
if 'column_names_original' not in schema_dict: # {'table': [col.lower, ..., ]} * -> __all__
# continue
schema_dict["column_names_original"] = schema_dict["column_names"]
schema_dict["table_names_original"] = schema_dict["table_names"]
# assert db_id not in schemas
schemas[db_id] = schema_dict_to_spider_schema(schema_dict)
eval_foreign_key_maps[db_id] = evaluation.build_foreign_key_map(schema_dict)
return schemas, eval_foreign_key_maps
def load_original_schemas(tables_paths):
all_schemas = {}
schemas, db_ids, tables = get_schemas_from_json(tables_paths)
for db_id in db_ids:
all_schemas[db_id] = Schema(schemas[db_id], tables[db_id])
return all_schemas

Просмотреть файл

@ -0,0 +1,343 @@
"""
Utility functions for reading the standardised text2sql datasets presented in
`"Improving Text to SQL Evaluation Methodology" <https://arxiv.org/abs/1806.09029>`_
"""
import json
import os
import sqlite3
from collections import defaultdict
from typing import List, Dict, Optional, Any
from semparse.sql.process_sql import get_tables_with_alias, parse_sql
class TableColumn:
def __init__(self,
name: str,
text: str,
column_type: str,
is_primary_key: bool,
foreign_key: Optional[str],
lemma: Optional[str]):
self.name = name
self.text = text
self.column_type = column_type
self.is_primary_key = is_primary_key
self.foreign_key = foreign_key
self.lemma = lemma
class Table:
def __init__(self,
name: str,
text: str,
columns: List[TableColumn],
lemma: Optional[str]):
self.name = name
self.text = text
self.columns = columns
self.lemma = lemma
def read_dataset_schema(schema_path: str, stanza_model=None) -> Dict[str, List[Table]]:
schemas: Dict[str, Dict[str, Table]] = defaultdict(dict)
dbs_json_blob = json.load(open(schema_path, "r", encoding='utf-8'))
for db in dbs_json_blob:
db_id = db['db_id']
column_id_to_table = {}
column_id_to_column = {}
concate_columns = [c[-1] for c in db['column_names']]
concate_tables = [c for c in db['table_names']]
#load stanza model
if stanza_model is not None:
lemma_columns = stanza_model('\n\n'.join(concate_columns).replace(' ','none'))
lemma_columns_collect = []
for sent in lemma_columns.sentences:
tmp = []
for word in sent.words:
if word.lemma != None:
tmp.append(word.lemma)
elif word.text==' ':
tmp.append('none')
else:
tmp.append(word.text)
lemma_columns_collect.append(' '.join(tmp))
lemma_tables = stanza_model('\n\n'.join(concate_tables).replace(' ','none'))
lemma_tables_collect = {}
for t,sent in zip(concate_tables, lemma_tables.sentences):
tmp = []
for word in sent.words:
if word.lemma != None:
tmp.append(word.lemma)
elif word.text == ' ':
tmp.append('none')
else:
tmp.append(word.text)
lemma_tables_collect[t]=' '.join(tmp)
else:
lemma_columns_collect = concate_columns
lemma_tables_collect = {t:t for t in concate_tables}
for i, (column, text, column_type) in enumerate(zip(db['column_names_original'], db['column_names'], db['column_types'])):
table_id, column_name = column
_, column_text = text
table_name = db['table_names_original'][table_id]
if table_name not in schemas[db_id]:
table_text = db['table_names'][table_id]
table_lemma = lemma_tables_collect[table_text]
schemas[db_id][table_name] = Table(table_name, table_text, [], table_lemma)
if column_name == "*":
continue
is_primary_key = i in db['primary_keys']
table_column = TableColumn(column_name.lower(), column_text, column_type, is_primary_key, None, lemma_columns_collect[i])
schemas[db_id][table_name].columns.append(table_column)
column_id_to_table[i] = table_name
column_id_to_column[i] = table_column
for (c1, c2) in db['foreign_keys']:
foreign_key = column_id_to_table[c2] + ':' + column_id_to_column[c2].name
column_id_to_column[c1].foreign_key = foreign_key
return {**schemas}
def read_dataset_values(db_id: str, dataset_path: str, tables: List[str]):
db = os.path.join(dataset_path, db_id, db_id + ".sqlite")
try:
conn = sqlite3.connect(db)
except Exception as e:
raise Exception(f"Can't connect to SQL: {e} in path {db}")
conn.text_factory = str
cursor = conn.cursor()
values = {}
for table in tables:
try:
cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
values[table] = cursor.fetchall()
except:
conn.text_factory = lambda x: str(x, 'latin1')
cursor = conn.cursor()
cursor.execute(f"SELECT * FROM {table.name} LIMIT 5000")
values[table] = cursor.fetchall()
return values
def ent_key_to_name(key):
parts = key.split(':')
if parts[0] == 'table':
return parts[1]
elif parts[0] == 'column':
_, _, table_name, column_name = parts
return f'{table_name}@{column_name}'
else:
return parts[1]
def fix_number_value(ex):
"""
There is something weird in the dataset files - the `query_toks_no_value` field anonymizes all values,
which is good since the evaluator doesn't check for the values. But it also anonymizes numbers that
should not be anonymized: e.g. LIMIT 3 becomes LIMIT 'value', while the evaluator fails if it is not a number.
"""
def split_and_keep(s, sep):
if not s: return [''] # consistent with string.split()
# Find replacement character that is not used in string
# i.e. just use the highest available character plus one
# Note: This fails if ord(max(s)) = 0x10FFFF (ValueError)
p = chr(ord(max(s)) + 1)
return s.replace(sep, p + sep + p).split(p)
# input is tokenized in different ways... so first try to make splits equal
query_toks = ex['query_toks']
ex['query_toks'] = []
for q in query_toks:
ex['query_toks'] += split_and_keep(q, '.')
i_val, i_no_val = 0, 0
while i_val < len(ex['query_toks']) and i_no_val < len(ex['query_toks_no_value']):
if ex['query_toks_no_value'][i_no_val] != 'value':
i_val += 1
i_no_val += 1
continue
i_val_end = i_val
while i_val + 1 < len(ex['query_toks']) and \
i_no_val + 1 < len(ex['query_toks_no_value']) and \
ex['query_toks'][i_val_end + 1].lower() != ex['query_toks_no_value'][i_no_val + 1].lower():
i_val_end += 1
if i_val == i_val_end and ex['query_toks'][i_val] in ["1", "2", "3", "4", "5"] and ex['query_toks'][i_val - 1].lower() == "limit":
ex['query_toks_no_value'][i_no_val] = ex['query_toks'][i_val]
i_val = i_val_end
i_val += 1
i_no_val += 1
return ex
_schemas_cache = None
def disambiguate_items(db_id: str, query_toks: List[str], tables_file: str, allow_aliases: bool) -> List[str]:
"""
we want the query tokens to be non-ambiguous - so we can change each column name to explicitly
tell which table it belongs to
parsed sql to sql clause is based on supermodel.gensql from syntaxsql
"""
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema, table):
self._schema = schema
self._table = table
self._idMap = self._map(self._schema, self._table)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema, table):
column_names_original = table['column_names_original']
table_names_original = table['table_names_original']
# print 'column_names_original: ', column_names_original
# print 'table_names_original: ', table_names_original
for i, (tab_id, col) in enumerate(column_names_original):
if tab_id == -1:
idMap = {'*': i}
else:
key = table_names_original[tab_id].lower()
val = col.lower().replace(' ','_')
idMap[key + "." + val] = i
for i, tab in enumerate(table_names_original):
key = tab.lower()
idMap[key] = i
return idMap
def get_schemas_from_json(fpath):
global _schemas_cache
if _schemas_cache is not None:
return _schemas_cache
with open(fpath, encoding='utf-8') as f:
data = json.load(f)
db_names = [db['db_id'] for db in data]
tables = {}
schemas = {}
for db in data:
db_id = db['db_id']
schema = {} # {'table': [col.lower, ..., ]} * -> __all__
column_names_original = db['column_names_original'] if 'column_names_original' in db else db['column_names']
table_names_original = db['table_names_original'] if 'table_names_original' in db else db['table_names']
tables[db_id] = {'column_names_original': column_names_original,
'table_names_original': table_names_original}
for i, tabn in enumerate(table_names_original):
table = str(tabn.lower())
cols = [str(col.lower().replace(' ','_')) for td, col in column_names_original if td == i]
schema[table] = cols
schemas[db_id] = schema
_schemas_cache = schemas, db_names, tables
return _schemas_cache
schemas, db_names, tables = get_schemas_from_json(tables_file)
schema = Schema(schemas[db_id], tables[db_id])
fixed_toks = []
i = 0
while i < len(query_toks):
tok = query_toks[i]
if tok == 'value' or tok == "'value'":
# TODO: value should alawys be between '/" (remove first if clause)
new_tok = f'"{tok}"'
elif tok in ['!','<','>'] and query_toks[i+1] == '=':
new_tok = tok + '='
i += 1
# elif i+1 < len(query_toks) and query_toks[i+1] == '.' and query_toks[i] in schema.schema.keys():
elif i + 1 < len(query_toks) and query_toks[i + 1] == '.':
new_tok = ''.join(query_toks[i:i+3])
i += 2
else:
new_tok = tok
fixed_toks.append(new_tok)
i += 1
toks = fixed_toks
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql, mapped_entities = parse_sql(toks, 0, tables_with_alias, schema, mapped_entities_fn=lambda: [])
for i, new_name in mapped_entities:
curr_tok = toks[i]
if '.' in curr_tok and allow_aliases:
parts = curr_tok.split('.')
assert(len(parts) == 2)
toks[i] = parts[0] + '.' + new_name
else:
toks[i] = new_name
if not allow_aliases:
toks = [tok for tok in toks if tok not in ['as', 't1', 't2', 't3', 't4', 't5', 't6', 't7', 't8', 't9', 't10']]
toks = [f'\'value\'' if tok == '"value"' else tok for tok in toks]
return toks
def remove_on(query):
query_tok = query.split()
sql_words = []
t = 0
while t < len(query_tok):
if query_tok[t] != 'on':
sql_words.append(query_tok[t])
t += 1
else:
t += 4
return ' '.join(sql_words)
def read_dataset_values_from_json(db_id: str, db_content_dict: Dict[str, Any], tables: List[str]):
values = {}
item = db_content_dict[db_id]
for table in tables:
values[table] = item['tables'][table.name]['cell']
return values
def extract_tree_style(sent):
"""
sent: List
"""
rnt = []
if __name__ == '__main__':
import stanza
stanza_model = stanza.Pipeline('en')
doc = stanza_model("what is the name of the breed with the most dogs ?")
word=[word.lemma for sent in doc.sentences for word in sent.words]
rnt = []

Просмотреть файл

@ -0,0 +1,881 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import os
import json
import sqlite3
import argparse
from semparse.sql.process_sql import get_schema, Schema, get_sql
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
if label_tables != pred_tables:
return False
# if len(label['from']['conds']) > 0:
# label_joins = sorted(label['from']['conds'], key=lambda x: str(x))
# pred_joins = sorted(pred['from']['conds'], key=lambda x: str(x))
# if label_joins != pred_joins:
# return False
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
# res['group(no Having)'] = {'acc': 1, 'rec': 1, 'f1': 1, 'label_total': 1, 'pred_total': 1}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
# res['group'] = {'acc': 1, 'rec': 1, 'f1': 1,'label_total':1,'pred_total':1}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql, [])
except Exception as e:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps):
with open(gold, encoding='utf-8') as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict, encoding='utf-8') as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
db = os.path.join(db_dir, db, db + ".sqlite")
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
try:
p_sql = get_sql(schema, p_str)
except:
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql
p_sql = {
"except": None,
"from": {
"conds": [],
"table_units": []
},
"groupBy": [],
"having": [],
"intersect": None,
"limit": None,
"orderBy": [],
"select": [
False,
[]
],
"union": None,
"where": []
}
eval_err_num += 1
print("eval_err_num:{}".format(eval_err_num))
print(p_str)
print()
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
# p_sql_copy = copy.deepcopy(p_sql)
# g_sql_copy = copy.deepcopy(g_sql)
# if not eval_exec_match(db, p_str, g_str, p_sql_copy, g_sql_copy) and evaluator.eval_exact_match(p_sql_copy, g_sql_copy):
# a = 1
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if exec_score:
scores[hardness]['exec'] += 1
scores['all']['exec'] += exec_score
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
conn.text_factory = bytes
try:
cursor.execute(p_str)
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, val_units):
rmap = {}
for idx, val_unit in enumerate(val_units):
key = tuple(val_unit[1]) if not val_unit[2] else (val_unit[0], tuple(val_unit[1]), tuple(val_unit[2]))
rmap[key] = [r[idx] for r in res]
return rmap
p_val_units = [unit[1] for unit in pred['select'][1]]
q_val_units = [unit[1] for unit in gold['select'][1]]
return res_map(p_res, p_val_units) == res_map(q_res, q_val_units)
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table, encoding='utf-8') as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gold', dest='gold', type=str)
parser.add_argument('--pred', dest='pred', type=str)
parser.add_argument('--db', dest='db', type=str)
parser.add_argument('--table', dest='table', type=str)
parser.add_argument('--etype', dest='etype', type=str)
args = parser.parse_args()
gold = args.gold
pred = args.pred
db_dir = args.db
table = args.table
etype = args.etype
assert etype in ["all", "exec", "match"], "Unknown evaluation method"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps)

Просмотреть файл

@ -0,0 +1,89 @@
import os
import sqlite3
from semparse.worlds.evaluate import Evaluator, build_valid_col_units, rebuild_sql_val, rebuild_sql_col, \
build_foreign_key_map_from_json
from semparse.sql.process_sql import Schema, get_schema, get_sql
_schemas = {}
kmaps = None
def evaluate(gold, predict, db_name, db_dir, table, check_valid: bool=True, db_schema=None) -> bool:
global kmaps
gold=gold.replace("t1 . ","").replace("t2 . ",'').replace("t3 . ",'').replace("t4 . ",'').replace(" as t1",'').replace(" as t2",'').replace(" as t3",'').replace(" as t4",'')
predict=predict.replace("t1 . ","").replace("t2 . ",'').replace("t3 . ",'').replace("t4 . ",'').replace(" as t1",'').replace(" as t2",'').replace(" as t3",'').replace(" as t4",'')
# sgrammar = SpiderGrammar(
# output_from=True,
# use_table_pointer=True,
# include_literals=True,
# include_columns=True,
# )
# try:
evaluator = Evaluator()
if kmaps is None:
kmaps = build_foreign_key_map_from_json(table)
if 'chase' in db_dir:
schema = _schemas[db_name] = Schema(db_schema)
elif db_name in _schemas:
schema = _schemas[db_name]
else:
db = os.path.join(db_dir, db_name, db_name + ".sqlite")
schema = _schemas[db_name] = Schema(get_schema(db))
g_sql = get_sql(schema, gold)
# try:
p_sql = get_sql(schema, predict)
# except Exception as e:
# print('evaluate_spider.py L39')
# return False
# rebuild sql for value evaluation
kmap = kmaps[db_name]
g_valid_col_units = build_valid_col_units(g_sql['from']['table_units'], schema)
g_sql = rebuild_sql_val(g_sql)
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap)
p_valid_col_units = build_valid_col_units(p_sql['from']['table_units'], schema)
p_sql = rebuild_sql_val(p_sql)
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap)
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
if not check_valid:
return exact_score
else:
return exact_score and check_valid_sql(predict, db_name, db_dir)
# except Exception as e:
# return 0
_conns = {}
def check_valid_sql(sql, db_name, db_dir, return_error=False):
return True
db = os.path.join(db_dir, db_name, db_name + ".sqlite")
if db_name == 'wta_1':
# TODO: seems like there is a problem with this dataset - slow response - add limit 1
return True if not return_error else (True, None)
if db_name not in _conns:
_conns[db_name] = sqlite3.connect(db)
# fixes an encoding bug
_conns[db_name].text_factory = bytes
conn = _conns[db_name]
cursor = conn.cursor()
try:
cursor.execute(sql)
cursor.fetchall()
return True if not return_error else (True, None)
except Exception as e:
return False if not return_error else (False, e.args[0])

Просмотреть файл

@ -0,0 +1,471 @@
"""
Based on https://github.com/ryanzhumich/editsql/blob/master/preprocess.py
"""
import argparse
import json
import os
import re
import stanza
import sqlparse
from tqdm import tqdm
from semparse.contexts.spider_db_context import SpiderDBContext
from semparse.sql.spider_utils import disambiguate_items, fix_number_value
from semparse.sql.spider_utils import read_dataset_schema
keyword = ['select', 'distinct', 'from', 'join', 'on', 'where', 'group', 'by', 'order', 'asc', 'desc', 'limit',
'having',
'and', 'not', 'or', 'like', 'between', 'in',
'sum', 'count', 'max', 'min', 'avg',
'(', ')', ',', '>', '<', '=', '==', '>=', '!=', '<=',
'union', 'except', 'intersect',
'\'value\'']
stanza.download('en')
stanza_model = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma')
# stanza_model=None
def write_interaction(interaction_list, split, output_dir):
interaction = []
for db_id in interaction_list:
interaction += interaction_list[db_id]
json_split = os.path.join(output_dir, split + '.json')
with open(json_split, 'w', encoding="utf-8") as outfile:
json.dump(interaction, outfile, indent=2, ensure_ascii=False)
return
def read_database_schema(table_path):
schema_tokens = {}
column_names = {}
database_schemas_dict = {}
with open(table_path, 'r', encoding='UTF-8') as f:
database_schemas = json.load(f)
def get_schema_tokens(table_schema):
column_names_surface_form = []
column_names = []
column_names_original = table_schema['column_names_original']
table_names = table_schema['table_names']
table_names_original = table_schema['table_names_original']
for i, (table_id, column_name) in enumerate(column_names_original):
if table_id >= 0:
table_name = table_names_original[table_id]
column_name_surface_form = '{}.{}'.format(table_name, column_name)
else:
# this is just *
column_name_surface_form = column_name
column_names_surface_form.append(column_name_surface_form.lower())
column_names.append(column_name.lower())
# also add table_name.*
for table_name in table_names_original:
column_names_surface_form.append('{}.*'.format(table_name.lower()))
return column_names_surface_form, column_names
for table_schema in database_schemas:
database_id = table_schema['db_id']
if 'column_names_original' not in table_schema:
table_schema["column_names_original"] = table_schema["column_names"]
table_schema["table_names_original"] = table_schema["table_names"]
table_schema['table_names_original'] = [t.lower() for t in table_schema['table_names_original']]
table_schema['foreign_keys_col'] = [i[0] for i in table_schema['foreign_keys']]
structure_schema = []
for t in table_schema['foreign_keys']:
primary_col, foreign_col = t
primary_col = table_schema['column_names_original'][primary_col]
primary_col_tab = table_schema['table_names_original'][primary_col[0]].lower()
foreign_col = table_schema['column_names_original'][foreign_col]
foreign_col_tab = table_schema['table_names_original'][foreign_col[0]].lower()
structure_schema.append(f"( {primary_col_tab} , {foreign_col_tab} )")
structure_schema = list(sorted(set(structure_schema)))
table_schema['permutations'] = [structure_schema]
database_schemas_dict[database_id] = table_schema
schema_tokens[database_id], column_names[database_id] = get_schema_tokens(table_schema)
if 'column_rewrite_names' in table_schema:
for i in range(len(table_schema['column_rewrite_names'])):
table_schema['column_rewrite_names'][i] = [table_schema['column_names'][i][-1]] + \
table_schema['column_rewrite_names'][i][-1]
table_schema['column_rewrite_names'] = [list(set(map(lambda x: x.lower().replace(' ', ''), i))) for i in
table_schema['column_rewrite_names']]
for i in range(len(table_schema['table_rewrite_names'])):
table_schema['table_rewrite_names'][i] = [table_schema['table_names'][i]] + \
table_schema['table_rewrite_names'][i]
table_schema['table_rewrite_names'] = [list(set(map(lambda x: x.lower().replace(' ', ''), i))) for i in
table_schema['table_rewrite_names']]
return schema_tokens, column_names, database_schemas_dict
def remove_from_with_join(format_sql_2):
used_tables_list = []
format_sql_3 = []
table_to_name = {}
table_list = []
old_table_to_name = {}
old_table_list = []
for sub_sql in format_sql_2.split('\n'):
if 'select ' in sub_sql:
# only replace alias: t1 -> table_name, t2 -> table_name, etc...
if len(table_list) > 0:
for i in range(len(format_sql_3)):
for table, name in table_to_name.items():
format_sql_3[i] = format_sql_3[i].replace(table, name)
old_table_list = table_list
old_table_to_name = table_to_name
table_to_name = {}
table_list = []
format_sql_3.append(sub_sql)
elif sub_sql.startswith('from'):
new_sub_sql = None
sub_sql_tokens = sub_sql.split()
for t_i, t in enumerate(sub_sql_tokens):
if t == 'as':
table_to_name[sub_sql_tokens[t_i + 1]] = sub_sql_tokens[t_i - 1]
table_list.append(sub_sql_tokens[t_i - 1])
elif t == ')' and new_sub_sql is None:
# new_sub_sql keeps some trailing parts after ')'
new_sub_sql = ' '.join(sub_sql_tokens[t_i:])
if len(table_list) > 0:
# if it's a from clause with join
if new_sub_sql is not None:
format_sql_3.append(new_sub_sql)
used_tables_list.append(table_list)
else:
# if it's a from clause without join
table_list = old_table_list
table_to_name = old_table_to_name
assert 'join' not in sub_sql
if new_sub_sql is not None:
sub_sub_sql = sub_sql[:-len(new_sub_sql)].strip()
assert len(sub_sub_sql.split()) == 2
used_tables_list.append([sub_sub_sql.split()[1]])
format_sql_3.append(sub_sub_sql)
format_sql_3.append(new_sub_sql)
elif 'join' not in sub_sql:
assert len(sub_sql.split()) == 2 or len(sub_sql.split()) == 1
if len(sub_sql.split()) == 2:
used_tables_list.append([sub_sql.split()[1]])
format_sql_3.append(sub_sql)
else:
print('bad from clause in remove_from_with_join')
exit()
else:
format_sql_3.append(sub_sql)
if len(table_list) > 0:
for i in range(len(format_sql_3)):
for table, name in table_to_name.items():
format_sql_3[i] = format_sql_3[i].replace(table, name)
used_tables = []
for t in used_tables_list:
for tt in t:
used_tables.append(tt)
used_tables = list(set(used_tables))
return format_sql_3, used_tables, used_tables_list
def remove_from_without_join(format_sql_3, column_names, schema_tokens):
format_sql_4 = []
table_name = None
for sub_sql in format_sql_3.split('\n'):
if 'select ' in sub_sql:
if table_name:
for i in range(len(format_sql_4)):
tokens = format_sql_4[i].split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii - 1] != '.':
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[
ii + 1] != '(') or ii + 1 == len(tokens):
if '{}.{}'.format(table_name, token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name, token)
format_sql_4[i] = ' '.join(tokens)
format_sql_4.append(sub_sql)
elif sub_sql.startswith('from'):
sub_sql_tokens = sub_sql.split()
if len(sub_sql_tokens) == 1:
table_name = None
elif len(sub_sql_tokens) == 2:
table_name = sub_sql_tokens[1]
else:
print('bad from clause in remove_from_without_join')
print(format_sql_3)
exit()
else:
format_sql_4.append(sub_sql)
if table_name:
for i in range(len(format_sql_4)):
tokens = format_sql_4[i].split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii - 1] != '.':
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[ii + 1] != '(') or ii + 1 == len(
tokens):
if '{}.{}'.format(table_name, token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name, token)
format_sql_4[i] = ' '.join(tokens)
return format_sql_4
def add_table_name(format_sql_3, used_tables, column_names, schema_tokens):
# If just one table used, easy case, replace all column_name -> table_name.column_name
if len(used_tables) == 1:
table_name = used_tables[0]
format_sql_4 = []
for sub_sql in format_sql_3.split('\n'):
if sub_sql.startswith('from'):
format_sql_4.append(sub_sql)
continue
tokens = sub_sql.split()
for ii, token in enumerate(tokens):
if token in column_names and tokens[ii - 1] != '.':
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[ii + 1] != '(') or ii + 1 == len(
tokens):
if '{}.{}'.format(table_name, token) in schema_tokens:
tokens[ii] = '{} . {}'.format(table_name, token)
format_sql_4.append(' '.join(tokens))
return format_sql_4
def get_table_name_for(token):
table_names = []
for table_name in used_tables:
if '{}.{}'.format(table_name, token) in schema_tokens:
table_names.append(table_name)
if len(table_names) == 0:
return 'table'
if len(table_names) > 1:
return None
else:
return table_names[0]
format_sql_4 = []
for sub_sql in format_sql_3.split('\n'):
if sub_sql.startswith('from'):
format_sql_4.append(sub_sql)
continue
tokens = sub_sql.split()
for ii, token in enumerate(tokens):
# skip *
if token == '*':
continue
if token in column_names and tokens[ii - 1] != '.':
if (ii + 1 < len(tokens) and tokens[ii + 1] != '.' and tokens[ii + 1] != '(') or ii + 1 == len(tokens):
table_name = get_table_name_for(token)
if table_name:
tokens[ii] = '{} . {}'.format(table_name, token)
format_sql_4.append(' '.join(tokens))
return format_sql_4
def normalize_space(format_sql):
format_sql_1 = [' '.join(
sub_sql.strip().replace(',', ' , ').replace('.', ' . ').replace('(', ' ( ').replace(')', ' ) ').split()) for
sub_sql in format_sql.split('\n')]
format_sql_1 = '\n'.join(format_sql_1)
format_sql_2 = format_sql_1.replace('\njoin', ' join').replace(',\n', ', ').replace(' where', '\nwhere').replace(
' intersect', '\nintersect').replace('\nand', ' and').replace('order by t2 .\nstart desc',
'order by t2 . start desc')
format_sql_2 = format_sql_2.replace('select\noperator', 'select operator').replace('select\nconstructor',
'select constructor').replace(
'select\nstart', 'select start').replace('select\ndrop', 'select drop').replace('select\nwork',
'select work').replace(
'select\ngroup', 'select group').replace('select\nwhere_built', 'select where_built').replace('select\norder',
'select order').replace(
'from\noperator', 'from operator').replace('from\nforward', 'from forward').replace('from\nfor',
'from for').replace(
'from\ndrop', 'from drop').replace('from\norder', 'from order').replace('.\nstart', '. start').replace(
'.\norder', '. order').replace('.\noperator', '. operator').replace('.\nsets', '. sets').replace(
'.\nwhere_built', '. where_built').replace('.\nwork', '. work').replace('.\nconstructor',
'. constructor').replace('.\ngroup',
'. group').replace(
'.\nfor', '. for').replace('.\ndrop', '. drop').replace('.\nwhere', '. where')
format_sql_2 = format_sql_2.replace('group by', 'group_by').replace('order by', 'order_by').replace('! =',
'!=').replace(
'limit value', 'limit_value')
return format_sql_2
def normalize_final_sql(format_sql_5):
format_sql_final = format_sql_5.replace('\n', ' ').replace(' . ', '.').replace('group by', 'group_by').replace(
'order by', 'order_by').replace('! =', '!=').replace('limit value', 'limit_value')
# normalize two bad sqls
if 't1' in format_sql_final or 't2' in format_sql_final or 't3' in format_sql_final or 't4' in format_sql_final:
format_sql_final = format_sql_final.replace('t2.dormid', 'dorm.dormid')
# This is the failure case of remove_from_without_join()
format_sql_final = format_sql_final.replace(
'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by population desc limit_value',
'select city.city_name where city.state_name in ( select state.state_name where state.state_name in ( select river.traverse where river.river_name = value ) and state.area = ( select min ( state.area ) where state.state_name in ( select river.traverse where river.river_name = value ) ) ) order_by city.population desc limit_value')
return format_sql_final
def normalize_original_sql(sql):
sql = [i.lower() for i in sql]
sql = ' '.join(sql).strip(';').replace("``", "'").replace("\"", "'").replace("''", "'")
sql = sql.replace(')from', ') from')
sql = sql.replace('(', ' ( ')
sql = sql.replace(')', ' ) ')
sql = re.sub('\s+', ' ', sql)
sql = re.sub(r"(')(\S+)", r"\1 \2", sql)
sql = re.sub(r"(\S+)(')", r"\1 \2", sql).split(' ')
sql = ' '.join(sql)
sql = sql.strip(' ;').replace('> =', '>=').replace('! =', '!=')
return sql.split(' ')
def parse_sql(sql_string, db_id, column_names, schema_tokens, schema):
format_sql = sqlparse.format(sql_string, reindent=True)
format_sql_2 = normalize_space(format_sql)
format_sql_3, used_tables, used_tables_list = remove_from_with_join(format_sql_2)
format_sql_3 = '\n'.join(format_sql_3)
format_sql_4 = add_table_name(format_sql_3, used_tables, column_names, schema_tokens)
format_sql_4 = '\n'.join(format_sql_4)
format_sql_5 = remove_from_without_join(format_sql_4, column_names, schema_tokens)
format_sql_5 = '\n'.join(format_sql_5)
format_sql_final = normalize_final_sql(format_sql_5)
return format_sql_final
def read_spider_split(dataset_path, table_path, database_path):
with open(dataset_path) as f:
split_data = json.load(f)
print('read_spider_split', dataset_path, len(split_data))
schemas = read_dataset_schema(table_path, stanza_model)
interaction_list = {}
for i, ex in enumerate(tqdm(split_data)):
db_id = ex['db_id']
ex['query_toks_no_value'] = normalize_original_sql(ex['query_toks_no_value'])
turn_sql = ' '.join(ex['query_toks_no_value'])
turn_sql = turn_sql.replace('select count ( * ) from follows group by value',
'select count ( * ) from follows group by f1')
ex['query_toks_no_value'] = turn_sql.split(' ')
ex = fix_number_value(ex)
try:
ex['query_toks_no_value'] = disambiguate_items(db_id, ex['query_toks_no_value'],
tables_file=table_path, allow_aliases=False)
except:
print(ex['query_toks'])
continue
final_sql_parse = ' '.join(ex['query_toks_no_value'])
final_utterance = ' '.join(ex['question_toks']).lower()
if stanza_model is not None:
lemma_utterance_stanza = stanza_model(final_utterance)
lemma_utterance = [word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words]
original_utterance = final_utterance
else:
original_utterance = lemma_utterance = final_utterance.split(' ')
# using db content
db_context = SpiderDBContext(db_id,
lemma_utterance,
tables_file=table_path,
dataset_path=database_path,
stanza_model=stanza_model,
schemas=schemas,
original_utterance=original_utterance)
value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(db_id)
if value_match != []:
print(value_match, value_alignment)
if db_id not in interaction_list:
interaction_list[db_id] = []
interaction = {}
interaction['id'] = i
interaction['database_id'] = db_id
interaction['interaction'] = [{'utterance': final_utterance,
'db_id': db_id,
'query': ex['query'],
'question': ex['question'],
'sql': final_sql_parse,
'value_match': value_match,
'value_alignment': value_alignment,
'exact_match': exact_match,
'partial_match': partial_match,
}]
interaction_list[db_id].append(interaction)
return interaction_list
def preprocess_dataset(dataset, dataset_dir, output_dir, table_path, database_path):
# for session in ['train', 'dev']:
for session in ['dev']:
dataset_path = os.path.join(dataset_dir, f'{session}.json')
interaction_list = read_spider_split(dataset_path, table_path, database_path)
write_interaction(interaction_list, session, output_dir)
return interaction_list
def preprocess(dataset, dataset_dir, table_path, database_path, output_dir):
# directory
if not os.path.exists(output_dir):
os.mkdir(output_dir)
# read schema
print('Reading spider database schema file')
schema_tokens, column_names, database_schemas = read_database_schema(table_path)
print('total number of schema_tokens / databases:', len(schema_tokens))
output_table_path = os.path.join(output_dir, 'tables.json')
with open(output_table_path, 'w') as outfile:
json.dump([v for k, v in database_schemas.items()], outfile, indent=4)
# process (SQL, Query) pair in train/dev
preprocess_dataset(dataset, dataset_dir, output_dir, table_path, database_path)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", choices=('spider', 'sparc', 'cosql'), default='spider')
args = parser.parse_args()
dataset = args.dataset
dataset_dir = f'./data/{dataset}/'
table_path = f'./data/{dataset}/tables.json'
database_path = f'./data/{dataset}/database'
output_dir = f'./data/{dataset}_schema_linking_tag'
preprocess(dataset, dataset_dir, table_path, database_path, output_dir)

Просмотреть файл

@ -0,0 +1,174 @@
import os
import json
import argparse
import subprocess
from tqdm import tqdm
from step1_schema_linking import read_database_schema
from train import run_command
def running_process(generate_path):
# cmd = f"python -m multiprocessing_bpe_encoder \
# --encoder-json ./models/spider_sl/encoder.json \
# --vocab-bpe ./models/spider_sl/vocab.bpe \
# --inputs {generate_path}/train.src \
# --outputs {generate_path}/train.bpe.src \
# --workers 1 \
# --keep-empty"
# run_command(cmd)
#
# cmd = f"python -m multiprocessing_bpe_encoder \
# --encoder-json ./models/spider_sl/encoder.json \
# --vocab-bpe ./models/spider_sl/vocab.bpe \
# --inputs {generate_path}/train.tgt \
# --outputs {generate_path}/train.bpe.tgt \
# --workers 1 \
# --keep-empty"
# run_command(cmd)
cmd = f"python -m multiprocessing_bpe_encoder \
--encoder-json ./models/spider_sl/encoder.json \
--vocab-bpe ./models/spider_sl/vocab.bpe \
--inputs {generate_path}/dev.src \
--outputs {generate_path}/dev.bpe.src \
--workers 1 \
--keep-empty"
run_command(cmd)
cmd = f"python -m multiprocessing_bpe_encoder \
--encoder-json ./models/spider_sl/encoder.json \
--vocab-bpe ./models/spider_sl/vocab.bpe \
--inputs {generate_path}/dev.tgt \
--outputs {generate_path}/dev.bpe.tgt \
--workers 1 \
--keep-empty"
run_command(cmd)
# cmd = f'fairseq-preprocess --source-lang "src" --target-lang "tgt" \
# --trainpref {generate_path}/train.bpe \
# --validpref {generate_path}/dev.bpe \
# --destdir {generate_path}/bin \
# --workers 2 \
# --srcdict ./models/spider_sl/dict.src.txt \
# --tgtdict ./models/spider_sl/dict.tgt.txt '
cmd = f'fairseq-preprocess --source-lang "src" --target-lang "tgt" \
--validpref {generate_path}/dev.bpe \
--destdir {generate_path}/bin \
--workers 2 \
--srcdict ./models/spider_sl/dict.src.txt \
--tgtdict ./models/spider_sl/dict.tgt.txt '
subprocess.Popen(
cmd, universal_newlines=True, shell=True,
stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
def build_schema_linking_data(schema, question, item, turn_id, linking_type):
source_sequence_list, target_sequence_list = [], []
# column description
column_names = []
for i, (t, c) in enumerate(zip(schema['column_types'], schema['column_names_original'])):
if c[0] == -1:
column_names.append("{0} {1}".format(t, c[1].lower()))
else:
column_with_alias = "{0}@{1}".format(schema['table_names_original'][c[0]].lower(), c[1].lower())
tag_list = []
if column_with_alias in item['interaction'][turn_id]['exact_match']:
tag_list.append('EM')
elif column_with_alias in item['interaction'][turn_id]['partial_match']:
tag_list.append('PA')
if column_with_alias in item['interaction'][turn_id]['value_match']:
tag_list.append('VC')
# primary-foreign key
if i in schema['primary_keys']:
tag_list.append('RK')
elif i in schema['foreign_keys_col']:
tag_list.append('FO')
if tag_list != []:
column_names.append("{0} {1} {2}".format(' '.join(tag_list), t, column_with_alias))
else:
column_names.append("{0} {1}".format(t, column_with_alias))
# table description
table_names = []
for t in schema['table_names_original']:
tag_list = []
if t in item['interaction'][turn_id]['exact_match']:
tag_list.append('EM')
elif t in item['interaction'][turn_id]['partial_match']:
tag_list.append('PA')
if '_nosl' in linking_type or 'not' in linking_type:
tag_list = []
if tag_list != []:
table_names.append("{0} {1}".format(' '.join(tag_list), t.lower()))
else:
table_names.append("{0}".format(t.lower()))
table_names = ' | '.join(table_names)
column_names = ' | '.join(column_names)
for structure_schema_list in schema['permutations'][:10]:
structure_schema_str = ' | '.join(structure_schema_list)
source_sequence = f"<C> {column_names} | <T> {table_names} | <S> {structure_schema_str} | <Q> {question.lower()}"
target_sequence = item['interaction'][turn_id]['sql'].lower()
source_sequence_list.append(source_sequence)
target_sequence_list.append(target_sequence)
return source_sequence_list, target_sequence_list
def extract_input_and_output(example_lines, linking_type):
inputs = []
outputs = []
database_schema_filename = './data/spider/tables.json'
schema_tokens, column_names, database_schemas = read_database_schema(database_schema_filename)
for item in tqdm(example_lines):
question = item['interaction'][0]['utterance']
schema = database_schemas[item['database_id']]
source_sequence, target_sequence = build_schema_linking_data(schema=schema,
question=question,
item=item,
turn_id=0,
linking_type=linking_type)
outputs.extend(target_sequence)
inputs.extend(source_sequence)
assert len(inputs) == len(outputs)
return inputs, outputs
def read_dataflow_dataset(file_path, out_folder, session, linking_type):
train_out_path = os.path.join(out_folder, session)
train_src_writer = open(train_out_path + ".src", "w", encoding="utf8")
train_tgt_writer = open(train_out_path + ".tgt", "w", encoding="utf8")
with open(file_path, "r", encoding='utf-8') as data_file:
lines = json.load(data_file)
data_input, data_output = extract_input_and_output(lines, linking_type)
train_src_writer.write("\n".join(data_input))
train_tgt_writer.write("\n".join(data_output))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--sl_dataset_path", default='./data/spider_schema_linking_tag')
parser.add_argument("--output_path", default='./dataset_post/spider_sl')
parser.add_argument("--linking_type", default='default')
args = parser.parse_args()
# for session in ["train", "dev"]:
for session in ["dev"]:
file_path = os.path.join(args.sl_dataset_path, "{}.json".format(session))
out_folder = args.output_path
if not os.path.exists(out_folder):
os.makedirs(out_folder)
read_dataflow_dataset(file_path, out_folder, session, args.linking_type)
running_process(args.output_path)

Просмотреть файл

@ -0,0 +1,336 @@
import argparse
import json
import re
import subprocess
from collections import defaultdict
from re import RegexFlag
import networkx as nx
import torch
from genre.fairseq_model import GENRE, mGENRE
from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_fairseq as get_prefix_allowed_tokens_fn
from genre.trie import Trie
from semparse.sql.spider import load_original_schemas, load_tables
from semparse.worlds.evaluate_spider import evaluate as evaluate_sql
from step1_schema_linking import read_database_schema
database_dir='./data/spider/database'
database_schema_filename = './data/spider/tables.json'
schema_tokens, column_names, database_schemas = read_database_schema(database_schema_filename)
with open(f'./data/spider/dev.json', 'r', encoding='utf-8') as f:
item = json.load(f)
sql_to_db = []
for i in item:
sql_to_db.append(i['db_id'])
def post_processing_sql(p_sql, foreign_key_maps, schemas, o_schemas):
foreign_key = {}
for k, v in foreign_key_maps.items():
if k == v:
continue
key = ' '.join(sorted([k.split('.')[0].strip('_'), v.split('.')[0].strip('_')]))
foreign_key[key] = (k.strip('_').replace('.', '@'), v.strip('_').replace('.', '@'))
primary_key = {}
for t in o_schemas.tables:
table = t.orig_name.lower()
if len(t.primary_keys) == 0:
continue
column = t.primary_keys[0].orig_name.lower()
primary_key[table] = f'{table}@{column}'
p_sql = re.sub(r'(=)(\S+)', r'\1 \2', p_sql)
p_sql = p_sql.split()
columns = ['*']
tables = []
for table, column_list in schemas.schema.items():
for column in column_list:
columns.append(f"{table}@{column}")
tables.append(table)
# infer table from mentioned column
all_from_table_ids = set()
from_idx = where_idx = group_idx = order_idx = -1
for idx, token in enumerate(p_sql):
if '@' in token and token in columns:
all_from_table_ids.add(schemas.idMap[token.split('@')[0]])
if token == 'from' and from_idx == -1:
from_idx = idx
if token == 'where' and where_idx == -1:
where_idx = idx
if token == 'group' and group_idx == -1:
group_idx = idx
if token == 'order' and order_idx == -1:
order_idx = idx
#don't process nested SQL (more than one select)
if len(re.findall('select', ' '.join(p_sql))) > 1 or len(all_from_table_ids) == 0:
return ' '.join(p_sql)
covered_tables = set()
candidate_table_ids = sorted(all_from_table_ids)
start_table_id = candidate_table_ids[0]
conds = set()
all_conds = []
for table_id in candidate_table_ids[1:]:
if table_id in covered_tables:
continue
try:
path = nx.shortest_path(
o_schemas.foreign_key_graph,
source=start_table_id,
target=table_id,
)
except (nx.NetworkXNoPath, nx.NodeNotFound):
covered_tables.add(table_id)
continue
for source_table_id, target_table_id in zip(path, path[1:]):
if target_table_id in covered_tables:
continue
covered_tables.add(target_table_id)
all_from_table_ids.add(target_table_id)
col1, col2 = o_schemas.foreign_key_graph[source_table_id][target_table_id]["columns"]
all_conds.append((columns[col1], columns[col2]))
conds.add((tables[source_table_id],
tables[target_table_id],
columns[col1],
columns[col2]))
all_from_table_ids = list(all_from_table_ids)
try:
tokens = ["from", tables[all_from_table_ids[0]]]
for i, table_id in enumerate(all_from_table_ids[1:]):
tokens += ["join"]
tokens += [tables[table_id]]
tokens += ["on", all_conds[i][0], "=", all_conds[i][1]]
except:
return ' '.join(p_sql)
if where_idx != -1:
p_sql = p_sql[:from_idx] + tokens + p_sql[where_idx:]
elif group_idx != -1:
p_sql = p_sql[:from_idx] + tokens + p_sql[group_idx:]
elif order_idx != -1:
p_sql = p_sql[:from_idx] + tokens + p_sql[order_idx:]
elif len(p_sql[:from_idx] + p_sql[from_idx:]) == len(p_sql):
p_sql = p_sql[:from_idx] + tokens
return ' '.join(p_sql)
def extract_structure_data(plain_text_content: str):
def sort_by_id(data):
data.sort(key=lambda x: int(x.split('\t')[0][2:]))
return data
data = []
original_schemas = load_original_schemas(database_schema_filename)
schemas, eval_foreign_key_maps = load_tables(database_schema_filename)
predict_outputs = sort_by_id(re.findall("^D.+", plain_text_content, RegexFlag.MULTILINE))
ground_outputs = sort_by_id(re.findall("^T.+", plain_text_content, RegexFlag.MULTILINE))
source_inputs = sort_by_id(re.findall("^S.+", plain_text_content, RegexFlag.MULTILINE))
for idx, (predict, ground, source) in enumerate(zip(predict_outputs, ground_outputs, source_inputs)):
predict_id, predict_score, predict_clean = predict.split('\t')
ground_id, ground_clean = ground.split('\t')
source_id, source_clean = source.split('\t')
db_id = sql_to_db[idx]
#try to postprocess the incomplete sql from
# (1) correcting the COLUMN in ON_CLAUSE based on foreign key graph
# (2) adding the underlying TABLE via searching shortest path
predict_clean = post_processing_sql(predict_clean, eval_foreign_key_maps[db_id], original_schemas[db_id],
schemas[db_id])
data.append((predict_id[2:], source_clean.split('<Q>')[-1].strip(), ground_clean, predict_clean, db_id))
return data
def evaluate(data):
def evaluate_example(_predict_str: str, _ground_str: str):
return re.sub("\s+", "", _predict_str.lower()) == re.sub("\s+", "", _ground_str.lower())
correct_num = 0
correct_tag_list = []
total = 0
tmp = []
for example in data:
idx, source_str, ground_str, predict_str, db_id = example
total += 1
try:
sql_match = evaluate_sql(gold=ground_str.replace('@', '.'),
predict=predict_str.replace('@', '.'),
db_name=db_id,
db_dir=database_dir,
table=database_schema_filename)
except:
print(predict_str)
sql_match = False
if (sql_match or evaluate_example(predict_str, ground_str)):
is_correct = True
correct_num += 1
else:
is_correct = False
tmp.append(is_correct)
correct_tag_list.append(is_correct)
print("Correct/Total : {}/{}, {:.4f}".format(correct_num, total, correct_num / total))
return correct_tag_list, correct_num, total
def predict_and_evaluate(model_path, dataset_path, constrain):
if constrain:
data = predict_with_constrain(
model_path=model_path,
dataset_path=dataset_path
)
else:
decode_without_constrain(
model_path=model_path,
dataset_path=dataset_path
)
with open('./eval/generate-valid.txt', "r", encoding="utf8") as generate_f:
file_content = generate_f.read()
data = extract_structure_data(file_content)
correct_arr, correct_num, total = evaluate(data)
with open('./eval/spider_eval.txt', "w", encoding="utf8") as eval_file:
for example, correct in zip(data, correct_arr):
eval_file.write(str(correct) + "\n" + "\n".join(
[example[0], "db: " + example[-1], example[1], "gold: " + example[2], "pred: " + example[3]]) + "\n\n")
return correct_num, total
def get_alias_schema(schemas):
alias_schema = {}
for db in schemas:
schema = schemas[db].orig
collect = []
for i, (t, c) in enumerate(zip(schema['column_types'], schema['column_names_original'])):
if c[0] == -1:
collect.append('*')
else:
column_with_alias = "{0}@{1}".format(schema['table_names_original'][c[0]].lower(), c[1].lower())
collect.append(column_with_alias)
for t in schema['table_names_original']:
collect.append(t.lower())
collect.append("'value'")
alias_schema[db] = collect
return alias_schema
def predict_with_constrain(model_path, dataset_path):
schemas, eval_foreign_key_maps = load_tables(database_schema_filename)
original_schemas = load_original_schemas(database_schema_filename)
with open(f'{dataset_path}/dev.src', 'r', encoding='utf-8') as f:
item = [i.strip() for i in f.readlines()]
with open(f'{dataset_path}/dev.tgt', 'r', encoding='utf-8') as f:
ground = [i.strip() for i in f.readlines()]
alias_schema = get_alias_schema(schemas)
item_db_cluster = defaultdict(list)
ground_db_cluster = defaultdict(list)
source_db_cluster = defaultdict(list)
num_example = 1034
for db, sentence, g_sql in zip(sql_to_db[:num_example], item[:num_example], ground[:num_example]):
source = sentence.split('<Q>')[-1].strip()
item_db_cluster[db].append(sentence)
ground_db_cluster[db].append(g_sql)
source_db_cluster[db].append(source)
source = []
ground = []
for db, sentence in source_db_cluster.items():
source.extend(sentence)
for db, g_SQL in ground_db_cluster.items():
ground.extend(g_SQL)
model = GENRE.from_pretrained(model_path).eval()
if torch.cuda.is_available():
model.cuda()
result=[]
for db, sentence in item_db_cluster.items():
print(f'processing db: {db} with {len(sentence)} sentences')
rnt=decode_with_constrain(sentence, alias_schema[db], model)
result.extend([i[0]['text'] if isinstance(i[0]['text'], str) else i[0]['text'][0] for i in rnt])
eval_file_path= f'./eval/generate-valid-constrain.txt'
with open(eval_file_path, "w", encoding="utf8") as f:
f.write('\n'.join(result))
# result = []
# with open(f'./eval/generate-valid-constrain.txt', "r", encoding="utf8") as f:
# for idx, (sent, db_id) in enumerate(zip(f.readlines(), sql_to_db)):
# result.append(sent.strip())
data = []
for predict_id, (predict_clean, ground_clean, source_clean, db_id) in enumerate(
zip(result, ground, source, sql_to_db)):
predict_clean = post_processing_sql(predict_clean, eval_foreign_key_maps[db_id], original_schemas[db_id],
schemas[db_id])
data.append((str(predict_id), source_clean.split('<Q>')[-1].strip(), ground_clean, predict_clean, db_id))
return data
def decode_with_constrain(sentences, schema, model):
trie = Trie([
model.encode(" {}".format(e))[1:].tolist()
for e in schema
])
prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(
model,
sentences,
mention_trie=trie,
)
return model.sample(
sentences,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
def decode_without_constrain( model_path, dataset_path):
cmd = f'fairseq-generate \
--path {model_path}/model.pt {dataset_path}/bin \
--gen-subset valid \
--nbest 1 \
--max-tokens 4096 \
--source-lang src --target-lang tgt \
--results-path ./eval \
--beam 5 \
--bpe gpt2 \
--remove-bpe \
--skip-invalid-size-inputs-valid-test'
subprocess.Popen(
cmd, universal_newlines=True, shell=True,
stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default='./models/spider_sl')
parser.add_argument("--dataset_path", default='./dataset_post/spider_sl')
parser.add_argument("--constrain", action='store_true')
args = parser.parse_args()
predict_and_evaluate(model_path=args.model_path,
dataset_path=args.dataset_path,
constrain=args.constrain)

0
unified_parser_text_to_sql/third_party/__init__.py поставляемый Normal file
Просмотреть файл

232
unified_parser_text_to_sql/third_party/spider/README.md поставляемый Normal file
Просмотреть файл

@ -0,0 +1,232 @@
# Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task
Spider is a large human-labeled dataset for complex and cross-domain semantic parsing and text-to-SQL task (natural language interfaces for relational databases). It is released along with our EMNLP 2018 paper: [Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task](https://arxiv.org/abs/1809.08887). This repo contains all code for evaluation, preprocessing, and all baselines used in our paper. Please refer to [the task site](https://yale-lily.github.io/spider) for more general introduction and the leaderboard.
### Changelog
- `06/07/2020` We corrected some annotation errors and label mismatches (not errors) in Spider dev and test sets (~4% of dev examples updated, click [here](https://github.com/taoyds/spider/commit/25fcd85d9b6e94acaeb5e9172deadeefeed83f5e#diff-18b0a730a7b0d29b0a78a5070d971d49) for more details). Please download the Spider dataset from [the page](https://yale-lily.github.io/spider) again.
- `01/16/2020` For value prediction (in order to compute the execution accuracy), your model should be able to 1) copy from the question inputs, 2) retrieve from the database content (database content is available), or 3) generate numbers (e.g. 3 in "LIMIT 3").
- `1/14/2019` The submission toturial is ready! Please follow it to get your results on the unreleased test data.
- `12/17/2018` We updated 7 sqlite database files. Please download the Spider data from the official website again. Please refer to [the issue 14](https://github.com/taoyds/spider/issues/14) for more details.
- `10/25/2018`: evaluation script is updated so that the table in `count(*)`cases will be evaluated as well. Please check out [the issue 5](https://github.com/taoyds/spider/issues/5) for more info. Results of all baselines and [syntaxSQL](https://github.com/taoyds/syntaxSQL) on the papers are updated as well.
- `10/25/2018`: to get the latest SQL parsing results (a few small bugs fixed), please use `preprocess/parse_raw_json.py` to update. Please refer to [the issue 3](https://github.com/taoyds/spider/issues/3) for more details.
### Citation
The dataset is annotated by 11 college students. When you use the Spider dataset, we would appreciate it if you cite the following:
```
@inproceedings{Yu&al.18c,
title = {Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task},
author = {Tao Yu and Rui Zhang and Kai Yang and Michihiro Yasunaga and Dongxu Wang and Zifan Li and James Ma and Irene Li and Qingning Yao and Shanelle Roman and Zilin Zhang and Dragomir Radev}
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing",
address = "Brussels, Belgium",
publisher = "Association for Computational Linguistics",
year = 2018
}
```
### Installation
`evaluation.py` and `process_sql.py` are written in Python 3. Enviroment setup for each baseline is in README under each baseline directory.
### Data Content and Format
#### Question, SQL, and Parsed SQL
Each file in`train.json` and `dev.json` contains the following fields:
- `question`: the natural language question
- `question_toks`: the natural language question tokens
- `db_id`: the database id to which this question is addressed.
- `query`: the SQL query corresponding to the question.
- `query_toks`: the SQL query tokens corresponding to the question.
- `sql`: parsed results of this SQL query using `process_sql.py`. Please refer to `parsed_sql_examples.sql` in the`preprocess` directory for the detailed documentation.
```
{
"db_id": "world_1",
"query": "SELECT avg(LifeExpectancy) FROM country WHERE Name NOT IN (SELECT T1.Name FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = \"English\" AND T2.IsOfficial = \"T\")",
"query_toks": ["SELECT", "avg", "(", "LifeExpectancy", ")", "FROM", ...],
"question": "What is average life expectancy in the countries where English is not the official language?",
"question_toks": ["What", "is", "average", "life", ...],
"sql": {
"except": null,
"from": {
"conds": [],
"table_units": [
...
},
"groupBy": [],
"having": [],
"intersect": null,
"limit": null,
"orderBy": [],
"select": [
...
],
"union": null,
"where": [
[
true,
...
{
"except": null,
"from": {
"conds": [
[
false,
2,
[
...
},
"groupBy": [],
"having": [],
"intersect": null,
"limit": null,
"orderBy": [],
"select": [
false,
...
"union": null,
"where": [
[
false,
2,
[
0,
...
}
},
```
#### Tables
`tables.json` contains the following information for each database:
- `db_id`: database id
- `table_names_original`: original table names stored in the database.
- `table_names`: cleaned and normalized table names. We make sure the table names are meaningful. [to be changed]
- `column_names_original`: original column names stored in the database. Each column looks like: `[0, "id"]`. `0` is the index of table names in `table_names`, which is `city` in this case. `"id"` is the column name.
- `column_names`: cleaned and normalized column names. We make sure the column names are meaningful. [to be changed]
- `column_types`: data type of each column
- `foreign_keys`: foreign keys in the database. `[3, 8]` means column indices in the `column_names`. These two columns are foreign keys of two different tables.
- `primary_keys`: primary keys in the database. Each number is the index of `column_names`.
```
{
"column_names": [
[
0,
"id"
],
[
0,
"name"
],
[
0,
"country code"
],
[
0,
"district"
],
.
.
.
],
"column_names_original": [
[
0,
"ID"
],
[
0,
"Name"
],
[
0,
"CountryCode"
],
[
0,
"District"
],
.
.
.
],
"column_types": [
"number",
"text",
"text",
"text",
.
.
.
],
"db_id": "world_1",
"foreign_keys": [
[
3,
8
],
[
23,
8
]
],
"primary_keys": [
1,
8,
23
],
"table_names": [
"city",
"sqlite sequence",
"country",
"country language"
],
"table_names_original": [
"city",
"sqlite_sequence",
"country",
"countrylanguage"
]
}
```
#### Databases
All table contents are contained in corresponding SQLite3 database files.
### Evaluation
Our evaluation metrics include Component Matching, Exact Matching, and Execution Accuracy. For component and exact matching evaluation, instead of simply conducting string comparison between the predicted and gold SQL queries, we decompose each SQL into several clauses, and conduct set comparison in each SQL clause.
For Execution Accuracy, our current models do not predict any value in SQL conditions so that we do not provide execution accuracies. However, we encourage you to provide it in the future submissions. For value prediction, you can assume that a list of gold values for each question is given. Your model has to fill them into the right slots in the SQL.
Please refer to [our paper]() and [this page](https://github.com/taoyds/spider/tree/master/evaluation) for more details and examples.
```
python evaluation.py --gold [gold file] --pred [predicted file] --etype [evaluation type] --db [database dir] --table [table file]
arguments:
[gold file] gold.sql file where each line is `a gold SQL \t db_id`
[predicted file] predicted sql file where each line is a predicted SQL
[evaluation type] "match" for exact set matching score, "exec" for execution score, and "all" for both
[database dir] directory which contains sub-directories where each SQLite3 database is stored
[table file] table.json file which includes foreign key info of each database
```
### FAQ

1119
unified_parser_text_to_sql/third_party/spider/evaluation.py поставляемый Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,15 @@
## Data Preprocess
#### Get Parsed SQL Output
The SQL parsing script is `process_sql.py` in the main directory. Please refer to `parsed_sql_examples.sql` for the explanation of some parsed SQL output examples.
If you would like to use `process_sql.py` to parse SQL queries by yourself, `parse_sql_one.py` provides an example of how the script is called. Or you can use `parse_raw_json.py` to update all parsed SQL results (value for `sql`) in `train.json` and `dev.json`.
#### Get Table Info from Database
To generate the final `tables.json` file. It reads sqlite files from `database/` dir and previous `tables.json` with hand-corrected names:
```
python process/get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]
```

Просмотреть файл

Просмотреть файл

@ -0,0 +1,162 @@
import os
import sys
import json
import sqlite3
from os import listdir, makedirs
from os.path import isfile, isdir, join, split, exists, splitext
from nltk import word_tokenize, tokenize
import traceback
EXIST = {"atis", "geo", "advising", "yelp", "restaurants", "imdb", "academic"}
def convert_fk_index(data):
fk_holder = []
for fk in data["foreign_keys"]:
tn, col, ref_tn, ref_col = fk[0][0], fk[0][1], fk[1][0], fk[1][1]
ref_cid, cid = None, None
try:
tid = data["table_names_original"].index(tn)
ref_tid = data["table_names_original"].index(ref_tn)
for i, (tab_id, col_org) in enumerate(data["column_names_original"]):
if tab_id == ref_tid and ref_col == col_org:
ref_cid = i
elif tid == tab_id and col == col_org:
cid = i
if ref_cid and cid:
fk_holder.append([cid, ref_cid])
except:
traceback.print_exc()
print("table_names_original: ", data["table_names_original"])
print("finding tab name: ", tn, ref_tn)
sys.exit()
return fk_holder
def dump_db_json_schema(db, f):
"""read table and column info"""
conn = sqlite3.connect(db)
conn.execute("pragma foreign_keys=ON")
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
data = {
"db_id": f,
"table_names_original": [],
"table_names": [],
"column_names_original": [(-1, "*")],
"column_names": [(-1, "*")],
"column_types": ["text"],
"primary_keys": [],
"foreign_keys": [],
}
fk_holder = []
for i, item in enumerate(cursor.fetchall()):
table_name = item[0]
data["table_names_original"].append(table_name)
data["table_names"].append(table_name.lower().replace("_", " "))
fks = conn.execute(
"PRAGMA foreign_key_list('{}') ".format(table_name)
).fetchall()
# print("db:{} table:{} fks:{}".format(f,table_name,fks))
fk_holder.extend([[(table_name, fk[3]), (fk[2], fk[4])] for fk in fks])
cur = conn.execute("PRAGMA table_info('{}') ".format(table_name))
for j, col in enumerate(cur.fetchall()):
data["column_names_original"].append((i, col[1]))
data["column_names"].append((i, col[1].lower().replace("_", " ")))
# varchar, '' -> text, int, numeric -> integer,
col_type = col[2].lower()
if (
"char" in col_type
or col_type == ""
or "text" in col_type
or "var" in col_type
):
data["column_types"].append("text")
elif (
"int" in col_type
or "numeric" in col_type
or "decimal" in col_type
or "number" in col_type
or "id" in col_type
or "real" in col_type
or "double" in col_type
or "float" in col_type
):
data["column_types"].append("number")
elif "date" in col_type or "time" in col_type or "year" in col_type:
data["column_types"].append("time")
elif "boolean" in col_type:
data["column_types"].append("boolean")
else:
data["column_types"].append("others")
if col[5] == 1:
data["primary_keys"].append(len(data["column_names"]) - 1)
data["foreign_keys"] = fk_holder
data["foreign_keys"] = convert_fk_index(data)
return data
if __name__ == "__main__":
if len(sys.argv) < 2:
print(
"Usage: python get_tables.py [dir includes many subdirs containing database.sqlite files] [output file name e.g. output.json] [existing tables.json file to be inherited]"
)
sys.exit()
input_dir = sys.argv[1]
output_file = sys.argv[2]
ex_tab_file = sys.argv[3]
all_fs = [
df for df in listdir(input_dir) if exists(join(input_dir, df, df + ".sqlite"))
]
with open(ex_tab_file) as f:
ex_tabs = json.load(f)
# for tab in ex_tabs:
# tab["foreign_keys"] = convert_fk_index(tab)
ex_tabs = {tab["db_id"]: tab for tab in ex_tabs if tab["db_id"] in all_fs}
print("precessed file num: ", len(ex_tabs))
not_fs = [
df
for df in listdir(input_dir)
if not exists(join(input_dir, df, df + ".sqlite"))
]
for d in not_fs:
print("no sqlite file found in: ", d)
db_files = [
(df + ".sqlite", df)
for df in listdir(input_dir)
if exists(join(input_dir, df, df + ".sqlite"))
]
tables = []
for f, df in db_files:
# if df in ex_tabs.keys():
# print 'reading old db: ', df
# tables.append(ex_tabs[df])
db = join(input_dir, df, f)
print("\nreading new db: ", df)
table = dump_db_json_schema(db, df)
prev_tab_num = len(ex_tabs[df]["table_names"])
prev_col_num = len(ex_tabs[df]["column_names"])
cur_tab_num = len(table["table_names"])
cur_col_num = len(table["column_names"])
if (
df in ex_tabs.keys()
and prev_tab_num == cur_tab_num
and prev_col_num == cur_col_num
and prev_tab_num != 0
and len(ex_tabs[df]["column_names"]) > 1
):
table["table_names"] = ex_tabs[df]["table_names"]
table["column_names"] = ex_tabs[df]["column_names"]
else:
print("\n----------------------------------problem db: ", df)
tables.append(table)
print("final db num: ", len(tables))
with open(output_file, "wt") as out:
json.dump(tables, out, sort_keys=True, indent=2, separators=(",", ": "))

Просмотреть файл

@ -0,0 +1,45 @@
import os, sys
import json
import sqlite3
import traceback
import argparse
import tqdm
from ..process_sql import get_sql
from .schema import Schema, get_schemas_from_json
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument("--tables", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()
sql_path = args.input
output_file = args.output
table_file = args.tables
schemas, db_names, tables = get_schemas_from_json(table_file)
with open(sql_path) as inf:
sql_data = json.load(inf)
sql_data_new = []
for data in tqdm.tqdm(sql_data):
try:
db_id = data["db_id"]
schema = schemas[db_id]
table = tables[db_id]
schema = Schema(schema, table)
sql = data["query"]
sql_label = get_sql(schema, sql)
data["sql"] = sql_label
sql_data_new.append(data)
except:
print("db_id: ", db_id)
print("sql: ", sql)
raise
with open(output_file, "wt") as out:
json.dump(sql_data_new, out, sort_keys=True, indent=4, separators=(",", ": "))

Просмотреть файл

@ -0,0 +1,27 @@
import os
import traceback
import re
import sys
import json
import sqlite3
import random
from os import listdir, makedirs
from collections import OrderedDict
from nltk import word_tokenize, tokenize
from os.path import isfile, isdir, join, split, exists, splitext
from ..process_sql import get_sql
from .schema import Schema, get_schemas_from_json
if __name__ == "__main__":
sql = "SELECT name , country , age FROM singer ORDER BY age DESC"
db_id = "concert_singer"
table_file = "tables.json"
schemas, db_names, tables = get_schemas_from_json(table_file)
schema = schemas[db_id]
table = tables[db_id]
schema = Schema(schema, table)
sql_label = get_sql(schema, sql)

Просмотреть файл

@ -0,0 +1,74 @@
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
############ Example 1 ############
SELECT T1.lname , T1.fname
FROM artists AS T1 JOIN paintings AS T2 ON T1.artistID = T2.painterID
EXCEPT
SELECT T1.lname , T1.fname
FROM artists AS T1 JOIN sculptures AS T2 ON T1.artistID = T2.sculptorID
'select': (False #is distinct#, [(0 #index of AGG_OPS#, (0 #index of unit_op to save col1-col2 cases#, (0#index of AGG_OPS#, '__artists.lname__' #index of column names#, False #is DISTINCT#), None #col_unit2 usually is None, for saving col1-col2#)), (0, (0, (0, '__artists.fname__', False), None))])
'from': {'table_units' #list of tables in from#: [('table_unit' #some froms are nested sql#, '__artists__' #gonna be index of table#), ('table_unit', '__paintings__')],
'conds': [(False #if there is NOT#, 2 #index of WHERE_OPS#, (0 #index of unit_op to save col1-col2 cases#, (0 #index of AGG_OPS#, '__artists.artistid__' #index of column names#, False #is DISTINCT#), None #col_unit2 usually is None, for saving col1-col2#), 't2.painterid' #val1, here is t2 ref id#, None #val2 for between val1 and val2#]}
'except': {'from': {'table_units': [('table_unit', '__artists__'), ('table_unit', '__sculptures__')], 'conds': [(False, 2, (0, (0, '__artists.artistid__', False), None), (0, '__sculptures.sculptorid__', False), None)]}, 'select': (False, [(0, (0, (0, '__artists.lname__', False), None)), (0, (0, (0, '__artists.fname__', False), None))])}
############ Example 2 ############
SELECT paintingID
FROM paintings
WHERE height_mm > (SELECT max(height_mm)
FROM paintings
WHERE YEAR > 1900)
'select': (False, [(0, (0, (0, '__paintings.paintingid__', False), None))])
'from': {'table_units': [('table_unit', '__paintings__')], 'conds': []}
'where': [(False, 3, (0, (0, '__paintings.height_mm__', False), None) #finshed val_unit1#, #start val1 which is a sql# {'from': {'table_units': [('table_unit', '__paintings__')], 'conds': []}, 'where': [(False, 3, (0, (0, '__paintings.year__', False), None), 1900.0 #cond val1#, None #cond val2#)], 'select': (False, [(1, (0, (0, '__paintings.height_mm__', False), None))])}, None)]
############ Example 3 ############
ORDER BY count(*) DESC LIMIT 1
'orderBy': ('desc', [(0 #index of unit_op no -/+#, (3 #agg count index#, '__all__', False), None)])
############ Example 4 ############
GROUP BY T2.painterID HAVING count(*) >= 2
'groupBy': [(0 #index of AGG_OPS#, '__paintings.painterid__', False #is distinct#)], 'having': [(False, 5, (0, (3, '__all__', False), None), 2.0, None)]

Просмотреть файл

@ -0,0 +1,80 @@
import json
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema, table):
self._schema = schema
self._table = table
self._idMap = self._map(self._schema, self._table)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema, table):
if 'column_names_original' not in table: # {'table': [col.lower, ..., ]} * -> __all__
table["column_names_original"] = table["column_names"]
table["table_names_original"] = table["table_names"]
column_names_original = table["column_names_original"]
table_names_original = table["table_names_original"]
# print 'column_names_original: ', column_names_original
# print 'table_names_original: ', table_names_original
for i, (tab_id, col) in enumerate(column_names_original):
if tab_id == -1:
idMap = {"*": i}
else:
key = table_names_original[tab_id].lower()
val = col.lower()
idMap[key + "." + val] = i
for i, tab in enumerate(table_names_original):
key = tab.lower()
idMap[key] = i
return idMap
def _get_schemas_from_json(data: dict):
db_names = [db["db_id"] for db in data]
tables = {}
schemas = {}
for db in data:
db_id = db["db_id"]
schema = {}
if 'column_names_original' not in db:
db["column_names_original"] = db["column_names"]
db["table_names_original"] = db["table_names"]
column_names_original = db["column_names_original"]
table_names_original = db["table_names_original"]
tables[db_id] = {
"column_names_original": column_names_original,
"table_names_original": table_names_original,
}
for i, tabn in enumerate(table_names_original):
table = str(tabn.lower())
cols = [str(col.lower()) for td, col in column_names_original if td == i]
schema[table] = cols
schemas[db_id] = schema
return schemas, db_names, tables
def get_schemas_from_json(fpath):
with open(fpath, 'r',encoding='UTF-8') as f:
data = json.load(f)
return _get_schemas_from_json(data)

659
unified_parser_text_to_sql/third_party/spider/process_sql.py поставляемый Normal file
Просмотреть файл

@ -0,0 +1,659 @@
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import json
import sqlite3
from nltk import word_tokenize
import copy
CLAUSE_KEYWORDS = (
"select",
"from",
"where",
"group",
"order",
"limit",
"intersect",
"union",
"except",
)
JOIN_KEYWORDS = ("join", "on", "as")
WHERE_OPS = (
"not",
"between",
"=",
">",
"<",
">=",
"<=",
"!=",
"in",
"like",
"is",
"exists",
)
UNIT_OPS = ("none", "-", "+", "*", "/")
AGG_OPS = ("none", "max", "min", "count", "sum", "avg")
TABLE_TYPE = {
"sql": "sql",
"table_unit": "table_unit",
}
COND_OPS = ("and", "or")
SQL_OPS = ("intersect", "union", "except")
ORDER_OPS = ("desc", "asc")
global_mentioned_schema = []
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema):
self._schema = schema
self._idMap = self._map(self._schema)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema):
idMap = {"*": "__all__"}
id = 1
for key, vals in schema.items():
for val in vals:
idMap[key.lower() + "." + val.lower()] = (
"__" + key.lower() + "." + val.lower() + "__"
)
id += 1
for key in schema:
idMap[key.lower()] = "__" + key.lower() + "__"
id += 1
return idMap
def get_schema(db):
"""
Get database's schema, which is a dict with table name as key
and list of column names as value
:param db: database path
:return: schema dict
"""
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# fetch table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]
# fetch table info
for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
return schema
def get_schema_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
schema = {}
for entry in data:
table = str(entry["table"].lower())
cols = [str(col["column_name"].lower()) for col in entry["col_data"]]
schema[table] = cols
return schema
def tokenize(string):
string = str(string)
string = string.replace(
"'", '"'
) # ensures all string values wrapped by "" problem??
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
# keep string value as token
vals = {}
for i in range(len(quote_idxs) - 1, -1, -2):
qidx1 = quote_idxs[i - 1]
qidx2 = quote_idxs[i]
val = string[qidx1 : qidx2 + 1]
key = "__val_{}_{}__".format(qidx1, qidx2)
string = string[:qidx1] + key + string[qidx2 + 1 :]
vals[key] = val
toks = [word.lower() for word in word_tokenize(string)]
# replace with string value token
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ("!", ">", "<")
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx - 1]
if pre_tok in prefix:
toks = toks[: eq_idx - 1] + [pre_tok + "="] + toks[eq_idx + 1 :]
return toks
def scan_alias(toks):
"""Scan the index of 'as' and build the map for all alias"""
as_idxs = [idx for idx, tok in enumerate(toks) if tok == "as"]
alias = {}
for idx in as_idxs:
alias[toks[idx + 1]] = toks[idx - 1]
return alias
def get_tables_with_alias(schema, toks):
tables = scan_alias(toks)
for key in schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
tok = toks[start_idx]
if tok == "*":
global_mentioned_schema.append(('column', schema.idMap[tok]))
return start_idx + 1, schema.idMap[tok]
if "." in tok: # if token is a composite
alias, col = tok.split(".")
key = tables_with_alias[alias] + "." + col
global_mentioned_schema.append(('column', schema.idMap[key]))
return start_idx + 1, schema.idMap[key]
assert (
default_tables is not None and len(default_tables) > 0
), "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
global_mentioned_schema.append(('column', schema.idMap[key]))
return start_idx + 1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == "(":
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == "("
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ")"
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ")"
idx += 1 # skip ')'
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == "(":
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index("none")
idx, col_unit1 = parse_col_unit(
toks, idx, tables_with_alias, schema, default_tables
)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = parse_col_unit(
toks, idx, tables_with_alias, schema, default_tables
)
if isBlock:
assert toks[idx] == ")"
idx += 1 # skip ')'
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
"""
:returns next idx, table id, table name
"""
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx + 1] == "as":
idx += 3
else:
idx += 1
global_mentioned_schema.append(('table', schema.idMap[key]))
return idx, schema.idMap[key], key
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == "(":
isBlock = True
idx += 1
if toks[idx] == "select":
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
elif '"' in toks[idx]: # token is a string value
val = toks[idx]
idx += 1
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while (
end_idx < len_
and toks[end_idx] != ","
and toks[end_idx] != ")"
and toks[end_idx] != "and"
and toks[end_idx] not in CLAUSE_KEYWORDS
and toks[end_idx] not in JOIN_KEYWORDS
):
end_idx += 1
idx, val = parse_col_unit(
toks[start_idx:end_idx], 0, tables_with_alias, schema, default_tables
)
idx = end_idx
if isBlock:
assert toks[idx] == ")"
idx += 1
return idx, val
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = parse_val_unit(
toks, idx, tables_with_alias, schema, default_tables
)
not_op = False
if toks[idx] == "not":
not_op = True
idx += 1
assert (
idx < len_ and toks[idx] in WHERE_OPS
), "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index(
"between"
): # between..and... special case: dual values
idx, val1 = parse_value(
toks, idx, tables_with_alias, schema, default_tables
)
assert toks[idx] == "and"
idx += 1
idx, val2 = parse_value(
toks, idx, tables_with_alias, schema, default_tables
)
else: # normal case: single value
idx, val1 = parse_value(
toks, idx, tables_with_alias, schema, default_tables
)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (
toks[idx] in CLAUSE_KEYWORDS
or toks[idx] in (")", ";")
or toks[idx] in JOIN_KEYWORDS
):
break
if idx < len_ and toks[idx] in COND_OPS:
conds.append(toks[idx])
idx += 1 # skip and/or
return idx, conds
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == "select", "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == "distinct":
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = parse_val_unit(
toks, idx, tables_with_alias, schema, default_tables
)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ",":
idx += 1 # skip ','
return idx, (isDistinct, val_units)
def parse_from(toks, start_idx, tables_with_alias, schema):
"""
Assume in the from clause, all table units are combined with join
"""
assert "from" in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index("from", start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == "(":
isBlock = True
idx += 1
if toks[idx] == "select":
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE["sql"], sql))
else:
if idx < len_ and toks[idx] == "join":
idx += 1 # skip join
idx, table_unit, table_name = parse_table_unit(
toks, idx, tables_with_alias, schema
)
table_units.append((TABLE_TYPE["table_unit"], table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx, this_conds = parse_condition(
toks, idx, tables_with_alias, schema, default_tables
)
if len(conds) > 0:
conds.append("and")
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ")"
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != "where":
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != "group":
return idx, col_units
idx += 1
assert toks[idx] == "by"
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = parse_col_unit(
toks, idx, tables_with_alias, schema, default_tables
)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ",":
idx += 1 # skip ','
else:
break
return idx, col_units
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = "asc" # default type is 'asc'
if idx >= len_ or toks[idx] != "order":
return idx, val_units
idx += 1
assert toks[idx] == "by"
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = parse_val_unit(
toks, idx, tables_with_alias, schema, default_tables
)
val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ",":
idx += 1 # skip ','
else:
break
return idx, (order_type, val_units)
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != "having":
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == "limit":
idx += 2
return idx, int(toks[idx - 1])
return idx, None
def parse_sql(toks, start_idx, tables_with_alias, schema, require_mentioned_schema=False):
global global_mentioned_schema
global_mentioned_schema = []
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == "(":
isBlock = True
idx += 1
# parse from clause in order to get default tables
from_end_idx, table_units, conds, default_tables = parse_from(
toks, start_idx, tables_with_alias, schema
)
sql["from"] = {"table_units": table_units, "conds": conds}
# select clause
_, select_col_units = parse_select(
toks, idx, tables_with_alias, schema, default_tables
)
idx = from_end_idx
sql["select"] = select_col_units
# where clause
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql["where"] = where_conds
# group by clause
idx, group_col_units = parse_group_by(
toks, idx, tables_with_alias, schema, default_tables
)
sql["groupBy"] = group_col_units
# having clause
idx, having_conds = parse_having(
toks, idx, tables_with_alias, schema, default_tables
)
sql["having"] = having_conds
# order by clause
idx, order_col_units = parse_order_by(
toks, idx, tables_with_alias, schema, default_tables
)
sql["orderBy"] = order_col_units
# limit clause
idx, limit_val = parse_limit(toks, idx)
sql["limit"] = limit_val
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ")"
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# intersect/union/except clause
for op in SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
sql[sql_op] = IUE_sql
if require_mentioned_schema:
return idx, sql, global_mentioned_schema
else:
return idx, sql
def load_data(fpath):
with open(fpath) as f:
data = json.load(f)
return data
def get_sql(schema, query):
toks = tokenize(query)
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
return sql
def extract_mentioned_schema_in_sql(schema, query):
toks = tokenize(query)
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql, mentioned_schema = parse_sql(toks, 0, tables_with_alias, schema, require_mentioned_schema = True)
return mentioned_schema
def skip_semicolon(toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx

Просмотреть файл

@ -0,0 +1,74 @@
import subprocess
import argparse
import os
def run_command(bash_command):
process = subprocess.Popen(bash_command.split())
output, error = process.communicate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="", help="dataset path")
parser.add_argument("--exp_name", type=str, default="", help="test")
parser.add_argument("--models_path", type=str, default="", help="models path")
parser.add_argument("--bart_model_path", type=str, default="", help="bart init models path")
parser.add_argument("--total_num_update", type=int, default=200000)
parser.add_argument("--max_tokens", type=int, default=4096)
parser.add_argument("--tensorboard_path", type=str, default="", help="tensorboard path")
args = parser.parse_args()
print("START training")
run_command("printenv")
restore_file = os.path.join(args.bart_model_path, "model.pt")
cmd = f"""
fairseq-train {args.dataset_path} \
--save-dir {args.models_path}/{args.exp_name} \
--restore-file {restore_file} \
--arch bart_large \
--criterion label_smoothed_cross_entropy \
--source-lang src \
--target-lang tgt \
--truncate-source \
--label-smoothing 0.1 \
--max-tokens {args.max_tokens} \
--update-freq 4 \
--max-update {args.total_num_update} \
--required-batch-size-multiple 1 \
--dropout 0.1 \
--attention-dropout 0.1 \
--relu-dropout 0.0 \
--weight-decay 0.05 \
--optimizer adam \
--adam-eps 1e-08 \
--clip-norm 0.1 \
--lr-scheduler polynomial_decay \
--lr 1e-05 \
--total-num-update {args.total_num_update} \
--warmup-updates 5000 \
--ddp-backend no_c10d \
--num-workers 20 \
--reset-meters \
--reset-optimizer \
--reset-dataloader \
--share-all-embeddings \
--layernorm-embedding \
--share-decoder-input-output-embed \
--skip-invalid-size-inputs-valid-test \
--log-format json \
--log-interval 10 \
--save-interval-updates 500 \
--validate-interval-updates 500 \
--validate-interval 10 \
--save-interval 10 \
--patience 200 \
--no-last-checkpoints \
--no-save-optimizer-state \
--report-accuracy
"""
print("RUN {}".format(cmd))
run_command(cmd)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,110 @@
"""
Based on https://github.com/ElementAI/Unisar/blob/master/Unisar/api.py
"""
import os
import subprocess
from typing import Optional
import torch
from genre.fairseq_model import GENRE
from semparse.contexts.spider_db_context import SpiderDBContext
from semparse.sql.spider import load_original_schemas, load_tables
from semparse.sql.spider_utils import read_dataset_schema
from step1_schema_linking import read_database_schema
from step2_serialization import build_schema_linking_data
from step3_evaluate import decode_with_constrain, get_alias_schema, post_processing_sql
class UnisarAPI(object):
def __init__(self, logdir: str, config_path: str):
self.model = self.inferer.load_model(logdir, step=None)
def convert_csv_to_sqlite(csv_path: str):
# TODO: infer types when importing
db_path = csv_path + ".sqlite"
if os.path.exists(db_path):
os.remove(db_path)
subprocess.run(["sqlite3", db_path, ".mode csv", f".import {csv_path} Data"])
return db_path
class UnisarAPI(object):
"""Run Unisar model on a given database."""
def __init__(self, log_dir: str, db_path: str, schema_path: Optional[str], stanza_model):
self.log_dir = log_dir
self.db_path = db_path
self.schema_path = schema_path
self.stanza_model = stanza_model
# if self.db_path.endswith(".sqlite"):
# pass
# elif self.db_path.endswith(".csv"):
# self.db_path = convert_csv_to_sqlite(self.db_path)
# else:
# raise ValueError("expected either .sqlite or .csv file")
self.schema = read_dataset_schema(self.schema_path, stanza_model)
_, _, self.database_schemas = read_database_schema(self.schema_path)
self.model = GENRE.from_pretrained(self.log_dir).eval()
if torch.cuda.is_available():
self.model.cuda()
def infer_query(self, question, db_id):
###step-1 schema-linking
lemma_utterance_stanza = self.stanza_model(question)
lemma_utterance = [word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words]
db_context = SpiderDBContext(db_id,
lemma_utterance,
tables_file=self.schema_path,
dataset_path=self.db_path,
stanza_model=self.stanza_model,
schemas=self.schema,
original_utterance=question)
value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(db_id)
item = {}
item['interaction'] = [{'db_id': db_id,
'question': question,
'sql': '',
'value_match': value_match,
'value_alignment': value_alignment,
'exact_match': exact_match,
'partial_match': partial_match,
}]
###step-2 serialization
source_sequence, _ = build_schema_linking_data(schema=self.database_schemas[db_id],
question=question,
item=item,
turn_id=0,
linking_type='default')
slml_question = source_sequence[0]
###step-3 prediction
schemas, eval_foreign_key_maps = load_tables(self.schema_path)
original_schemas = load_original_schemas(self.schema_path)
alias_schema = get_alias_schema(schemas)
rnt = decode_with_constrain(slml_question, alias_schema[db_id], self.model)
predict_sql = rnt[0]['text'] if isinstance(rnt[0]['text'], str) else rnt[0]['text'][0]
score = rnt[0]['score'].tolist()
predict_sql = post_processing_sql(predict_sql, eval_foreign_key_maps[db_id], original_schemas[db_id],schemas[db_id])
return {
"slml_question": slml_question,
"predict_sql": predict_sql,
"score": score
}
def execute(self, query):
### TODO: replace the query with value version
pass
# conn = sqlite3.connect(self.db_path)
# # Temporary Hack: makes sure all literals are collated in a case-insensitive way
# query = add_collate_nocase(query)
# results = conn.execute(query).fetchall()
# conn.close()
# return results