зеркало из https://github.com/microsoft/IRNet.git
add files
This commit is contained in:
Родитель
4e4521e795
Коммит
de4c3d02ca
88
README.md
88
README.md
|
@ -1,14 +1,86 @@
|
|||
# IRNet
|
||||
Code for our ACL'19 accepted paper: [Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation](https://arxiv.org/pdf/1905.08205.pdf)
|
||||
|
||||
<p align='center'>
|
||||
<img src='https://zhanzecheng.github.io/architecture.png' width="91%"/>
|
||||
</p>
|
||||
|
||||
## Environment Setup
|
||||
|
||||
* `Python3.6`
|
||||
* `Pytorch 0.4.0` or higher
|
||||
|
||||
Install Python dependency via `pip install -r requirements.txt` when the environment of Python and Pytorch is setup.
|
||||
|
||||
## Running Code
|
||||
|
||||
#### Data preparation
|
||||
|
||||
|
||||
* Download [Glove Embedding](https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip) and put `glove.42B.300d` under `./data/` directory
|
||||
* Download [Pretrained IRNet](https://drive.google.com/open?id=1VoV28fneYss8HaZmoThGlvYU3A-aK31q) and put `
|
||||
IRNet_pretrained.model` under `./saved_model/` directory
|
||||
* Download preprocessed train/dev datasets from [here](https://drive.google.com/open?id=1YFV1GoLivOMlmunKW0nkzefKULO4wtrn) and put `train.json`, `dev.json` and
|
||||
`tables.json` under `./data/` directory
|
||||
|
||||
##### Generating train/dev data by yourself
|
||||
You could process the origin [Spider Data](https://drive.google.com/uc?export=download&id=11icoH_EA-NYb0OrPTdehRWm_d7-DIzWX) by your own. Download and put `train.json`, `dev.json` and
|
||||
`tables.json` under `./data/` directory and follow the instruction on `./preprocess/`
|
||||
|
||||
#### Training
|
||||
|
||||
Run `train.sh` to train IRNet.
|
||||
|
||||
`sh train.sh [GPU_ID] [SAVE_FOLD]`
|
||||
|
||||
#### Testing
|
||||
|
||||
Run `eval.sh` to eval IRNet.
|
||||
|
||||
`sh eval.sh [GPU_ID] [OUTPUT_FOLD]`
|
||||
|
||||
|
||||
#### Evaluation
|
||||
|
||||
You could follow the general evaluation process in [Spider Page](https://github.com/taoyds/spider)
|
||||
|
||||
|
||||
## Results
|
||||
| **Model** | Dev <br /> Exact Set Match <br />Accuracy | Test<br /> Exact Set Match <br />Accuracy |
|
||||
| ----------- | ------------------------------------- | -------------------------------------- |
|
||||
| IRNet | 53.2 | 46.7 |
|
||||
| IRNet+BERT(base) | 61.9 | **54.7** |
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
If you use IRNet, please cite the following work.
|
||||
|
||||
```
|
||||
@article{GuoIRNet2019,
|
||||
author={Jiaqi Guo and Zecheng Zhan and Yan Gao and Yan Xiao and Jian-Guang Lou and Ting Liu and Dongmei Zhang},
|
||||
title={Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation},
|
||||
journal={arXiv preprint arXiv:1905.08205},
|
||||
year={2019},
|
||||
note={version 1}
|
||||
}
|
||||
```
|
||||
|
||||
## Thanks
|
||||
We would like to thank [Tao Yu](https://taoyds.github.io/) and [Bo Pang](https://www.linkedin.com/in/bo-pang/) for running evaluations on our submitted models.
|
||||
We are also grateful to the flexible semantic parser [TranX](https://github.com/pcyin/tranX) that inspires our works.
|
||||
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
This project welcomes contributions and suggestions. Most contributions require you to
|
||||
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
|
||||
and actually do, grant us the rights to use your contribution. For details, visit
|
||||
https://cla.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need
|
||||
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
|
||||
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : __init__.py
|
||||
# @Software: PyCharm
|
||||
"""
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/27
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : eval.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from src import args as arg
|
||||
from src import utils
|
||||
from src.models.model import IRNet
|
||||
from src.rule import semQL
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
"""
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
|
||||
grammar = semQL.Grammar()
|
||||
sql_data, table_data, val_sql_data,\
|
||||
val_table_data= utils.load_dataset(args.dataset, use_small=args.toy)
|
||||
|
||||
model = IRNet(args, grammar)
|
||||
|
||||
if args.cuda: model.cuda()
|
||||
|
||||
print('load pretrained model from %s'% (args.load_model))
|
||||
pretrained_model = torch.load(args.load_model,
|
||||
map_location=lambda storage, loc: storage)
|
||||
import copy
|
||||
pretrained_modeled = copy.deepcopy(pretrained_model)
|
||||
for k in pretrained_model.keys():
|
||||
if k not in model.state_dict().keys():
|
||||
del pretrained_modeled[k]
|
||||
|
||||
model.load_state_dict(pretrained_modeled)
|
||||
|
||||
model.word_emb = utils.load_word_emb(args.glove_embed_path)
|
||||
|
||||
json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
|
||||
beam_size=args.beam_size)
|
||||
# utils.eval_acc(json_datas, val_sql_data)
|
||||
import json
|
||||
with open('./predict_lf.json', 'w') as f:
|
||||
json.dump(json_datas, f)
|
||||
|
||||
if __name__ == '__main__':
|
||||
arg_parser = arg.init_arg_parser()
|
||||
args = arg.init_config(arg_parser)
|
||||
print(args)
|
||||
evaluate(args)
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
devices=$1
|
||||
save_name=$2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$devices python -u eval.py --dataset ./data \
|
||||
--glove_embed_path ./data/glove.42B.300d.txt \
|
||||
--cuda \
|
||||
--epoch 50 \
|
||||
--loss_epoch_threshold 50 \
|
||||
--sketch_loss_coefficie 1.0 \
|
||||
--beam_size 5 \
|
||||
--seed 90 \
|
||||
--save ${save_name} \
|
||||
--embed_size 300 \
|
||||
--sentence_features \
|
||||
--column_pointer \
|
||||
--hidden_size 300 \
|
||||
--lr_scheduler \
|
||||
--lr_scheduler_gammar 0.5 \
|
||||
--att_vec_size 300 \
|
||||
--load_model ./saved_model/IRNet_pretrained.model
|
||||
|
||||
python sem2SQL.py --data_path ./data --input_path predict_lf.json --output_path ${save_name}
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/24
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : data_process.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
import json
|
||||
import argparse
|
||||
import nltk
|
||||
import os
|
||||
import pickle
|
||||
from utils import symbol_filter, re_lemma, fully_part_header, group_header, partial_header, num2year, group_symbol, group_values, group_digital
|
||||
from utils import AGG, wordnet_lemmatizer
|
||||
from utils import load_dataSets
|
||||
|
||||
def process_datas(datas, args):
|
||||
"""
|
||||
|
||||
:param datas:
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
with open(os.path.join(args.conceptNet, 'english_RelatedTo.pkl'), 'rb') as f:
|
||||
english_RelatedTo = pickle.load(f)
|
||||
|
||||
with open(os.path.join(args.conceptNet, 'english_IsA.pkl'), 'rb') as f:
|
||||
english_IsA = pickle.load(f)
|
||||
|
||||
# copy of the origin question_toks
|
||||
for d in datas:
|
||||
if 'origin_question_toks' not in d:
|
||||
d['origin_question_toks'] = d['question_toks']
|
||||
|
||||
for entry in datas:
|
||||
entry['question_toks'] = symbol_filter(entry['question_toks'])
|
||||
origin_question_toks = symbol_filter([x for x in entry['origin_question_toks'] if x.lower() != 'the'])
|
||||
question_toks = [wordnet_lemmatizer.lemmatize(x.lower()) for x in entry['question_toks'] if x.lower() != 'the']
|
||||
|
||||
entry['question_toks'] = question_toks
|
||||
|
||||
table_names = []
|
||||
table_names_pattern = []
|
||||
|
||||
for y in entry['table_names']:
|
||||
x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
|
||||
table_names.append(" ".join(x))
|
||||
x = [re_lemma(x.lower()) for x in y.split(' ')]
|
||||
table_names_pattern.append(" ".join(x))
|
||||
|
||||
header_toks = []
|
||||
header_toks_list = []
|
||||
|
||||
header_toks_pattern = []
|
||||
header_toks_list_pattern = []
|
||||
|
||||
for y in entry['col_set']:
|
||||
x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
|
||||
header_toks.append(" ".join(x))
|
||||
header_toks_list.append(x)
|
||||
|
||||
x = [re_lemma(x.lower()) for x in y.split(' ')]
|
||||
header_toks_pattern.append(" ".join(x))
|
||||
header_toks_list_pattern.append(x)
|
||||
|
||||
num_toks = len(question_toks)
|
||||
idx = 0
|
||||
tok_concol = []
|
||||
type_concol = []
|
||||
nltk_result = nltk.pos_tag(question_toks)
|
||||
|
||||
while idx < num_toks:
|
||||
|
||||
# fully header
|
||||
end_idx, header = fully_part_header(question_toks, idx, num_toks, header_toks)
|
||||
if header:
|
||||
tok_concol.append(question_toks[idx: end_idx])
|
||||
type_concol.append(["col"])
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
# check for table
|
||||
end_idx, tname = group_header(question_toks, idx, num_toks, table_names)
|
||||
if tname:
|
||||
tok_concol.append(question_toks[idx: end_idx])
|
||||
type_concol.append(["table"])
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
# check for column
|
||||
end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
|
||||
if header:
|
||||
tok_concol.append(question_toks[idx: end_idx])
|
||||
type_concol.append(["col"])
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
# check for partial column
|
||||
end_idx, tname = partial_header(question_toks, idx, header_toks_list)
|
||||
if tname:
|
||||
tok_concol.append(tname)
|
||||
type_concol.append(["col"])
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
# check for aggregation
|
||||
end_idx, agg = group_header(question_toks, idx, num_toks, AGG)
|
||||
if agg:
|
||||
tok_concol.append(question_toks[idx: end_idx])
|
||||
type_concol.append(["agg"])
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
if nltk_result[idx][1] == 'RBR' or nltk_result[idx][1] == 'JJR':
|
||||
tok_concol.append([question_toks[idx]])
|
||||
type_concol.append(['MORE'])
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
if nltk_result[idx][1] == 'RBS' or nltk_result[idx][1] == 'JJS':
|
||||
tok_concol.append([question_toks[idx]])
|
||||
type_concol.append(['MOST'])
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
# string match for Time Format
|
||||
if num2year(question_toks[idx]):
|
||||
question_toks[idx] = 'year'
|
||||
end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
|
||||
if header:
|
||||
tok_concol.append(question_toks[idx: end_idx])
|
||||
type_concol.append(["col"])
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
def get_concept_result(toks, graph):
|
||||
for begin_id in range(0, len(toks)):
|
||||
for r_ind in reversed(range(1, len(toks) + 1 - begin_id)):
|
||||
tmp_query = "_".join(toks[begin_id:r_ind])
|
||||
if tmp_query in graph:
|
||||
mi = graph[tmp_query]
|
||||
for col in entry['col_set']:
|
||||
if col in mi:
|
||||
return col
|
||||
|
||||
end_idx, symbol = group_symbol(question_toks, idx, num_toks)
|
||||
if symbol:
|
||||
tmp_toks = [x for x in question_toks[idx: end_idx]]
|
||||
assert len(tmp_toks) > 0, print(symbol, question_toks)
|
||||
pro_result = get_concept_result(tmp_toks, english_IsA)
|
||||
if pro_result is None:
|
||||
pro_result = get_concept_result(tmp_toks, english_RelatedTo)
|
||||
if pro_result is None:
|
||||
pro_result = "NONE"
|
||||
for tmp in tmp_toks:
|
||||
tok_concol.append([tmp])
|
||||
type_concol.append([pro_result])
|
||||
pro_result = "NONE"
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
end_idx, values = group_values(origin_question_toks, idx, num_toks)
|
||||
if values and (len(values) > 1 or question_toks[idx - 1] not in ['?', '.']):
|
||||
tmp_toks = [wordnet_lemmatizer.lemmatize(x) for x in question_toks[idx: end_idx] if x.isalnum() is True]
|
||||
assert len(tmp_toks) > 0, print(question_toks[idx: end_idx], values, question_toks, idx, end_idx)
|
||||
pro_result = get_concept_result(tmp_toks, english_IsA)
|
||||
if pro_result is None:
|
||||
pro_result = get_concept_result(tmp_toks, english_RelatedTo)
|
||||
if pro_result is None:
|
||||
pro_result = "NONE"
|
||||
for tmp in tmp_toks:
|
||||
tok_concol.append([tmp])
|
||||
type_concol.append([pro_result])
|
||||
pro_result = "NONE"
|
||||
idx = end_idx
|
||||
continue
|
||||
|
||||
result = group_digital(question_toks, idx)
|
||||
if result is True:
|
||||
tok_concol.append(question_toks[idx: idx + 1])
|
||||
type_concol.append(["value"])
|
||||
idx += 1
|
||||
continue
|
||||
if question_toks[idx] == ['ha']:
|
||||
question_toks[idx] = ['have']
|
||||
|
||||
tok_concol.append([question_toks[idx]])
|
||||
type_concol.append(['NONE'])
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
entry['question_arg'] = tok_concol
|
||||
entry['question_arg_type'] = type_concol
|
||||
entry['nltk_pos'] = nltk_result
|
||||
|
||||
return datas
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--data_path', type=str, help='dataset', required=True)
|
||||
arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True)
|
||||
arg_parser.add_argument('--output', type=str, help='output data')
|
||||
args = arg_parser.parse_args()
|
||||
args.conceptNet = './conceptNet'
|
||||
|
||||
# loading dataSets
|
||||
datas, table = load_dataSets(args)
|
||||
|
||||
# process datasets
|
||||
process_result = process_datas(datas, args)
|
||||
|
||||
with open(args.output, 'w') as f:
|
||||
json.dump(datas, f)
|
||||
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/1/29
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : download_nltk.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
import nltk
|
||||
nltk.download('averaged_perceptron_tagger')
|
||||
nltk.download('punkt')
|
||||
nltk.download('wordnet')
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
data=$1
|
||||
table_data=$2
|
||||
output=$3
|
||||
|
||||
echo "Start download NLTK data"
|
||||
python download_nltk.py
|
||||
|
||||
echo "Start process the origin Spider dataset"
|
||||
python data_process.py --data_path ${data} --table_path ${table_data} --output "process_data.json"
|
||||
|
||||
echo "Start generate SemQL from SQL"
|
||||
python sql2SemQL.py --data_path process_data.json --table_path ${table_data} --output ${data}
|
||||
|
||||
rm process_data.json
|
|
@ -0,0 +1,391 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/24
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : sql2SemQL.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import copy
|
||||
from utils import load_dataSets
|
||||
|
||||
sys.path.append("..")
|
||||
from src.rule.semQL import Root1, Root, N, A, C, T, Sel, Sup, Filter, Order
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.copy_selec = None
|
||||
self.sel_result = []
|
||||
self.colSet = set()
|
||||
|
||||
def _init_rule(self):
|
||||
self.copy_selec = None
|
||||
self.colSet = set()
|
||||
|
||||
def _parse_root(self, sql):
|
||||
"""
|
||||
parsing the sql by the grammar
|
||||
R ::= Select | Select Filter | Select Order | ... |
|
||||
:return: [R(), states]
|
||||
"""
|
||||
use_sup, use_ord, use_fil = True, True, False
|
||||
|
||||
if sql['sql']['limit'] == None:
|
||||
use_sup = False
|
||||
|
||||
if sql['sql']['orderBy'] == []:
|
||||
use_ord = False
|
||||
elif sql['sql']['limit'] != None:
|
||||
use_ord = False
|
||||
|
||||
# check the where and having
|
||||
if sql['sql']['where'] != [] or \
|
||||
sql['sql']['having'] != []:
|
||||
use_fil = True
|
||||
|
||||
if use_fil and use_sup:
|
||||
return [Root(0)], ['FILTER', 'SUP', 'SEL']
|
||||
elif use_fil and use_ord:
|
||||
return [Root(1)], ['ORDER', 'FILTER', 'SEL']
|
||||
elif use_sup:
|
||||
return [Root(2)], ['SUP', 'SEL']
|
||||
elif use_fil:
|
||||
return [Root(3)], ['FILTER', 'SEL']
|
||||
elif use_ord:
|
||||
return [Root(4)], ['ORDER', 'SEL']
|
||||
else:
|
||||
return [Root(5)], ['SEL']
|
||||
|
||||
def _parser_column0(self, sql, select):
|
||||
"""
|
||||
Find table of column '*'
|
||||
:return: T(table_id)
|
||||
"""
|
||||
if len(sql['sql']['from']['table_units']) == 1:
|
||||
return T(sql['sql']['from']['table_units'][0][1])
|
||||
else:
|
||||
table_list = []
|
||||
for tmp_t in sql['sql']['from']['table_units']:
|
||||
if type(tmp_t[1]) == int:
|
||||
table_list.append(tmp_t[1])
|
||||
table_set, other_set = set(table_list), set()
|
||||
for sel_p in select:
|
||||
if sel_p[1][1][1] != 0:
|
||||
other_set.add(sql['col_table'][sel_p[1][1][1]])
|
||||
|
||||
if len(sql['sql']['where']) == 1:
|
||||
other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
|
||||
elif len(sql['sql']['where']) == 3:
|
||||
other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
|
||||
other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]])
|
||||
elif len(sql['sql']['where']) == 5:
|
||||
other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
|
||||
other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]])
|
||||
other_set.add(sql['col_table'][sql['sql']['where'][4][2][1][1]])
|
||||
table_set = table_set - other_set
|
||||
if len(table_set) == 1:
|
||||
return T(list(table_set)[0])
|
||||
elif len(table_set) == 0 and sql['sql']['groupBy'] != []:
|
||||
return T(sql['col_table'][sql['sql']['groupBy'][0][1]])
|
||||
else:
|
||||
question = sql['question']
|
||||
self.sel_result.append(question)
|
||||
print('column * table error')
|
||||
return T(sql['sql']['from']['table_units'][0][1])
|
||||
|
||||
def _parse_select(self, sql):
|
||||
"""
|
||||
parsing the sql by the grammar
|
||||
Select ::= A | AA | AAA | ... |
|
||||
A ::= agg column table
|
||||
:return: [Sel(), states]
|
||||
"""
|
||||
result = []
|
||||
select = sql['sql']['select'][1]
|
||||
result.append(Sel(0))
|
||||
result.append(N(len(select) - 1))
|
||||
|
||||
for sel in select:
|
||||
result.append(A(sel[0]))
|
||||
self.colSet.add(sql['col_set'].index(sql['names'][sel[1][1][1]]))
|
||||
result.append(C(sql['col_set'].index(sql['names'][sel[1][1][1]])))
|
||||
# now check for the situation with *
|
||||
if sel[1][1][1] == 0:
|
||||
result.append(self._parser_column0(sql, select))
|
||||
else:
|
||||
result.append(T(sql['col_table'][sel[1][1][1]]))
|
||||
if not self.copy_selec:
|
||||
self.copy_selec = [copy.deepcopy(result[-2]), copy.deepcopy(result[-1])]
|
||||
|
||||
return result, None
|
||||
|
||||
def _parse_sup(self, sql):
|
||||
"""
|
||||
parsing the sql by the grammar
|
||||
Sup ::= Most A | Least A
|
||||
A ::= agg column table
|
||||
:return: [Sup(), states]
|
||||
"""
|
||||
result = []
|
||||
select = sql['sql']['select'][1]
|
||||
if sql['sql']['limit'] == None:
|
||||
return result, None
|
||||
if sql['sql']['orderBy'][0] == 'desc':
|
||||
result.append(Sup(0))
|
||||
else:
|
||||
result.append(Sup(1))
|
||||
|
||||
result.append(A(sql['sql']['orderBy'][1][0][1][0]))
|
||||
self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))
|
||||
result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])))
|
||||
if sql['sql']['orderBy'][1][0][1][1] == 0:
|
||||
result.append(self._parser_column0(sql, select))
|
||||
else:
|
||||
result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]]))
|
||||
return result, None
|
||||
|
||||
def _parse_filter(self, sql):
|
||||
"""
|
||||
parsing the sql by the grammar
|
||||
Filter ::= and Filter Filter | ... |
|
||||
A ::= agg column table
|
||||
:return: [Filter(), states]
|
||||
"""
|
||||
result = []
|
||||
# check the where
|
||||
if sql['sql']['where'] != [] and sql['sql']['having'] != []:
|
||||
result.append(Filter(0))
|
||||
|
||||
if sql['sql']['where'] != []:
|
||||
# check the not and/or
|
||||
if len(sql['sql']['where']) == 1:
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
|
||||
elif len(sql['sql']['where']) == 3:
|
||||
if sql['sql']['where'][1] == 'or':
|
||||
result.append(Filter(1))
|
||||
else:
|
||||
result.append(Filter(0))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
|
||||
else:
|
||||
if sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'and':
|
||||
result.append(Filter(0))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
|
||||
result.append(Filter(0))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
|
||||
elif sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'or':
|
||||
result.append(Filter(1))
|
||||
result.append(Filter(0))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
|
||||
elif sql['sql']['where'][1] == 'or' and sql['sql']['where'][3] == 'and':
|
||||
result.append(Filter(1))
|
||||
result.append(Filter(0))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
|
||||
else:
|
||||
result.append(Filter(1))
|
||||
result.append(Filter(1))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
|
||||
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
|
||||
|
||||
# check having
|
||||
if sql['sql']['having'] != []:
|
||||
result.extend(self.parse_one_condition(sql['sql']['having'][0], sql['names'], sql))
|
||||
return result, None
|
||||
|
||||
def _parse_order(self, sql):
|
||||
"""
|
||||
parsing the sql by the grammar
|
||||
Order ::= asc A | desc A
|
||||
A ::= agg column table
|
||||
:return: [Order(), states]
|
||||
"""
|
||||
result = []
|
||||
|
||||
if 'order' not in sql['query_toks_no_value'] or 'by' not in sql['query_toks_no_value']:
|
||||
return result, None
|
||||
elif 'limit' in sql['query_toks_no_value']:
|
||||
return result, None
|
||||
else:
|
||||
if sql['sql']['orderBy'] == []:
|
||||
return result, None
|
||||
else:
|
||||
select = sql['sql']['select'][1]
|
||||
if sql['sql']['orderBy'][0] == 'desc':
|
||||
result.append(Order(0))
|
||||
else:
|
||||
result.append(Order(1))
|
||||
result.append(A(sql['sql']['orderBy'][1][0][1][0]))
|
||||
self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))
|
||||
result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])))
|
||||
if sql['sql']['orderBy'][1][0][1][1] == 0:
|
||||
result.append(self._parser_column0(sql, select))
|
||||
else:
|
||||
result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]]))
|
||||
return result, None
|
||||
|
||||
|
||||
def parse_one_condition(self, sql_condit, names, sql):
|
||||
result = []
|
||||
# check if V(root)
|
||||
nest_query = True
|
||||
if type(sql_condit[3]) != dict:
|
||||
nest_query = False
|
||||
|
||||
if sql_condit[0] == True:
|
||||
if sql_condit[1] == 9:
|
||||
# not like only with values
|
||||
fil = Filter(10)
|
||||
elif sql_condit[1] == 8:
|
||||
# not in with Root
|
||||
fil = Filter(19)
|
||||
else:
|
||||
print(sql_condit[1])
|
||||
raise NotImplementedError("not implement for the others FIL")
|
||||
else:
|
||||
# check for Filter (<,=,>,!=,between, >=, <=, ...)
|
||||
single_map = {1:8,2:2,3:5,4:4,5:7,6:6,7:3}
|
||||
nested_map = {1:15,2:11,3:13,4:12,5:16,6:17,7:14}
|
||||
if sql_condit[1] in [1, 2, 3, 4, 5, 6, 7]:
|
||||
if nest_query == False:
|
||||
fil = Filter(single_map[sql_condit[1]])
|
||||
else:
|
||||
fil = Filter(nested_map[sql_condit[1]])
|
||||
elif sql_condit[1] == 9:
|
||||
fil = Filter(9)
|
||||
elif sql_condit[1] == 8:
|
||||
fil = Filter(18)
|
||||
else:
|
||||
print(sql_condit[1])
|
||||
raise NotImplementedError("not implement for the others FIL")
|
||||
|
||||
result.append(fil)
|
||||
result.append(A(sql_condit[2][1][0]))
|
||||
self.colSet.add(sql['col_set'].index(sql['names'][sql_condit[2][1][1]]))
|
||||
result.append(C(sql['col_set'].index(sql['names'][sql_condit[2][1][1]])))
|
||||
if sql_condit[2][1][1] == 0:
|
||||
select = sql['sql']['select'][1]
|
||||
result.append(self._parser_column0(sql, select))
|
||||
else:
|
||||
result.append(T(sql['col_table'][sql_condit[2][1][1]]))
|
||||
|
||||
# check for the nested value
|
||||
if type(sql_condit[3]) == dict:
|
||||
nest_query = {}
|
||||
nest_query['names'] = names
|
||||
nest_query['query_toks_no_value'] = ""
|
||||
nest_query['sql'] = sql_condit[3]
|
||||
nest_query['col_table'] = sql['col_table']
|
||||
nest_query['col_set'] = sql['col_set']
|
||||
nest_query['table_names'] = sql['table_names']
|
||||
nest_query['question'] = sql['question']
|
||||
nest_query['query'] = sql['query']
|
||||
nest_query['keys'] = sql['keys']
|
||||
result.extend(self.parser(nest_query))
|
||||
|
||||
return result
|
||||
|
||||
def _parse_step(self, state, sql):
|
||||
|
||||
if state == 'ROOT':
|
||||
return self._parse_root(sql)
|
||||
|
||||
if state == 'SEL':
|
||||
return self._parse_select(sql)
|
||||
|
||||
elif state == 'SUP':
|
||||
return self._parse_sup(sql)
|
||||
|
||||
elif state == 'FILTER':
|
||||
return self._parse_filter(sql)
|
||||
|
||||
elif state == 'ORDER':
|
||||
return self._parse_order(sql)
|
||||
else:
|
||||
raise NotImplementedError("Not the right state")
|
||||
|
||||
def full_parse(self, query):
|
||||
sql = query['sql']
|
||||
nest_query = {}
|
||||
nest_query['names'] = query['names']
|
||||
nest_query['query_toks_no_value'] = ""
|
||||
nest_query['col_table'] = query['col_table']
|
||||
nest_query['col_set'] = query['col_set']
|
||||
nest_query['table_names'] = query['table_names']
|
||||
nest_query['question'] = query['question']
|
||||
nest_query['query'] = query['query']
|
||||
nest_query['keys'] = query['keys']
|
||||
|
||||
if sql['intersect']:
|
||||
results = [Root1(0)]
|
||||
nest_query['sql'] = sql['intersect']
|
||||
results.extend(self.parser(query))
|
||||
results.extend(self.parser(nest_query))
|
||||
return results
|
||||
|
||||
if sql['union']:
|
||||
results = [Root1(1)]
|
||||
nest_query['sql'] = sql['union']
|
||||
results.extend(self.parser(query))
|
||||
results.extend(self.parser(nest_query))
|
||||
return results
|
||||
|
||||
if sql['except']:
|
||||
results = [Root1(2)]
|
||||
nest_query['sql'] = sql['except']
|
||||
results.extend(self.parser(query))
|
||||
results.extend(self.parser(nest_query))
|
||||
return results
|
||||
|
||||
results = [Root1(3)]
|
||||
results.extend(self.parser(query))
|
||||
|
||||
return results
|
||||
|
||||
def parser(self, query):
|
||||
stack = ["ROOT"]
|
||||
result = []
|
||||
while len(stack) > 0:
|
||||
state = stack.pop()
|
||||
step_result, step_state = self._parse_step(state, query)
|
||||
result.extend(step_result)
|
||||
if step_state:
|
||||
stack.extend(step_state)
|
||||
return result
|
||||
|
||||
if __name__ == '__main__':
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--data_path', type=str, help='dataset', required=True)
|
||||
arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True)
|
||||
arg_parser.add_argument('--output', type=str, help='output data', required=True)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
parser = Parser()
|
||||
|
||||
# loading dataSets
|
||||
datas, table = load_dataSets(args)
|
||||
processed_data = []
|
||||
|
||||
for i, d in enumerate(datas):
|
||||
if len(datas[i]['sql']['select'][1]) > 5:
|
||||
continue
|
||||
r = parser.full_parse(datas[i])
|
||||
datas[i]['rule_label'] = " ".join([str(x) for x in r])
|
||||
processed_data.append(datas[i])
|
||||
|
||||
print('Finished %s datas and failed %s datas' % (len(processed_data), len(datas) - len(processed_data)))
|
||||
with open(args.output, 'w', encoding='utf8') as f:
|
||||
f.write(json.dumps(processed_data))
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/24
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : utils.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from pattern.en import lemma
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
|
||||
VALUE_FILTER = ['what', 'how', 'list', 'give', 'show', 'find', 'id', 'order', 'when']
|
||||
AGG = ['average', 'sum', 'max', 'min', 'minimum', 'maximum', 'between']
|
||||
|
||||
wordnet_lemmatizer = WordNetLemmatizer()
|
||||
|
||||
def load_dataSets(args):
|
||||
with open(args.table_path, 'r', encoding='utf8') as f:
|
||||
table_datas = json.load(f)
|
||||
with open(args.data_path, 'r', encoding='utf8') as f:
|
||||
datas = json.load(f)
|
||||
|
||||
output_tab = {}
|
||||
tables = {}
|
||||
tabel_name = set()
|
||||
for i in range(len(table_datas)):
|
||||
table = table_datas[i]
|
||||
temp = {}
|
||||
temp['col_map'] = table['column_names']
|
||||
temp['table_names'] = table['table_names']
|
||||
tmp_col = []
|
||||
for cc in [x[1] for x in table['column_names']]:
|
||||
if cc not in tmp_col:
|
||||
tmp_col.append(cc)
|
||||
table['col_set'] = tmp_col
|
||||
db_name = table['db_id']
|
||||
tabel_name.add(db_name)
|
||||
table['schema_content'] = [col[1] for col in table['column_names']]
|
||||
table['col_table'] = [col[0] for col in table['column_names']]
|
||||
output_tab[db_name] = temp
|
||||
tables[db_name] = table
|
||||
|
||||
for d in datas:
|
||||
d['names'] = tables[d['db_id']]['schema_content']
|
||||
d['table_names'] = tables[d['db_id']]['table_names']
|
||||
d['col_set'] = tables[d['db_id']]['col_set']
|
||||
d['col_table'] = tables[d['db_id']]['col_table']
|
||||
keys = {}
|
||||
for kv in tables[d['db_id']]['foreign_keys']:
|
||||
keys[kv[0]] = kv[1]
|
||||
keys[kv[1]] = kv[0]
|
||||
for id_k in tables[d['db_id']]['primary_keys']:
|
||||
keys[id_k] = id_k
|
||||
d['keys'] = keys
|
||||
return datas, tables
|
||||
|
||||
def group_header(toks, idx, num_toks, header_toks):
|
||||
for endIdx in reversed(range(idx + 1, num_toks+1)):
|
||||
sub_toks = toks[idx: endIdx]
|
||||
sub_toks = " ".join(sub_toks)
|
||||
if sub_toks in header_toks:
|
||||
return endIdx, sub_toks
|
||||
return idx, None
|
||||
|
||||
def fully_part_header(toks, idx, num_toks, header_toks):
|
||||
for endIdx in reversed(range(idx + 1, num_toks+1)):
|
||||
sub_toks = toks[idx: endIdx]
|
||||
if len(sub_toks) > 1:
|
||||
sub_toks = " ".join(sub_toks)
|
||||
if sub_toks in header_toks:
|
||||
return endIdx, sub_toks
|
||||
return idx, None
|
||||
|
||||
def partial_header(toks, idx, header_toks):
|
||||
def check_in(list_one, list_two):
|
||||
if len(set(list_one) & set(list_two)) == len(list_one) and (len(list_two) <= 3):
|
||||
return True
|
||||
for endIdx in reversed(range(idx + 1, len(toks))):
|
||||
sub_toks = toks[idx: min(endIdx, len(toks))]
|
||||
if len(sub_toks) > 1:
|
||||
flag_count = 0
|
||||
tmp_heads = None
|
||||
for heads in header_toks:
|
||||
if check_in(sub_toks, heads):
|
||||
flag_count += 1
|
||||
tmp_heads = heads
|
||||
if flag_count == 1:
|
||||
return endIdx, tmp_heads
|
||||
return idx, None
|
||||
|
||||
def symbol_filter(questions):
|
||||
question_tmp_q = []
|
||||
for q_id, q_val in enumerate(questions):
|
||||
if len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥<EFBFBD>', '鈥<EFBFBD>'] and q_val[-1] in ["'", '"', '`', '鈥<EFBFBD>']:
|
||||
question_tmp_q.append("'")
|
||||
question_tmp_q += ["".join(q_val[1:-1])]
|
||||
question_tmp_q.append("'")
|
||||
elif len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥<EFBFBD>'] :
|
||||
question_tmp_q.append("'")
|
||||
question_tmp_q += ["".join(q_val[1:])]
|
||||
elif len(q_val) > 2 and q_val[-1] in ["'", '"', '`', '鈥<EFBFBD>']:
|
||||
question_tmp_q += ["".join(q_val[0:-1])]
|
||||
question_tmp_q.append("'")
|
||||
elif q_val in ["'", '"', '`', '鈥<EFBFBD>', '鈥<EFBFBD>', '``', "''"]:
|
||||
question_tmp_q += ["'"]
|
||||
else:
|
||||
question_tmp_q += [q_val]
|
||||
return question_tmp_q
|
||||
|
||||
|
||||
def group_values(toks, idx, num_toks):
|
||||
def check_isupper(tok_lists):
|
||||
for tok_one in tok_lists:
|
||||
if tok_one[0].isupper() is False:
|
||||
return False
|
||||
return True
|
||||
|
||||
for endIdx in reversed(range(idx + 1, num_toks + 1)):
|
||||
sub_toks = toks[idx: endIdx]
|
||||
|
||||
if len(sub_toks) > 1 and check_isupper(sub_toks) is True:
|
||||
return endIdx, sub_toks
|
||||
if len(sub_toks) == 1:
|
||||
if sub_toks[0][0].isupper() and sub_toks[0].lower() not in VALUE_FILTER and \
|
||||
sub_toks[0].lower().isalnum() is True:
|
||||
return endIdx, sub_toks
|
||||
return idx, None
|
||||
|
||||
|
||||
def group_digital(toks, idx):
|
||||
test = toks[idx].replace(':', '')
|
||||
test = test.replace('.', '')
|
||||
if test.isdigit():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def group_symbol(toks, idx, num_toks):
|
||||
if toks[idx-1] == "'":
|
||||
for i in range(0, min(3, num_toks-idx)):
|
||||
if toks[i + idx] == "'":
|
||||
return i + idx, toks[idx:i+idx]
|
||||
return idx, None
|
||||
|
||||
|
||||
def num2year(tok):
|
||||
if len(str(tok)) == 4 and str(tok).isdigit() and int(str(tok)[:2]) < 22 and int(str(tok)[:2]) > 15:
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_header(toks, header_toks, tok_concol, idx, num_toks):
|
||||
def check_in(list_one, list_two):
|
||||
if set(list_one) == set(list_two):
|
||||
return True
|
||||
for endIdx in range(idx, num_toks):
|
||||
toks += tok_concol[endIdx]
|
||||
if len(tok_concol[endIdx]) > 1:
|
||||
break
|
||||
for heads in header_toks:
|
||||
if check_in(toks, heads):
|
||||
return heads
|
||||
return None
|
||||
|
||||
def re_lemma(string):
|
||||
lema = lemma(string.lower())
|
||||
if len(lema) > 0:
|
||||
return lema
|
||||
else:
|
||||
return string.lower()
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
nltk==3.4
|
||||
pattern
|
||||
numpy==1.14.0
|
||||
pytorch-pretrained-bert==0.5.1
|
||||
tqdm==4.31.1
|
|
@ -0,0 +1,697 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/27
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : sem2SQL.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import traceback
|
||||
|
||||
from src.rule.graph import Graph
|
||||
from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
|
||||
from src.rule.sem_utils import alter_inter, alter_not_in, alter_column0, load_dataSets
|
||||
|
||||
|
||||
def split_logical_form(lf):
|
||||
indexs = [i+1 for i, letter in enumerate(lf) if letter == ')']
|
||||
indexs.insert(0, 0)
|
||||
components = list()
|
||||
for i in range(1, len(indexs)):
|
||||
components.append(lf[indexs[i-1]:indexs[i]].strip())
|
||||
return components
|
||||
|
||||
|
||||
def pop_front(array):
|
||||
if len(array) == 0:
|
||||
return 'None'
|
||||
return array.pop(0)
|
||||
|
||||
|
||||
def is_end(components, transformed_sql, is_root_processed):
|
||||
end = False
|
||||
c = pop_front(components)
|
||||
c_instance = eval(c)
|
||||
|
||||
if isinstance(c_instance, Root) and is_root_processed:
|
||||
# intersect, union, except
|
||||
end = True
|
||||
elif isinstance(c_instance, Filter):
|
||||
if 'where' not in transformed_sql:
|
||||
end = True
|
||||
else:
|
||||
num_conjunction = 0
|
||||
for f in transformed_sql['where']:
|
||||
if isinstance(f, str) and (f == 'and' or f == 'or'):
|
||||
num_conjunction += 1
|
||||
current_filters = len(transformed_sql['where'])
|
||||
valid_filters = current_filters - num_conjunction
|
||||
if valid_filters >= num_conjunction + 1:
|
||||
end = True
|
||||
elif isinstance(c_instance, Order):
|
||||
if 'order' not in transformed_sql:
|
||||
end = True
|
||||
elif len(transformed_sql['order']) == 0:
|
||||
end = False
|
||||
else:
|
||||
end = True
|
||||
elif isinstance(c_instance, Sup):
|
||||
if 'sup' not in transformed_sql:
|
||||
end = True
|
||||
elif len(transformed_sql['sup']) == 0:
|
||||
end = False
|
||||
else:
|
||||
end = True
|
||||
components.insert(0, c)
|
||||
return end
|
||||
|
||||
|
||||
def _transform(components, transformed_sql, col_set, table_names, schema):
|
||||
processed_root = False
|
||||
current_table = schema
|
||||
|
||||
while len(components) > 0:
|
||||
if is_end(components, transformed_sql, processed_root):
|
||||
break
|
||||
c = pop_front(components)
|
||||
c_instance = eval(c)
|
||||
if isinstance(c_instance, Root):
|
||||
processed_root = True
|
||||
transformed_sql['select'] = list()
|
||||
if c_instance.id_c == 0:
|
||||
transformed_sql['where'] = list()
|
||||
transformed_sql['sup'] = list()
|
||||
elif c_instance.id_c == 1:
|
||||
transformed_sql['where'] = list()
|
||||
transformed_sql['order'] = list()
|
||||
elif c_instance.id_c == 2:
|
||||
transformed_sql['sup'] = list()
|
||||
elif c_instance.id_c == 3:
|
||||
transformed_sql['where'] = list()
|
||||
elif c_instance.id_c == 4:
|
||||
transformed_sql['order'] = list()
|
||||
elif isinstance(c_instance, Sel):
|
||||
continue
|
||||
elif isinstance(c_instance, N):
|
||||
for i in range(c_instance.id_c + 1):
|
||||
agg = eval(pop_front(components))
|
||||
column = eval(pop_front(components))
|
||||
_table = pop_front(components)
|
||||
table = eval(_table)
|
||||
if not isinstance(table, T):
|
||||
table = None
|
||||
components.insert(0, _table)
|
||||
assert isinstance(agg, A) and isinstance(column, C)
|
||||
|
||||
transformed_sql['select'].append((
|
||||
agg.production.split()[1],
|
||||
replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) if table is not None else col_set[column.id_c],
|
||||
table_names[table.id_c] if table is not None else table
|
||||
))
|
||||
|
||||
elif isinstance(c_instance, Sup):
|
||||
transformed_sql['sup'].append(c_instance.production.split()[1])
|
||||
agg = eval(pop_front(components))
|
||||
column = eval(pop_front(components))
|
||||
_table = pop_front(components)
|
||||
table = eval(_table)
|
||||
if not isinstance(table, T):
|
||||
table = None
|
||||
components.insert(0, _table)
|
||||
assert isinstance(agg, A) and isinstance(column, C)
|
||||
|
||||
transformed_sql['sup'].append(agg.production.split()[1])
|
||||
if table:
|
||||
fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)
|
||||
else:
|
||||
fix_col_id = col_set[column.id_c]
|
||||
raise RuntimeError('not found table !!!!')
|
||||
transformed_sql['sup'].append(fix_col_id)
|
||||
transformed_sql['sup'].append(table_names[table.id_c] if table is not None else table)
|
||||
|
||||
elif isinstance(c_instance, Order):
|
||||
transformed_sql['order'].append(c_instance.production.split()[1])
|
||||
agg = eval(pop_front(components))
|
||||
column = eval(pop_front(components))
|
||||
_table = pop_front(components)
|
||||
table = eval(_table)
|
||||
if not isinstance(table, T):
|
||||
table = None
|
||||
components.insert(0, _table)
|
||||
assert isinstance(agg, A) and isinstance(column, C)
|
||||
transformed_sql['order'].append(agg.production.split()[1])
|
||||
transformed_sql['order'].append(replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table))
|
||||
transformed_sql['order'].append(table_names[table.id_c] if table is not None else table)
|
||||
|
||||
elif isinstance(c_instance, Filter):
|
||||
op = c_instance.production.split()[1]
|
||||
if op == 'and' or op == 'or':
|
||||
transformed_sql['where'].append(op)
|
||||
else:
|
||||
# No Supquery
|
||||
agg = eval(pop_front(components))
|
||||
column = eval(pop_front(components))
|
||||
_table = pop_front(components)
|
||||
table = eval(_table)
|
||||
if not isinstance(table, T):
|
||||
table = None
|
||||
components.insert(0, _table)
|
||||
assert isinstance(agg, A) and isinstance(column, C)
|
||||
if len(c_instance.production.split()) == 3:
|
||||
if table:
|
||||
fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)
|
||||
else:
|
||||
fix_col_id = col_set[column.id_c]
|
||||
raise RuntimeError('not found table !!!!')
|
||||
transformed_sql['where'].append((
|
||||
op,
|
||||
agg.production.split()[1],
|
||||
fix_col_id,
|
||||
table_names[table.id_c] if table is not None else table,
|
||||
None
|
||||
))
|
||||
else:
|
||||
# Subquery
|
||||
new_dict = dict()
|
||||
new_dict['sql'] = transformed_sql['sql']
|
||||
transformed_sql['where'].append((
|
||||
op,
|
||||
agg.production.split()[1],
|
||||
replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table),
|
||||
table_names[table.id_c] if table is not None else table,
|
||||
_transform(components, new_dict, col_set, table_names, schema)
|
||||
))
|
||||
|
||||
return transformed_sql
|
||||
|
||||
|
||||
def transform(query, schema, origin=None):
|
||||
preprocess_schema(schema)
|
||||
if origin is None:
|
||||
lf = query['model_result_replace']
|
||||
else:
|
||||
lf = origin
|
||||
# lf = query['rule_label']
|
||||
col_set = query['col_set']
|
||||
table_names = query['table_names']
|
||||
current_table = schema
|
||||
|
||||
current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']]
|
||||
current_table['schema_content'] = [x[1] for x in current_table['column_names_original']]
|
||||
|
||||
components = split_logical_form(lf)
|
||||
|
||||
transformed_sql = dict()
|
||||
transformed_sql['sql'] = query
|
||||
c = pop_front(components)
|
||||
c_instance = eval(c)
|
||||
assert isinstance(c_instance, Root1)
|
||||
if c_instance.id_c == 0:
|
||||
transformed_sql['intersect'] = dict()
|
||||
transformed_sql['intersect']['sql'] = query
|
||||
|
||||
_transform(components, transformed_sql, col_set, table_names, schema)
|
||||
_transform(components, transformed_sql['intersect'], col_set, table_names, schema)
|
||||
elif c_instance.id_c == 1:
|
||||
transformed_sql['union'] = dict()
|
||||
transformed_sql['union']['sql'] = query
|
||||
_transform(components, transformed_sql, col_set, table_names, schema)
|
||||
_transform(components, transformed_sql['union'], col_set, table_names, schema)
|
||||
elif c_instance.id_c == 2:
|
||||
transformed_sql['except'] = dict()
|
||||
transformed_sql['except']['sql'] = query
|
||||
_transform(components, transformed_sql, col_set, table_names, schema)
|
||||
_transform(components, transformed_sql['except'], col_set, table_names, schema)
|
||||
else:
|
||||
_transform(components, transformed_sql, col_set, table_names, schema)
|
||||
|
||||
parse_result = to_str(transformed_sql, 1, schema)
|
||||
|
||||
parse_result = parse_result.replace('\t', '')
|
||||
return [parse_result]
|
||||
|
||||
def col_to_str(agg, col, tab, table_names, N=1):
|
||||
_col = col.replace(' ', '_')
|
||||
if agg == 'none':
|
||||
if tab not in table_names:
|
||||
table_names[tab] = 'T' + str(len(table_names) + N)
|
||||
table_alias = table_names[tab]
|
||||
if col == '*':
|
||||
return '*'
|
||||
return '%s.%s' % (table_alias, _col)
|
||||
else:
|
||||
if col == '*':
|
||||
if tab is not None and tab not in table_names:
|
||||
table_names[tab] = 'T' + str(len(table_names) + N)
|
||||
return '%s(%s)' % (agg, _col)
|
||||
else:
|
||||
if tab not in table_names:
|
||||
table_names[tab] = 'T' + str(len(table_names) + N)
|
||||
table_alias = table_names[tab]
|
||||
return '%s(%s.%s)' % (agg, table_alias, _col)
|
||||
|
||||
|
||||
def infer_from_clause(table_names, schema, columns):
|
||||
tables = list(table_names.keys())
|
||||
# print(table_names)
|
||||
start_table = None
|
||||
end_table = None
|
||||
join_clause = list()
|
||||
if len(tables) == 1:
|
||||
join_clause.append((tables[0], table_names[tables[0]]))
|
||||
elif len(tables) == 2:
|
||||
use_graph = True
|
||||
# print(schema['graph'].vertices)
|
||||
for t in tables:
|
||||
if t not in schema['graph'].vertices:
|
||||
use_graph = False
|
||||
break
|
||||
if use_graph:
|
||||
start_table = tables[0]
|
||||
end_table = tables[1]
|
||||
_tables = list(schema['graph'].dijkstra(tables[0], tables[1]))
|
||||
# print('Two tables: ', _tables)
|
||||
max_key = 1
|
||||
for t, k in table_names.items():
|
||||
_k = int(k[1:])
|
||||
if _k > max_key:
|
||||
max_key = _k
|
||||
for t in _tables:
|
||||
if t not in table_names:
|
||||
table_names[t] = 'T' + str(max_key + 1)
|
||||
max_key += 1
|
||||
join_clause.append((t, table_names[t],))
|
||||
else:
|
||||
join_clause = list()
|
||||
for t in tables:
|
||||
join_clause.append((t, table_names[t],))
|
||||
else:
|
||||
# > 2
|
||||
# print('More than 2 table')
|
||||
for t in tables:
|
||||
join_clause.append((t, table_names[t],))
|
||||
|
||||
if len(join_clause) >= 3:
|
||||
star_table = None
|
||||
for agg, col, tab in columns:
|
||||
if col == '*':
|
||||
star_table = tab
|
||||
break
|
||||
if star_table is not None:
|
||||
star_table_count = 0
|
||||
for agg, col, tab in columns:
|
||||
if tab == star_table and col != '*':
|
||||
star_table_count += 1
|
||||
if star_table_count == 0 and ((end_table is None or end_table == star_table) or (start_table is None or start_table == star_table)):
|
||||
# Remove the table the rest tables still can join without star_table
|
||||
new_join_clause = list()
|
||||
for t in join_clause:
|
||||
if t[0] != star_table:
|
||||
new_join_clause.append(t)
|
||||
join_clause = new_join_clause
|
||||
|
||||
join_clause = ' JOIN '.join(['%s AS %s' % (jc[0], jc[1]) for jc in join_clause])
|
||||
return 'FROM ' + join_clause
|
||||
|
||||
def replace_col_with_original_col(query, col, current_table):
|
||||
# print(query, col)
|
||||
if query == '*':
|
||||
return query
|
||||
|
||||
cur_table = col
|
||||
cur_col = query
|
||||
single_final_col = None
|
||||
# print(query, col)
|
||||
for col_ind, col_name in enumerate(current_table['schema_content_clean']):
|
||||
if col_name == cur_col:
|
||||
assert cur_table in current_table['table_names']
|
||||
if current_table['table_names'][current_table['col_table'][col_ind]] == cur_table:
|
||||
single_final_col = current_table['column_names_original'][col_ind][1]
|
||||
break
|
||||
|
||||
assert single_final_col
|
||||
# if query != single_final_col:
|
||||
# print(query, single_final_col)
|
||||
return single_final_col
|
||||
|
||||
|
||||
def build_graph(schema):
|
||||
relations = list()
|
||||
foreign_keys = schema['foreign_keys']
|
||||
for (fkey, pkey) in foreign_keys:
|
||||
fkey_table = schema['table_names_original'][schema['column_names'][fkey][0]]
|
||||
pkey_table = schema['table_names_original'][schema['column_names'][pkey][0]]
|
||||
relations.append((fkey_table, pkey_table))
|
||||
relations.append((pkey_table, fkey_table))
|
||||
return Graph(relations)
|
||||
|
||||
|
||||
def preprocess_schema(schema):
|
||||
tmp_col = []
|
||||
for cc in [x[1] for x in schema['column_names']]:
|
||||
if cc not in tmp_col:
|
||||
tmp_col.append(cc)
|
||||
schema['col_set'] = tmp_col
|
||||
# print table
|
||||
schema['schema_content'] = [col[1] for col in schema['column_names']]
|
||||
schema['col_table'] = [col[0] for col in schema['column_names']]
|
||||
graph = build_graph(schema)
|
||||
schema['graph'] = graph
|
||||
|
||||
|
||||
|
||||
|
||||
def to_str(sql_json, N_T, schema, pre_table_names=None):
|
||||
all_columns = list()
|
||||
select_clause = list()
|
||||
table_names = dict()
|
||||
current_table = schema
|
||||
for (agg, col, tab) in sql_json['select']:
|
||||
all_columns.append((agg, col, tab))
|
||||
select_clause.append(col_to_str(agg, col, tab, table_names, N_T))
|
||||
select_clause_str = 'SELECT ' + ', '.join(select_clause).strip()
|
||||
|
||||
sup_clause = ''
|
||||
order_clause = ''
|
||||
direction_map = {"des": 'DESC', 'asc': 'ASC'}
|
||||
|
||||
if 'sup' in sql_json:
|
||||
(direction, agg, col, tab,) = sql_json['sup']
|
||||
all_columns.append((agg, col, tab))
|
||||
subject = col_to_str(agg, col, tab, table_names, N_T)
|
||||
sup_clause = ('ORDER BY %s %s LIMIT 1' % (subject, direction_map[direction])).strip()
|
||||
elif 'order' in sql_json:
|
||||
(direction, agg, col, tab,) = sql_json['order']
|
||||
all_columns.append((agg, col, tab))
|
||||
subject = col_to_str(agg, col, tab, table_names, N_T)
|
||||
order_clause = ('ORDER BY %s %s' % (subject, direction_map[direction])).strip()
|
||||
|
||||
has_group_by = False
|
||||
where_clause = ''
|
||||
have_clause = ''
|
||||
if 'where' in sql_json:
|
||||
conjunctions = list()
|
||||
filters = list()
|
||||
# print(sql_json['where'])
|
||||
for f in sql_json['where']:
|
||||
if isinstance(f, str):
|
||||
conjunctions.append(f)
|
||||
else:
|
||||
op, agg, col, tab, value = f
|
||||
if value:
|
||||
value['sql'] = sql_json['sql']
|
||||
all_columns.append((agg, col, tab))
|
||||
subject = col_to_str(agg, col, tab, table_names, N_T)
|
||||
if value is None:
|
||||
where_value = '1'
|
||||
if op == 'between':
|
||||
where_value = '1 AND 2'
|
||||
filters.append('%s %s %s' % (subject, op, where_value))
|
||||
else:
|
||||
if op == 'in' and len(value['select']) == 1 and value['select'][0][0] == 'none' \
|
||||
and 'where' not in value and 'order' not in value and 'sup' not in value:
|
||||
# and value['select'][0][2] not in table_names:
|
||||
if value['select'][0][2] not in table_names:
|
||||
table_names[value['select'][0][2]] = 'T' + str(len(table_names) + N_T)
|
||||
filters.append(None)
|
||||
|
||||
else:
|
||||
filters.append('%s %s %s' % (subject, op, '(' + to_str(value, len(table_names) + 1, schema) + ')'))
|
||||
if len(conjunctions):
|
||||
filters.append(conjunctions.pop())
|
||||
|
||||
aggs = ['count(', 'avg(', 'min(', 'max(', 'sum(']
|
||||
having_filters = list()
|
||||
idx = 0
|
||||
while idx < len(filters):
|
||||
_filter = filters[idx]
|
||||
if _filter is None:
|
||||
idx += 1
|
||||
continue
|
||||
for agg in aggs:
|
||||
if _filter.startswith(agg):
|
||||
having_filters.append(_filter)
|
||||
filters.pop(idx)
|
||||
# print(filters)
|
||||
if 0 < idx and (filters[idx - 1] in ['and', 'or']):
|
||||
filters.pop(idx - 1)
|
||||
# print(filters)
|
||||
break
|
||||
else:
|
||||
idx += 1
|
||||
if len(having_filters) > 0:
|
||||
have_clause = 'HAVING ' + ' '.join(having_filters).strip()
|
||||
if len(filters) > 0:
|
||||
# print(filters)
|
||||
filters = [_f for _f in filters if _f is not None]
|
||||
conjun_num = 0
|
||||
filter_num = 0
|
||||
for _f in filters:
|
||||
if _f in ['or', 'and']:
|
||||
conjun_num += 1
|
||||
else:
|
||||
filter_num += 1
|
||||
if conjun_num > 0 and filter_num != (conjun_num + 1):
|
||||
# assert 'and' in filters
|
||||
idx = 0
|
||||
while idx < len(filters):
|
||||
if filters[idx] == 'and':
|
||||
if idx - 1 == 0:
|
||||
filters.pop(idx)
|
||||
break
|
||||
if filters[idx - 1] in ['and', 'or']:
|
||||
filters.pop(idx)
|
||||
break
|
||||
if idx + 1 >= len(filters) - 1:
|
||||
filters.pop(idx)
|
||||
break
|
||||
if filters[idx + 1] in ['and', 'or']:
|
||||
filters.pop(idx)
|
||||
break
|
||||
idx += 1
|
||||
if len(filters) > 0:
|
||||
where_clause = 'WHERE ' + ' '.join(filters).strip()
|
||||
where_clause = where_clause.replace('not_in', 'NOT IN')
|
||||
else:
|
||||
where_clause = ''
|
||||
|
||||
if len(having_filters) > 0:
|
||||
has_group_by = True
|
||||
|
||||
for agg in ['count(', 'avg(', 'min(', 'max(', 'sum(']:
|
||||
if (len(sql_json['select']) > 1 and agg in select_clause_str)\
|
||||
or agg in sup_clause or agg in order_clause:
|
||||
has_group_by = True
|
||||
break
|
||||
|
||||
group_by_clause = ''
|
||||
if has_group_by:
|
||||
if len(table_names) == 1:
|
||||
# check none agg
|
||||
is_agg_flag = False
|
||||
for (agg, col, tab) in sql_json['select']:
|
||||
|
||||
if agg == 'none':
|
||||
group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
|
||||
else:
|
||||
is_agg_flag = True
|
||||
|
||||
if is_agg_flag is False and len(group_by_clause) > 5:
|
||||
group_by_clause = "GROUP BY"
|
||||
for (agg, col, tab) in sql_json['select']:
|
||||
group_by_clause = group_by_clause + ' ' + col_to_str(agg, col, tab, table_names, N_T)
|
||||
|
||||
if len(group_by_clause) < 5:
|
||||
if 'count(*)' in select_clause_str:
|
||||
current_table = schema
|
||||
for primary in current_table['primary_keys']:
|
||||
if current_table['table_names'][current_table['col_table'][primary]] in table_names :
|
||||
group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][primary],
|
||||
current_table['table_names'][
|
||||
current_table['col_table'][primary]],
|
||||
table_names, N_T)
|
||||
else:
|
||||
# if only one select
|
||||
if len(sql_json['select']) == 1:
|
||||
agg, col, tab = sql_json['select'][0]
|
||||
non_lists = [tab]
|
||||
fix_flag = False
|
||||
# add tab from other part
|
||||
for key, value in table_names.items():
|
||||
if key not in non_lists:
|
||||
non_lists.append(key)
|
||||
|
||||
a = non_lists[0]
|
||||
b = None
|
||||
for non in non_lists:
|
||||
if a != non:
|
||||
b = non
|
||||
if b:
|
||||
for pair in current_table['foreign_keys']:
|
||||
t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
|
||||
t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
|
||||
if t1 in [a, b] and t2 in [a, b]:
|
||||
if pre_table_names and t1 not in pre_table_names:
|
||||
assert t2 in pre_table_names
|
||||
t1 = t2
|
||||
group_by_clause = 'GROUP BY ' + col_to_str('none',
|
||||
current_table['schema_content'][pair[0]],
|
||||
t1,
|
||||
table_names, N_T)
|
||||
fix_flag = True
|
||||
break
|
||||
|
||||
if fix_flag is False:
|
||||
agg, col, tab = sql_json['select'][0]
|
||||
group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
|
||||
|
||||
else:
|
||||
# check if there are only one non agg
|
||||
non_agg, non_agg_count = None, 0
|
||||
non_lists = []
|
||||
for (agg, col, tab) in sql_json['select']:
|
||||
if agg == 'none':
|
||||
non_agg = (agg, col, tab)
|
||||
non_lists.append(tab)
|
||||
non_agg_count += 1
|
||||
|
||||
non_lists = list(set(non_lists))
|
||||
# print(non_lists)
|
||||
if non_agg_count == 1:
|
||||
group_by_clause = 'GROUP BY ' + col_to_str(non_agg[0], non_agg[1], non_agg[2], table_names, N_T)
|
||||
elif non_agg:
|
||||
find_flag = False
|
||||
fix_flag = False
|
||||
find_primary = None
|
||||
if len(non_lists) <= 1:
|
||||
for key, value in table_names.items():
|
||||
if key not in non_lists:
|
||||
non_lists.append(key)
|
||||
if len(non_lists) > 1:
|
||||
a = non_lists[0]
|
||||
b = None
|
||||
for non in non_lists:
|
||||
if a != non:
|
||||
b = non
|
||||
if b:
|
||||
for pair in current_table['foreign_keys']:
|
||||
t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
|
||||
t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
|
||||
if t1 in [a, b] and t2 in [a, b]:
|
||||
if pre_table_names and t1 not in pre_table_names:
|
||||
assert t2 in pre_table_names
|
||||
t1 = t2
|
||||
group_by_clause = 'GROUP BY ' + col_to_str('none',
|
||||
current_table['schema_content'][pair[0]],
|
||||
t1,
|
||||
table_names, N_T)
|
||||
fix_flag = True
|
||||
break
|
||||
tab = non_agg[2]
|
||||
assert tab in current_table['table_names']
|
||||
|
||||
for primary in current_table['primary_keys']:
|
||||
if current_table['table_names'][current_table['col_table'][primary]] == tab:
|
||||
find_flag = True
|
||||
find_primary = (current_table['schema_content'][primary], tab)
|
||||
if fix_flag is False:
|
||||
if find_flag is False:
|
||||
# rely on count *
|
||||
foreign = []
|
||||
for pair in current_table['foreign_keys']:
|
||||
if current_table['table_names'][current_table['col_table'][pair[0]]] == tab:
|
||||
foreign.append(pair[1])
|
||||
if current_table['table_names'][current_table['col_table'][pair[1]]] == tab:
|
||||
foreign.append(pair[0])
|
||||
|
||||
for pair in foreign:
|
||||
if current_table['table_names'][current_table['col_table'][pair]] in table_names:
|
||||
group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][pair],
|
||||
current_table['table_names'][current_table['col_table'][pair]],
|
||||
table_names, N_T)
|
||||
find_flag = True
|
||||
break
|
||||
if find_flag is False:
|
||||
for (agg, col, tab) in sql_json['select']:
|
||||
if 'id' in col.lower():
|
||||
group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
|
||||
break
|
||||
if len(group_by_clause) > 5:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError('fail to convert')
|
||||
else:
|
||||
group_by_clause = 'GROUP BY ' + col_to_str('none', find_primary[0],
|
||||
find_primary[1],
|
||||
table_names, N_T)
|
||||
intersect_clause = ''
|
||||
if 'intersect' in sql_json:
|
||||
sql_json['intersect']['sql'] = sql_json['sql']
|
||||
intersect_clause = 'INTERSECT ' + to_str(sql_json['intersect'], len(table_names) + 1, schema, table_names)
|
||||
union_clause = ''
|
||||
if 'union' in sql_json:
|
||||
sql_json['union']['sql'] = sql_json['sql']
|
||||
union_clause = 'UNION ' + to_str(sql_json['union'], len(table_names) + 1, schema, table_names)
|
||||
except_clause = ''
|
||||
if 'except' in sql_json:
|
||||
sql_json['except']['sql'] = sql_json['sql']
|
||||
except_clause = 'EXCEPT ' + to_str(sql_json['except'], len(table_names) + 1, schema, table_names)
|
||||
|
||||
# print(current_table['table_names_original'])
|
||||
table_names_replace = {}
|
||||
for a, b in zip(current_table['table_names_original'], current_table['table_names']):
|
||||
table_names_replace[b] = a
|
||||
new_table_names = {}
|
||||
for key, value in table_names.items():
|
||||
if key is None:
|
||||
continue
|
||||
new_table_names[table_names_replace[key]] = value
|
||||
from_clause = infer_from_clause(new_table_names, schema, all_columns).strip()
|
||||
|
||||
sql = ' '.join([select_clause_str, from_clause, where_clause, group_by_clause, have_clause, sup_clause, order_clause,
|
||||
intersect_clause, union_clause, except_clause])
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--data_path', type=str, help='dataset path', required=True)
|
||||
arg_parser.add_argument('--input_path', type=str, help='predicted logical form', required=True)
|
||||
arg_parser.add_argument('--output_path', type=str, help='output data')
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# loading dataSets
|
||||
datas, schemas = load_dataSets(args)
|
||||
alter_not_in(datas, schemas=schemas)
|
||||
alter_inter(datas)
|
||||
alter_column0(datas)
|
||||
|
||||
|
||||
index = range(len(datas))
|
||||
count = 0
|
||||
exception_count = 0
|
||||
with open(args.output_path, 'w', encoding='utf8') as d:
|
||||
for i in index:
|
||||
try:
|
||||
result = transform(datas[i], schemas[datas[i]['db_id']])
|
||||
d.write(result[0] + '\n')
|
||||
count += 1
|
||||
except Exception as e:
|
||||
result = transform(datas[i], schemas[datas[i]['db_id']], origin='Root1(3) Root(5) Sel(0) N(0) A(3) C(0) T(0)')
|
||||
exception_count += 1
|
||||
d.write(result[0] + '\n')
|
||||
count += 1
|
||||
print(e)
|
||||
print('Exception')
|
||||
print(traceback.format_exc())
|
||||
print('===\n\n')
|
||||
|
||||
print(count, exception_count)
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : __init__.py
|
||||
# @Software: PyCharm
|
||||
"""
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : args.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import random
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def init_arg_parser():
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument('--seed', default=5783287, type=int, help='random seed')
|
||||
arg_parser.add_argument('--cuda', action='store_true', help='use gpu')
|
||||
arg_parser.add_argument('--lr_scheduler', action='store_true', help='use learning rate scheduler')
|
||||
arg_parser.add_argument('--lr_scheduler_gammar', default=0.5, type=float, help='decay rate of learning rate scheduler')
|
||||
arg_parser.add_argument('--column_pointer', action='store_true', help='use column pointer')
|
||||
arg_parser.add_argument('--loss_epoch_threshold', default=20, type=int, help='loss epoch threshold')
|
||||
arg_parser.add_argument('--sketch_loss_coefficient', default=0.2, type=float, help='sketch loss coefficient')
|
||||
arg_parser.add_argument('--sentence_features', action='store_true', help='use sentence features')
|
||||
arg_parser.add_argument('--model_name', choices=['transformer', 'rnn', 'table', 'sketch'], default='rnn',
|
||||
help='model name')
|
||||
|
||||
arg_parser.add_argument('--lstm', choices=['lstm', 'lstm_with_dropout', 'parent_feed'], default='lstm')
|
||||
|
||||
arg_parser.add_argument('--load_model', default=None, type=str, help='load a pre-trained model')
|
||||
arg_parser.add_argument('--glove_embed_path', default="glove.42B.300d.txt", type=str)
|
||||
|
||||
arg_parser.add_argument('--batch_size', default=64, type=int, help='batch size')
|
||||
arg_parser.add_argument('--beam_size', default=5, type=int, help='beam size for beam search')
|
||||
arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings')
|
||||
arg_parser.add_argument('--col_embed_size', default=300, type=int, help='size of word embeddings')
|
||||
|
||||
arg_parser.add_argument('--action_embed_size', default=128, type=int, help='size of word embeddings')
|
||||
arg_parser.add_argument('--type_embed_size', default=128, type=int, help='size of word embeddings')
|
||||
arg_parser.add_argument('--hidden_size', default=100, type=int, help='size of LSTM hidden states')
|
||||
arg_parser.add_argument('--att_vec_size', default=100, type=int, help='size of attentional vector')
|
||||
arg_parser.add_argument('--dropout', default=0.3, type=float, help='dropout rate')
|
||||
arg_parser.add_argument('--word_dropout', default=0.2, type=float, help='word dropout rate')
|
||||
|
||||
# readout layer
|
||||
arg_parser.add_argument('--no_query_vec_to_action_map', default=False, action='store_true')
|
||||
arg_parser.add_argument('--readout', default='identity', choices=['identity', 'non_linear'])
|
||||
arg_parser.add_argument('--query_vec_to_action_diff_map', default=False, action='store_true')
|
||||
|
||||
|
||||
arg_parser.add_argument('--column_att', choices=['dot_prod', 'affine'], default='affine')
|
||||
|
||||
arg_parser.add_argument('--decode_max_time_step', default=40, type=int, help='maximum number of time steps used '
|
||||
'in decoding and sampling')
|
||||
|
||||
|
||||
arg_parser.add_argument('--save_to', default='model', type=str, help='save trained model to')
|
||||
arg_parser.add_argument('--toy', action='store_true',
|
||||
help='If set, use small data; used for fast debugging.')
|
||||
arg_parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients')
|
||||
arg_parser.add_argument('--max_epoch', default=-1, type=int, help='maximum number of training epoches')
|
||||
arg_parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer')
|
||||
arg_parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
|
||||
|
||||
arg_parser.add_argument('--dataset', default="./data", type=str)
|
||||
|
||||
arg_parser.add_argument('--epoch', default=50, type=int, help='Maximum Epoch')
|
||||
arg_parser.add_argument('--save', default='./', type=str,
|
||||
help="Path to save the checkpoint and logs of epoch")
|
||||
|
||||
return arg_parser
|
||||
|
||||
def init_config(arg_parser):
|
||||
args = arg_parser.parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
if args.cuda:
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
np.random.seed(int(args.seed * 13 / 7))
|
||||
random.seed(int(args.seed))
|
||||
return args
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : beam.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
from src.rule import semQL
|
||||
|
||||
|
||||
class ActionInfo(object):
|
||||
"""sufficient statistics for making a prediction of an action at a time step"""
|
||||
|
||||
def __init__(self, action=None):
|
||||
self.t = 0
|
||||
self.score = 0
|
||||
self.parent_t = -1
|
||||
self.action = action
|
||||
self.frontier_prod = None
|
||||
self.frontier_field = None
|
||||
|
||||
# for GenToken actions only
|
||||
self.copy_from_src = False
|
||||
self.src_token_position = -1
|
||||
|
||||
|
||||
class Beams(object):
|
||||
def __init__(self, is_sketch=False):
|
||||
self.actions = []
|
||||
self.action_infos = []
|
||||
self.inputs = []
|
||||
self.score = 0.
|
||||
self.t = 0
|
||||
self.is_sketch = is_sketch
|
||||
self.sketch_step = 0
|
||||
self.sketch_attention_history = list()
|
||||
|
||||
def get_availableClass(self):
|
||||
"""
|
||||
return the available action class
|
||||
:return:
|
||||
"""
|
||||
|
||||
# TODO: it could be update by speed
|
||||
# return the available class using rule
|
||||
# FIXME: now should change for these 11: "Filter 1 ROOT",
|
||||
def check_type(lists):
|
||||
for s in lists:
|
||||
if type(s) == int:
|
||||
return False
|
||||
return True
|
||||
|
||||
stack = [semQL.Root1]
|
||||
for action in self.actions:
|
||||
infer_action = action.get_next_action(is_sketch=self.is_sketch)
|
||||
infer_action.reverse()
|
||||
if stack[-1] is type(action):
|
||||
stack.pop()
|
||||
# check if the are non-terminal
|
||||
if check_type(infer_action):
|
||||
stack.extend(infer_action)
|
||||
else:
|
||||
raise RuntimeError("Not the right action")
|
||||
|
||||
result = stack[-1] if len(stack) > 0 else None
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_parent_action(cls, actions):
|
||||
"""
|
||||
|
||||
:param actions:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def check_type(lists):
|
||||
for s in lists:
|
||||
if type(s) == int:
|
||||
return False
|
||||
return True
|
||||
|
||||
# check the origin state Root
|
||||
if len(actions) == 0:
|
||||
return None
|
||||
|
||||
stack = [semQL.Root1]
|
||||
for id_x, action in enumerate(actions):
|
||||
infer_action = action.get_next_action()
|
||||
for ac in infer_action:
|
||||
ac.parent = action
|
||||
ac.pt = id_x
|
||||
infer_action.reverse()
|
||||
if stack[-1] is type(action):
|
||||
stack.pop()
|
||||
# check if the are non-terminal
|
||||
if check_type(infer_action):
|
||||
stack.extend(infer_action)
|
||||
else:
|
||||
for t in actions:
|
||||
if type(t) != semQL.C:
|
||||
print(t, end="")
|
||||
print('asd')
|
||||
print(action)
|
||||
print(stack[-1])
|
||||
raise RuntimeError("Not the right action")
|
||||
result = stack[-1] if len(stack) > 0 else None
|
||||
|
||||
return result
|
||||
|
||||
def apply_action(self, action):
|
||||
# TODO: not finish implement yet
|
||||
self.t += 1
|
||||
self.actions.append(action)
|
||||
|
||||
def clone_and_apply_action(self, action):
|
||||
new_hyp = self.copy()
|
||||
new_hyp.apply_action(action)
|
||||
|
||||
return new_hyp
|
||||
|
||||
def clone_and_apply_action_info(self, action_info):
|
||||
action = action_info.action
|
||||
action.score = action_info.score
|
||||
new_hyp = self.clone_and_apply_action(action)
|
||||
new_hyp.action_infos.append(action_info)
|
||||
new_hyp.sketch_step = self.sketch_step
|
||||
new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)
|
||||
|
||||
return new_hyp
|
||||
|
||||
def copy(self):
|
||||
new_hyp = Beams(is_sketch=self.is_sketch)
|
||||
# if self.tree:
|
||||
# new_hyp.tree = self.tree.copy()
|
||||
|
||||
new_hyp.actions = list(self.actions)
|
||||
new_hyp.score = self.score
|
||||
new_hyp.t = self.t
|
||||
new_hyp.sketch_step = self.sketch_step
|
||||
new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)
|
||||
|
||||
return new_hyp
|
||||
|
||||
def infer_n(self):
|
||||
if len(self.actions) > 4:
|
||||
prev_action = self.actions[-3]
|
||||
if isinstance(prev_action, semQL.Filter):
|
||||
if prev_action.id_c > 11:
|
||||
# Nested Query, only select 1 column
|
||||
return ['N A']
|
||||
if self.actions[0].id_c != 3:
|
||||
return [self.actions[3].production]
|
||||
return semQL.N._init_grammar()
|
||||
|
||||
@property
|
||||
def completed(self):
|
||||
return True if self.get_availableClass() is None else False
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
actions = self.actions
|
||||
return self.check_sel_valid(actions)
|
||||
|
||||
def check_sel_valid(self, actions):
|
||||
find_sel = False
|
||||
sel_actions = list()
|
||||
for ac in actions:
|
||||
if type(ac) == semQL.Sel:
|
||||
find_sel = True
|
||||
elif find_sel and type(ac) in [semQL.N, semQL.T, semQL.C, semQL.A]:
|
||||
if type(ac) not in [semQL.N]:
|
||||
sel_actions.append(ac)
|
||||
elif find_sel and type(ac) not in [semQL.N, semQL.T, semQL.C, semQL.A]:
|
||||
break
|
||||
|
||||
if find_sel is False:
|
||||
return True
|
||||
|
||||
# not the complete sel lf
|
||||
if len(sel_actions) % 3 != 0:
|
||||
return True
|
||||
|
||||
sel_string = list()
|
||||
for i in range(len(sel_actions) // 3):
|
||||
if (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c) in sel_string:
|
||||
return False
|
||||
else:
|
||||
sel_string.append(
|
||||
(sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c))
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test = Beams(is_sketch=True)
|
||||
# print(semQL.Root1(1).get_next_action())
|
||||
test.actions.append(semQL.Root1(3))
|
||||
test.actions.append(semQL.Root(5))
|
||||
|
||||
print(str(test.get_availableClass()))
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : utils.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
import src.rule.semQL as define_rule
|
||||
from src.models import nn_utils
|
||||
|
||||
|
||||
class Example:
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, src_sent, tgt_actions=None, vis_seq=None, tab_cols=None, col_num=None, sql=None,
|
||||
one_hot_type=None, col_hot_type=None, schema_len=None, tab_ids=None,
|
||||
table_names=None, table_len=None, col_table_dict=None, cols=None,
|
||||
table_col_name=None, table_col_len=None,
|
||||
col_pred=None, tokenized_src_sent=None,
|
||||
):
|
||||
|
||||
self.src_sent = src_sent
|
||||
self.tokenized_src_sent = tokenized_src_sent
|
||||
self.vis_seq = vis_seq
|
||||
self.tab_cols = tab_cols
|
||||
self.col_num = col_num
|
||||
self.sql = sql
|
||||
self.one_hot_type=one_hot_type
|
||||
self.col_hot_type = col_hot_type
|
||||
self.schema_len = schema_len
|
||||
self.tab_ids = tab_ids
|
||||
self.table_names = table_names
|
||||
self.table_len = table_len
|
||||
self.col_table_dict = col_table_dict
|
||||
self.cols = cols
|
||||
self.table_col_name = table_col_name
|
||||
self.table_col_len = table_col_len
|
||||
self.col_pred = col_pred
|
||||
self.tgt_actions = tgt_actions
|
||||
self.truth_actions = copy.deepcopy(tgt_actions)
|
||||
|
||||
self.sketch = list()
|
||||
if self.truth_actions:
|
||||
for ta in self.truth_actions:
|
||||
if isinstance(ta, define_rule.C) or isinstance(ta, define_rule.T) or isinstance(ta, define_rule.A):
|
||||
continue
|
||||
self.sketch.append(ta)
|
||||
|
||||
|
||||
class cached_property(object):
|
||||
""" A property that is only computed once per instance and then replaces
|
||||
itself with an ordinary attribute. Deleting the attribute resets the
|
||||
property.
|
||||
|
||||
Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.__doc__ = getattr(func, '__doc__')
|
||||
self.func = func
|
||||
|
||||
def __get__(self, obj, cls):
|
||||
if obj is None:
|
||||
return self
|
||||
value = obj.__dict__[self.func.__name__] = self.func(obj)
|
||||
return value
|
||||
|
||||
|
||||
class Batch(object):
|
||||
def __init__(self, examples, grammar, cuda=False):
|
||||
self.examples = examples
|
||||
|
||||
if examples[0].tgt_actions:
|
||||
self.max_action_num = max(len(e.tgt_actions) for e in self.examples)
|
||||
self.max_sketch_num = max(len(e.sketch) for e in self.examples)
|
||||
|
||||
self.src_sents = [e.src_sent for e in self.examples]
|
||||
self.src_sents_len = [len(e.src_sent) for e in self.examples]
|
||||
self.tokenized_src_sents = [e.tokenized_src_sent for e in self.examples]
|
||||
self.tokenized_src_sents_len = [len(e.tokenized_src_sent) for e in examples]
|
||||
self.src_sents_word = [e.src_sent for e in self.examples]
|
||||
self.table_sents_word = [[" ".join(x) for x in e.tab_cols] for e in self.examples]
|
||||
|
||||
self.schema_sents_word = [[" ".join(x) for x in e.table_names] for e in self.examples]
|
||||
|
||||
self.src_type = [e.one_hot_type for e in self.examples]
|
||||
self.col_hot_type = [e.col_hot_type for e in self.examples]
|
||||
self.table_sents = [e.tab_cols for e in self.examples]
|
||||
self.col_num = [e.col_num for e in self.examples]
|
||||
self.tab_ids = [e.tab_ids for e in self.examples]
|
||||
self.table_names = [e.table_names for e in self.examples]
|
||||
self.table_len = [e.table_len for e in examples]
|
||||
self.col_table_dict = [e.col_table_dict for e in examples]
|
||||
self.table_col_name = [e.table_col_name for e in examples]
|
||||
self.table_col_len = [e.table_col_len for e in examples]
|
||||
self.col_pred = [e.col_pred for e in examples]
|
||||
|
||||
self.grammar = grammar
|
||||
self.cuda = cuda
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
|
||||
def table_dict_mask(self, table_dict):
|
||||
return nn_utils.table_dict_to_mask_tensor(self.table_len, table_dict, cuda=self.cuda)
|
||||
|
||||
@cached_property
|
||||
def pred_col_mask(self):
|
||||
return nn_utils.pred_col_mask(self.col_pred, self.col_num)
|
||||
|
||||
@cached_property
|
||||
def schema_token_mask(self):
|
||||
return nn_utils.length_array_to_mask_tensor(self.table_len, cuda=self.cuda)
|
||||
|
||||
@cached_property
|
||||
def table_token_mask(self):
|
||||
return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda)
|
||||
|
||||
@cached_property
|
||||
def table_appear_mask(self):
|
||||
return nn_utils.appear_to_mask_tensor(self.col_num, cuda=self.cuda)
|
||||
|
||||
@cached_property
|
||||
def table_unk_mask(self):
|
||||
return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda, value=None)
|
||||
|
||||
@cached_property
|
||||
def src_token_mask(self):
|
||||
return nn_utils.length_array_to_mask_tensor(self.src_sents_len,
|
||||
cuda=self.cuda)
|
||||
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/26
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : __init__.py.py
|
||||
# @Software: PyCharm
|
||||
"""
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/26
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : basic_model.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
|
||||
|
||||
from src.rule import semQL as define_rule
|
||||
|
||||
|
||||
class BasicModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BasicModel, self).__init__()
|
||||
pass
|
||||
|
||||
def embedding_cosine(self, src_embedding, table_embedding, table_unk_mask):
|
||||
embedding_differ = []
|
||||
for i in range(table_embedding.size(1)):
|
||||
one_table_embedding = table_embedding[:, i, :]
|
||||
one_table_embedding = one_table_embedding.unsqueeze(1).expand(table_embedding.size(0),
|
||||
src_embedding.size(1),
|
||||
table_embedding.size(2))
|
||||
|
||||
topk_val = F.cosine_similarity(one_table_embedding, src_embedding, dim=-1)
|
||||
|
||||
embedding_differ.append(topk_val)
|
||||
embedding_differ = torch.stack(embedding_differ).transpose(1, 0)
|
||||
embedding_differ.data.masked_fill_(table_unk_mask.unsqueeze(2).expand(
|
||||
table_embedding.size(0),
|
||||
table_embedding.size(1),
|
||||
embedding_differ.size(2)
|
||||
), 0)
|
||||
|
||||
return embedding_differ
|
||||
|
||||
def encode(self, src_sents_var, src_sents_len, q_onehot_project=None):
|
||||
"""
|
||||
encode the source sequence
|
||||
:return:
|
||||
src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2)
|
||||
last_state, last_cell: Variable(batch_size, hidden_size)
|
||||
"""
|
||||
src_token_embed = self.gen_x_batch(src_sents_var)
|
||||
|
||||
if q_onehot_project is not None:
|
||||
src_token_embed = torch.cat([src_token_embed, q_onehot_project], dim=-1)
|
||||
|
||||
packed_src_token_embed = pack_padded_sequence(src_token_embed, src_sents_len, batch_first=True)
|
||||
# src_encodings: (tgt_query_len, batch_size, hidden_size)
|
||||
src_encodings, (last_state, last_cell) = self.encoder_lstm(packed_src_token_embed)
|
||||
src_encodings, _ = pad_packed_sequence(src_encodings, batch_first=True)
|
||||
# src_encodings: (batch_size, tgt_query_len, hidden_size)
|
||||
# src_encodings = src_encodings.permute(1, 0, 2)
|
||||
# (batch_size, hidden_size * 2)
|
||||
last_state = torch.cat([last_state[0], last_state[1]], -1)
|
||||
last_cell = torch.cat([last_cell[0], last_cell[1]], -1)
|
||||
|
||||
return src_encodings, (last_state, last_cell)
|
||||
|
||||
def input_type(self, values_list):
|
||||
B = len(values_list)
|
||||
val_len = []
|
||||
for value in values_list:
|
||||
val_len.append(len(value))
|
||||
max_len = max(val_len)
|
||||
# for the Begin and End
|
||||
val_emb_array = np.zeros((B, max_len, values_list[0].shape[1]), dtype=np.float32)
|
||||
for i in range(B):
|
||||
val_emb_array[i, :val_len[i], :] = values_list[i][:, :]
|
||||
|
||||
val_inp = torch.from_numpy(val_emb_array)
|
||||
if self.args.cuda:
|
||||
val_inp = val_inp.cuda()
|
||||
val_inp_var = Variable(val_inp)
|
||||
return val_inp_var
|
||||
|
||||
def padding_sketch(self, sketch):
|
||||
padding_result = []
|
||||
for action in sketch:
|
||||
padding_result.append(action)
|
||||
if type(action) == define_rule.N:
|
||||
for _ in range(action.id_c + 1):
|
||||
padding_result.append(define_rule.A(0))
|
||||
padding_result.append(define_rule.C(0))
|
||||
padding_result.append(define_rule.T(0))
|
||||
elif type(action) == define_rule.Filter and 'A' in action.production:
|
||||
padding_result.append(define_rule.A(0))
|
||||
padding_result.append(define_rule.C(0))
|
||||
padding_result.append(define_rule.T(0))
|
||||
elif type(action) == define_rule.Order or type(action) == define_rule.Sup:
|
||||
padding_result.append(define_rule.A(0))
|
||||
padding_result.append(define_rule.C(0))
|
||||
padding_result.append(define_rule.T(0))
|
||||
|
||||
return padding_result
|
||||
|
||||
def gen_x_batch(self, q):
|
||||
B = len(q)
|
||||
val_embs = []
|
||||
val_len = np.zeros(B, dtype=np.int64)
|
||||
is_list = False
|
||||
if type(q[0][0]) == list:
|
||||
is_list = True
|
||||
for i, one_q in enumerate(q):
|
||||
if not is_list:
|
||||
q_val = list(
|
||||
map(lambda x: self.word_emb.get(x, np.zeros(self.args.col_embed_size, dtype=np.float32)), one_q))
|
||||
else:
|
||||
q_val = []
|
||||
for ws in one_q:
|
||||
emb_list = []
|
||||
ws_len = len(ws)
|
||||
for w in ws:
|
||||
emb_list.append(self.word_emb.get(w, self.word_emb['unk']))
|
||||
if ws_len == 0:
|
||||
raise Exception("word list should not be empty!")
|
||||
elif ws_len == 1:
|
||||
q_val.append(emb_list[0])
|
||||
else:
|
||||
q_val.append(sum(emb_list) / float(ws_len))
|
||||
|
||||
val_embs.append(q_val)
|
||||
val_len[i] = len(q_val)
|
||||
max_len = max(val_len)
|
||||
|
||||
val_emb_array = np.zeros((B, max_len, self.args.col_embed_size), dtype=np.float32)
|
||||
for i in range(B):
|
||||
for t in range(len(val_embs[i])):
|
||||
val_emb_array[i, t, :] = val_embs[i][t]
|
||||
val_inp = torch.from_numpy(val_emb_array)
|
||||
if self.args.cuda:
|
||||
val_inp = val_inp.cuda()
|
||||
return val_inp
|
||||
|
||||
def save(self, path):
|
||||
dir_name = os.path.dirname(path)
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
|
||||
params = {
|
||||
'args': self.args,
|
||||
'vocab': self.vocab,
|
||||
'grammar': self.grammar,
|
||||
'state_dict': self.state_dict()
|
||||
}
|
||||
torch.save(params, path)
|
|
@ -0,0 +1,764 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : model.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils
|
||||
from torch.autograd import Variable
|
||||
|
||||
from src.beam import Beams, ActionInfo
|
||||
from src.dataset import Batch
|
||||
from src.models import nn_utils
|
||||
from src.models.basic_model import BasicModel
|
||||
from src.models.pointer_net import PointerNet
|
||||
from src.rule import semQL as define_rule
|
||||
|
||||
|
||||
class IRNet(BasicModel):
|
||||
|
||||
def __init__(self, args, grammar):
|
||||
super(IRNet, self).__init__()
|
||||
self.args = args
|
||||
self.grammar = grammar
|
||||
self.use_column_pointer = args.column_pointer
|
||||
self.use_sentence_features = args.sentence_features
|
||||
|
||||
if args.cuda:
|
||||
self.new_long_tensor = torch.cuda.LongTensor
|
||||
self.new_tensor = torch.cuda.FloatTensor
|
||||
else:
|
||||
self.new_long_tensor = torch.LongTensor
|
||||
self.new_tensor = torch.FloatTensor
|
||||
|
||||
self.encoder_lstm = nn.LSTM(args.embed_size, args.hidden_size // 2, bidirectional=True,
|
||||
batch_first=True)
|
||||
|
||||
input_dim = args.action_embed_size + \
|
||||
args.att_vec_size + \
|
||||
args.type_embed_size
|
||||
# previous action
|
||||
# input feeding
|
||||
# pre type embedding
|
||||
|
||||
self.lf_decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size)
|
||||
|
||||
self.sketch_decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size)
|
||||
|
||||
# initialize the decoder's state and cells with encoder hidden states
|
||||
self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size)
|
||||
|
||||
self.att_sketch_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
||||
self.att_lf_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
||||
|
||||
self.sketch_att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False)
|
||||
self.lf_att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False)
|
||||
|
||||
self.prob_att = nn.Linear(args.att_vec_size, 1)
|
||||
self.prob_len = nn.Linear(1, 1)
|
||||
|
||||
self.col_type = nn.Linear(4, args.col_embed_size)
|
||||
self.sketch_encoder = nn.LSTM(args.action_embed_size, args.action_embed_size // 2, bidirectional=True,
|
||||
batch_first=True)
|
||||
|
||||
self.production_embed = nn.Embedding(len(grammar.prod2id), args.action_embed_size)
|
||||
self.type_embed = nn.Embedding(len(grammar.type2id), args.type_embed_size)
|
||||
self.production_readout_b = nn.Parameter(torch.FloatTensor(len(grammar.prod2id)).zero_())
|
||||
|
||||
self.att_project = nn.Linear(args.hidden_size + args.type_embed_size, args.hidden_size)
|
||||
|
||||
self.N_embed = nn.Embedding(len(define_rule.N._init_grammar()), args.action_embed_size)
|
||||
|
||||
self.read_out_act = F.tanh if args.readout == 'non_linear' else nn_utils.identity
|
||||
|
||||
self.query_vec_to_action_embed = nn.Linear(args.att_vec_size, args.action_embed_size,
|
||||
bias=args.readout == 'non_linear')
|
||||
|
||||
self.production_readout = lambda q: F.linear(self.read_out_act(self.query_vec_to_action_embed(q)),
|
||||
self.production_embed.weight, self.production_readout_b)
|
||||
|
||||
self.q_att = nn.Linear(args.hidden_size, args.embed_size)
|
||||
|
||||
self.column_rnn_input = nn.Linear(args.col_embed_size, args.action_embed_size, bias=False)
|
||||
self.table_rnn_input = nn.Linear(args.col_embed_size, args.action_embed_size, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(args.dropout)
|
||||
|
||||
self.column_pointer_net = PointerNet(args.hidden_size, args.col_embed_size, attention_type=args.column_att)
|
||||
|
||||
self.table_pointer_net = PointerNet(args.hidden_size, args.col_embed_size, attention_type=args.column_att)
|
||||
|
||||
# initial the embedding layers
|
||||
nn.init.xavier_normal_(self.production_embed.weight.data)
|
||||
nn.init.xavier_normal_(self.type_embed.weight.data)
|
||||
nn.init.xavier_normal_(self.N_embed.weight.data)
|
||||
print('Use Column Pointer: ', True if self.use_column_pointer else False)
|
||||
|
||||
def forward(self, examples):
|
||||
args = self.args
|
||||
# now should implement the examples
|
||||
batch = Batch(examples, self.grammar, cuda=self.args.cuda)
|
||||
|
||||
table_appear_mask = batch.table_appear_mask
|
||||
|
||||
|
||||
src_encodings, (last_state, last_cell) = self.encode(batch.src_sents, batch.src_sents_len, None)
|
||||
|
||||
src_encodings = self.dropout(src_encodings)
|
||||
|
||||
utterance_encodings_sketch_linear = self.att_sketch_linear(src_encodings)
|
||||
utterance_encodings_lf_linear = self.att_lf_linear(src_encodings)
|
||||
|
||||
dec_init_vec = self.init_decoder_state(last_cell)
|
||||
h_tm1 = dec_init_vec
|
||||
action_probs = [[] for _ in examples]
|
||||
|
||||
zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())
|
||||
zero_type_embed = Variable(self.new_tensor(args.type_embed_size).zero_())
|
||||
|
||||
sketch_attention_history = list()
|
||||
|
||||
for t in range(batch.max_sketch_num):
|
||||
if t == 0:
|
||||
x = Variable(self.new_tensor(len(batch), self.sketch_decoder_lstm.input_size).zero_(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
a_tm1_embeds = []
|
||||
pre_types = []
|
||||
for e_id, example in enumerate(examples):
|
||||
|
||||
if t < len(example.sketch):
|
||||
# get the last action
|
||||
# This is the action embedding
|
||||
action_tm1 = example.sketch[t - 1]
|
||||
if type(action_tm1) in [define_rule.Root1,
|
||||
define_rule.Root,
|
||||
define_rule.Sel,
|
||||
define_rule.Filter,
|
||||
define_rule.Sup,
|
||||
define_rule.N,
|
||||
define_rule.Order]:
|
||||
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
|
||||
else:
|
||||
print(action_tm1, 'only for sketch')
|
||||
quit()
|
||||
a_tm1_embed = zero_action_embed
|
||||
pass
|
||||
else:
|
||||
a_tm1_embed = zero_action_embed
|
||||
|
||||
a_tm1_embeds.append(a_tm1_embed)
|
||||
|
||||
a_tm1_embeds = torch.stack(a_tm1_embeds)
|
||||
inputs = [a_tm1_embeds]
|
||||
|
||||
for e_id, example in enumerate(examples):
|
||||
if t < len(example.sketch):
|
||||
action_tm = example.sketch[t - 1]
|
||||
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
|
||||
else:
|
||||
pre_type = zero_type_embed
|
||||
pre_types.append(pre_type)
|
||||
|
||||
pre_types = torch.stack(pre_types)
|
||||
|
||||
inputs.append(att_tm1)
|
||||
inputs.append(pre_types)
|
||||
x = torch.cat(inputs, dim=-1)
|
||||
|
||||
src_mask = batch.src_token_mask
|
||||
|
||||
(h_t, cell_t), att_t, aw = self.step(x, h_tm1, src_encodings,
|
||||
utterance_encodings_sketch_linear, self.sketch_decoder_lstm,
|
||||
self.sketch_att_vec_linear,
|
||||
src_token_mask=src_mask, return_att_weight=True)
|
||||
sketch_attention_history.append(att_t)
|
||||
|
||||
# get the Root possibility
|
||||
apply_rule_prob = F.softmax(self.production_readout(att_t), dim=-1)
|
||||
|
||||
for e_id, example in enumerate(examples):
|
||||
if t < len(example.sketch):
|
||||
action_t = example.sketch[t]
|
||||
act_prob_t_i = apply_rule_prob[e_id, self.grammar.prod2id[action_t.production]]
|
||||
action_probs[e_id].append(act_prob_t_i)
|
||||
|
||||
h_tm1 = (h_t, cell_t)
|
||||
att_tm1 = att_t
|
||||
|
||||
sketch_prob_var = torch.stack(
|
||||
[torch.stack(action_probs_i, dim=0).log().sum() for action_probs_i in action_probs], dim=0)
|
||||
|
||||
table_embedding = self.gen_x_batch(batch.table_sents)
|
||||
src_embedding = self.gen_x_batch(batch.src_sents)
|
||||
schema_embedding = self.gen_x_batch(batch.table_names)
|
||||
|
||||
# get emb differ
|
||||
embedding_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=table_embedding,
|
||||
table_unk_mask=batch.table_unk_mask)
|
||||
|
||||
schema_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=schema_embedding,
|
||||
table_unk_mask=batch.schema_token_mask)
|
||||
|
||||
tab_ctx = (src_encodings.unsqueeze(1) * embedding_differ.unsqueeze(3)).sum(2)
|
||||
schema_ctx = (src_encodings.unsqueeze(1) * schema_differ.unsqueeze(3)).sum(2)
|
||||
|
||||
|
||||
table_embedding = table_embedding + tab_ctx
|
||||
|
||||
schema_embedding = schema_embedding + schema_ctx
|
||||
|
||||
col_type = self.input_type(batch.col_hot_type)
|
||||
|
||||
col_type_var = self.col_type(col_type)
|
||||
|
||||
table_embedding = table_embedding + col_type_var
|
||||
|
||||
batch_table_dict = batch.col_table_dict
|
||||
table_enable = np.zeros(shape=(len(examples)))
|
||||
action_probs = [[] for _ in examples]
|
||||
|
||||
h_tm1 = dec_init_vec
|
||||
|
||||
for t in range(batch.max_action_num):
|
||||
if t == 0:
|
||||
# x = self.lf_begin_vec.unsqueeze(0).repeat(len(batch), 1)
|
||||
x = Variable(self.new_tensor(len(batch), self.lf_decoder_lstm.input_size).zero_(), requires_grad=False)
|
||||
else:
|
||||
a_tm1_embeds = []
|
||||
pre_types = []
|
||||
|
||||
for e_id, example in enumerate(examples):
|
||||
if t < len(example.tgt_actions):
|
||||
action_tm1 = example.tgt_actions[t - 1]
|
||||
if type(action_tm1) in [define_rule.Root1,
|
||||
define_rule.Root,
|
||||
define_rule.Sel,
|
||||
define_rule.Filter,
|
||||
define_rule.Sup,
|
||||
define_rule.N,
|
||||
define_rule.Order,
|
||||
]:
|
||||
|
||||
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
|
||||
|
||||
else:
|
||||
if isinstance(action_tm1, define_rule.C):
|
||||
a_tm1_embed = self.column_rnn_input(table_embedding[e_id, action_tm1.id_c])
|
||||
elif isinstance(action_tm1, define_rule.T):
|
||||
a_tm1_embed = self.column_rnn_input(schema_embedding[e_id, action_tm1.id_c])
|
||||
elif isinstance(action_tm1, define_rule.A):
|
||||
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
|
||||
else:
|
||||
print(action_tm1, 'not implement')
|
||||
quit()
|
||||
a_tm1_embed = zero_action_embed
|
||||
pass
|
||||
|
||||
else:
|
||||
a_tm1_embed = zero_action_embed
|
||||
a_tm1_embeds.append(a_tm1_embed)
|
||||
|
||||
a_tm1_embeds = torch.stack(a_tm1_embeds)
|
||||
|
||||
inputs = [a_tm1_embeds]
|
||||
|
||||
# tgt t-1 action type
|
||||
for e_id, example in enumerate(examples):
|
||||
if t < len(example.tgt_actions):
|
||||
action_tm = example.tgt_actions[t - 1]
|
||||
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
|
||||
else:
|
||||
pre_type = zero_type_embed
|
||||
pre_types.append(pre_type)
|
||||
|
||||
pre_types = torch.stack(pre_types)
|
||||
|
||||
inputs.append(att_tm1)
|
||||
|
||||
inputs.append(pre_types)
|
||||
|
||||
x = torch.cat(inputs, dim=-1)
|
||||
|
||||
src_mask = batch.src_token_mask
|
||||
|
||||
(h_t, cell_t), att_t, aw = self.step(x, h_tm1, src_encodings,
|
||||
utterance_encodings_lf_linear, self.lf_decoder_lstm,
|
||||
self.lf_att_vec_linear,
|
||||
src_token_mask=src_mask, return_att_weight=True)
|
||||
|
||||
apply_rule_prob = F.softmax(self.production_readout(att_t), dim=-1)
|
||||
table_appear_mask_val = torch.from_numpy(table_appear_mask)
|
||||
if self.cuda:
|
||||
table_appear_mask_val = table_appear_mask_val.cuda()
|
||||
|
||||
if self.use_column_pointer:
|
||||
gate = F.sigmoid(self.prob_att(att_t))
|
||||
weights = self.column_pointer_net(src_encodings=table_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=None) * table_appear_mask_val * gate + self.column_pointer_net(
|
||||
src_encodings=table_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=None) * (1 - table_appear_mask_val) * (1 - gate)
|
||||
else:
|
||||
weights = self.column_pointer_net(src_encodings=table_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=batch.table_token_mask)
|
||||
|
||||
weights.data.masked_fill_(batch.table_token_mask, -float('inf'))
|
||||
|
||||
column_attention_weights = F.softmax(weights, dim=-1)
|
||||
|
||||
table_weights = self.table_pointer_net(src_encodings=schema_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=None)
|
||||
|
||||
schema_token_mask = batch.schema_token_mask.expand_as(table_weights)
|
||||
table_weights.data.masked_fill_(schema_token_mask, -float('inf'))
|
||||
table_dict = [batch_table_dict[x_id][int(x)] for x_id, x in enumerate(table_enable.tolist())]
|
||||
table_mask = batch.table_dict_mask(table_dict)
|
||||
table_weights.data.masked_fill_(table_mask, -float('inf'))
|
||||
|
||||
table_weights = F.softmax(table_weights, dim=-1)
|
||||
# now get the loss
|
||||
for e_id, example in enumerate(examples):
|
||||
if t < len(example.tgt_actions):
|
||||
action_t = example.tgt_actions[t]
|
||||
if isinstance(action_t, define_rule.C):
|
||||
table_appear_mask[e_id, action_t.id_c] = 1
|
||||
table_enable[e_id] = action_t.id_c
|
||||
act_prob_t_i = column_attention_weights[e_id, action_t.id_c]
|
||||
action_probs[e_id].append(act_prob_t_i)
|
||||
elif isinstance(action_t, define_rule.T):
|
||||
act_prob_t_i = table_weights[e_id, action_t.id_c]
|
||||
action_probs[e_id].append(act_prob_t_i)
|
||||
elif isinstance(action_t, define_rule.A):
|
||||
act_prob_t_i = apply_rule_prob[e_id, self.grammar.prod2id[action_t.production]]
|
||||
action_probs[e_id].append(act_prob_t_i)
|
||||
else:
|
||||
pass
|
||||
h_tm1 = (h_t, cell_t)
|
||||
att_tm1 = att_t
|
||||
lf_prob_var = torch.stack(
|
||||
[torch.stack(action_probs_i, dim=0).log().sum() for action_probs_i in action_probs], dim=0)
|
||||
|
||||
return [sketch_prob_var, lf_prob_var]
|
||||
|
||||
def parse(self, examples, beam_size=5):
|
||||
"""
|
||||
one example a time
|
||||
:param examples:
|
||||
:param beam_size:
|
||||
:return:
|
||||
"""
|
||||
batch = Batch([examples], self.grammar, cuda=self.args.cuda)
|
||||
|
||||
src_encodings, (last_state, last_cell) = self.encode(batch.src_sents, batch.src_sents_len, None)
|
||||
src_encodings = self.dropout(src_encodings)
|
||||
|
||||
utterance_encodings_sketch_linear = self.att_sketch_linear(src_encodings)
|
||||
utterance_encodings_lf_linear = self.att_lf_linear(src_encodings)
|
||||
|
||||
dec_init_vec = self.init_decoder_state(last_cell)
|
||||
h_tm1 = dec_init_vec
|
||||
|
||||
t = 0
|
||||
beams = [Beams(is_sketch=True)]
|
||||
completed_beams = []
|
||||
|
||||
while len(completed_beams) < beam_size and t < self.args.decode_max_time_step:
|
||||
hyp_num = len(beams)
|
||||
exp_src_enconding = src_encodings.expand(hyp_num, src_encodings.size(1),
|
||||
src_encodings.size(2))
|
||||
exp_src_encodings_sketch_linear = utterance_encodings_sketch_linear.expand(hyp_num,
|
||||
utterance_encodings_sketch_linear.size(
|
||||
1),
|
||||
utterance_encodings_sketch_linear.size(
|
||||
2))
|
||||
if t == 0:
|
||||
with torch.no_grad():
|
||||
x = Variable(self.new_tensor(1, self.sketch_decoder_lstm.input_size).zero_())
|
||||
else:
|
||||
a_tm1_embeds = []
|
||||
pre_types = []
|
||||
for e_id, hyp in enumerate(beams):
|
||||
action_tm1 = hyp.actions[-1]
|
||||
if type(action_tm1) in [define_rule.Root1,
|
||||
define_rule.Root,
|
||||
define_rule.Sel,
|
||||
define_rule.Filter,
|
||||
define_rule.Sup,
|
||||
define_rule.N,
|
||||
define_rule.Order]:
|
||||
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
|
||||
else:
|
||||
raise ValueError('unknown action %s' % action_tm1)
|
||||
|
||||
a_tm1_embeds.append(a_tm1_embed)
|
||||
a_tm1_embeds = torch.stack(a_tm1_embeds)
|
||||
inputs = [a_tm1_embeds]
|
||||
|
||||
for e_id, hyp in enumerate(beams):
|
||||
action_tm = hyp.actions[-1]
|
||||
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
|
||||
pre_types.append(pre_type)
|
||||
|
||||
pre_types = torch.stack(pre_types)
|
||||
|
||||
inputs.append(att_tm1)
|
||||
inputs.append(pre_types)
|
||||
x = torch.cat(inputs, dim=-1)
|
||||
|
||||
(h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_enconding,
|
||||
exp_src_encodings_sketch_linear, self.sketch_decoder_lstm,
|
||||
self.sketch_att_vec_linear,
|
||||
src_token_mask=None)
|
||||
|
||||
apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)
|
||||
|
||||
new_hyp_meta = []
|
||||
for hyp_id, hyp in enumerate(beams):
|
||||
action_class = hyp.get_availableClass()
|
||||
if action_class in [define_rule.Root1,
|
||||
define_rule.Root,
|
||||
define_rule.Sel,
|
||||
define_rule.Filter,
|
||||
define_rule.Sup,
|
||||
define_rule.N,
|
||||
define_rule.Order]:
|
||||
possible_productions = self.grammar.get_production(action_class)
|
||||
for possible_production in possible_productions:
|
||||
prod_id = self.grammar.prod2id[possible_production]
|
||||
prod_score = apply_rule_log_prob[hyp_id, prod_id]
|
||||
new_hyp_score = hyp.score + prod_score.data.cpu()
|
||||
meta_entry = {'action_type': action_class, 'prod_id': prod_id,
|
||||
'score': prod_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
else:
|
||||
raise RuntimeError('No right action class')
|
||||
|
||||
if not new_hyp_meta: break
|
||||
|
||||
new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta], dim=0)
|
||||
top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
|
||||
k=min(new_hyp_scores.size(0),
|
||||
beam_size - len(completed_beams)))
|
||||
|
||||
live_hyp_ids = []
|
||||
new_beams = []
|
||||
for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
|
||||
action_info = ActionInfo()
|
||||
hyp_meta_entry = new_hyp_meta[meta_id]
|
||||
prev_hyp_id = hyp_meta_entry['prev_hyp_id']
|
||||
prev_hyp = beams[prev_hyp_id]
|
||||
action_type_str = hyp_meta_entry['action_type']
|
||||
prod_id = hyp_meta_entry['prod_id']
|
||||
if prod_id < len(self.grammar.id2prod):
|
||||
production = self.grammar.id2prod[prod_id]
|
||||
action = action_type_str(list(action_type_str._init_grammar()).index(production))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
action_info.action = action
|
||||
action_info.t = t
|
||||
action_info.score = hyp_meta_entry['score']
|
||||
new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
|
||||
new_hyp.score = new_hyp_score
|
||||
new_hyp.inputs.extend(prev_hyp.inputs)
|
||||
|
||||
if new_hyp.is_valid is False:
|
||||
continue
|
||||
|
||||
if new_hyp.completed:
|
||||
completed_beams.append(new_hyp)
|
||||
else:
|
||||
new_beams.append(new_hyp)
|
||||
live_hyp_ids.append(prev_hyp_id)
|
||||
|
||||
if live_hyp_ids:
|
||||
h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
|
||||
att_tm1 = att_t[live_hyp_ids]
|
||||
beams = new_beams
|
||||
t += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# now get the sketch result
|
||||
completed_beams.sort(key=lambda hyp: -hyp.score)
|
||||
if len(completed_beams) == 0:
|
||||
return [[], []]
|
||||
|
||||
sketch_actions = completed_beams[0].actions
|
||||
# sketch_actions = examples.sketch
|
||||
|
||||
padding_sketch = self.padding_sketch(sketch_actions)
|
||||
|
||||
table_embedding = self.gen_x_batch(batch.table_sents)
|
||||
src_embedding = self.gen_x_batch(batch.src_sents)
|
||||
schema_embedding = self.gen_x_batch(batch.table_names)
|
||||
|
||||
# get emb differ
|
||||
embedding_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=table_embedding,
|
||||
table_unk_mask=batch.table_unk_mask)
|
||||
|
||||
schema_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=schema_embedding,
|
||||
table_unk_mask=batch.schema_token_mask)
|
||||
|
||||
tab_ctx = (src_encodings.unsqueeze(1) * embedding_differ.unsqueeze(3)).sum(2)
|
||||
schema_ctx = (src_encodings.unsqueeze(1) * schema_differ.unsqueeze(3)).sum(2)
|
||||
|
||||
table_embedding = table_embedding + tab_ctx
|
||||
|
||||
schema_embedding = schema_embedding + schema_ctx
|
||||
|
||||
col_type = self.input_type(batch.col_hot_type)
|
||||
|
||||
col_type_var = self.col_type(col_type)
|
||||
|
||||
table_embedding = table_embedding + col_type_var
|
||||
|
||||
batch_table_dict = batch.col_table_dict
|
||||
|
||||
h_tm1 = dec_init_vec
|
||||
|
||||
t = 0
|
||||
beams = [Beams(is_sketch=False)]
|
||||
completed_beams = []
|
||||
|
||||
while len(completed_beams) < beam_size and t < self.args.decode_max_time_step:
|
||||
hyp_num = len(beams)
|
||||
|
||||
# expand value
|
||||
exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1),
|
||||
src_encodings.size(2))
|
||||
exp_utterance_encodings_lf_linear = utterance_encodings_lf_linear.expand(hyp_num,
|
||||
utterance_encodings_lf_linear.size(
|
||||
1),
|
||||
utterance_encodings_lf_linear.size(
|
||||
2))
|
||||
exp_table_embedding = table_embedding.expand(hyp_num, table_embedding.size(1),
|
||||
table_embedding.size(2))
|
||||
|
||||
exp_schema_embedding = schema_embedding.expand(hyp_num, schema_embedding.size(1),
|
||||
schema_embedding.size(2))
|
||||
|
||||
|
||||
table_appear_mask = batch.table_appear_mask
|
||||
table_appear_mask = np.zeros((hyp_num, table_appear_mask.shape[1]), dtype=np.float32)
|
||||
table_enable = np.zeros(shape=(hyp_num))
|
||||
for e_id, hyp in enumerate(beams):
|
||||
for act in hyp.actions:
|
||||
if type(act) == define_rule.C:
|
||||
table_appear_mask[e_id][act.id_c] = 1
|
||||
table_enable[e_id] = act.id_c
|
||||
|
||||
if t == 0:
|
||||
with torch.no_grad():
|
||||
x = Variable(self.new_tensor(1, self.lf_decoder_lstm.input_size).zero_())
|
||||
else:
|
||||
a_tm1_embeds = []
|
||||
pre_types = []
|
||||
for e_id, hyp in enumerate(beams):
|
||||
action_tm1 = hyp.actions[-1]
|
||||
if type(action_tm1) in [define_rule.Root1,
|
||||
define_rule.Root,
|
||||
define_rule.Sel,
|
||||
define_rule.Filter,
|
||||
define_rule.Sup,
|
||||
define_rule.N,
|
||||
define_rule.Order]:
|
||||
|
||||
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
|
||||
hyp.sketch_step += 1
|
||||
elif isinstance(action_tm1, define_rule.C):
|
||||
a_tm1_embed = self.column_rnn_input(table_embedding[0, action_tm1.id_c])
|
||||
elif isinstance(action_tm1, define_rule.T):
|
||||
a_tm1_embed = self.column_rnn_input(schema_embedding[0, action_tm1.id_c])
|
||||
elif isinstance(action_tm1, define_rule.A):
|
||||
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
|
||||
else:
|
||||
raise ValueError('unknown action %s' % action_tm1)
|
||||
|
||||
a_tm1_embeds.append(a_tm1_embed)
|
||||
|
||||
a_tm1_embeds = torch.stack(a_tm1_embeds)
|
||||
|
||||
inputs = [a_tm1_embeds]
|
||||
|
||||
for e_id, hyp in enumerate(beams):
|
||||
action_tm = hyp.actions[-1]
|
||||
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
|
||||
pre_types.append(pre_type)
|
||||
|
||||
pre_types = torch.stack(pre_types)
|
||||
|
||||
inputs.append(att_tm1)
|
||||
inputs.append(pre_types)
|
||||
x = torch.cat(inputs, dim=-1)
|
||||
|
||||
(h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
|
||||
exp_utterance_encodings_lf_linear, self.lf_decoder_lstm,
|
||||
self.lf_att_vec_linear,
|
||||
src_token_mask=None)
|
||||
|
||||
apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)
|
||||
|
||||
table_appear_mask_val = torch.from_numpy(table_appear_mask)
|
||||
|
||||
if self.args.cuda: table_appear_mask_val = table_appear_mask_val.cuda()
|
||||
|
||||
if self.use_column_pointer:
|
||||
gate = F.sigmoid(self.prob_att(att_t))
|
||||
weights = self.column_pointer_net(src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=None) * table_appear_mask_val * gate + self.column_pointer_net(
|
||||
src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=None) * (1 - table_appear_mask_val) * (1 - gate)
|
||||
# weights = weights + self.col_attention_out(exp_embedding_differ).squeeze()
|
||||
else:
|
||||
weights = self.column_pointer_net(src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=batch.table_token_mask)
|
||||
# weights.data.masked_fill_(exp_col_pred_mask, -float('inf'))
|
||||
|
||||
column_selection_log_prob = F.log_softmax(weights, dim=-1)
|
||||
|
||||
table_weights = self.table_pointer_net(src_encodings=exp_schema_embedding, query_vec=att_t.unsqueeze(0),
|
||||
src_token_mask=None)
|
||||
# table_weights = self.table_pointer_net(src_encodings=exp_schema_embedding, query_vec=att_t.unsqueeze(0), src_token_mask=None)
|
||||
|
||||
schema_token_mask = batch.schema_token_mask.expand_as(table_weights)
|
||||
table_weights.data.masked_fill_(schema_token_mask, -float('inf'))
|
||||
|
||||
table_dict = [batch_table_dict[0][int(x)] for x_id, x in enumerate(table_enable.tolist())]
|
||||
table_mask = batch.table_dict_mask(table_dict)
|
||||
table_weights.data.masked_fill_(table_mask, -float('inf'))
|
||||
|
||||
table_weights = F.log_softmax(table_weights, dim=-1)
|
||||
|
||||
new_hyp_meta = []
|
||||
for hyp_id, hyp in enumerate(beams):
|
||||
# TODO: should change this
|
||||
if type(padding_sketch[t]) == define_rule.A:
|
||||
possible_productions = self.grammar.get_production(define_rule.A)
|
||||
for possible_production in possible_productions:
|
||||
prod_id = self.grammar.prod2id[possible_production]
|
||||
prod_score = apply_rule_log_prob[hyp_id, prod_id]
|
||||
|
||||
new_hyp_score = hyp.score + prod_score.data.cpu()
|
||||
meta_entry = {'action_type': define_rule.A, 'prod_id': prod_id,
|
||||
'score': prod_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
|
||||
elif type(padding_sketch[t]) == define_rule.C:
|
||||
for col_id, _ in enumerate(batch.table_sents[0]):
|
||||
col_sel_score = column_selection_log_prob[hyp_id, col_id]
|
||||
new_hyp_score = hyp.score + col_sel_score.data.cpu()
|
||||
meta_entry = {'action_type': define_rule.C, 'col_id': col_id,
|
||||
'score': col_sel_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
elif type(padding_sketch[t]) == define_rule.T:
|
||||
for t_id, _ in enumerate(batch.table_names[0]):
|
||||
t_sel_score = table_weights[hyp_id, t_id]
|
||||
new_hyp_score = hyp.score + t_sel_score.data.cpu()
|
||||
|
||||
meta_entry = {'action_type': define_rule.T, 't_id': t_id,
|
||||
'score': t_sel_score, 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
else:
|
||||
prod_id = self.grammar.prod2id[padding_sketch[t].production]
|
||||
new_hyp_score = hyp.score + torch.tensor(0.0)
|
||||
meta_entry = {'action_type': type(padding_sketch[t]), 'prod_id': prod_id,
|
||||
'score': torch.tensor(0.0), 'new_hyp_score': new_hyp_score,
|
||||
'prev_hyp_id': hyp_id}
|
||||
new_hyp_meta.append(meta_entry)
|
||||
|
||||
if not new_hyp_meta: break
|
||||
|
||||
new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta], dim=0)
|
||||
top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
|
||||
k=min(new_hyp_scores.size(0),
|
||||
beam_size - len(completed_beams)))
|
||||
|
||||
live_hyp_ids = []
|
||||
new_beams = []
|
||||
for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
|
||||
action_info = ActionInfo()
|
||||
hyp_meta_entry = new_hyp_meta[meta_id]
|
||||
prev_hyp_id = hyp_meta_entry['prev_hyp_id']
|
||||
prev_hyp = beams[prev_hyp_id]
|
||||
|
||||
action_type_str = hyp_meta_entry['action_type']
|
||||
if 'prod_id' in hyp_meta_entry:
|
||||
prod_id = hyp_meta_entry['prod_id']
|
||||
if action_type_str == define_rule.C:
|
||||
col_id = hyp_meta_entry['col_id']
|
||||
action = define_rule.C(col_id)
|
||||
elif action_type_str == define_rule.T:
|
||||
t_id = hyp_meta_entry['t_id']
|
||||
action = define_rule.T(t_id)
|
||||
elif prod_id < len(self.grammar.id2prod):
|
||||
production = self.grammar.id2prod[prod_id]
|
||||
action = action_type_str(list(action_type_str._init_grammar()).index(production))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
action_info.action = action
|
||||
action_info.t = t
|
||||
action_info.score = hyp_meta_entry['score']
|
||||
|
||||
new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
|
||||
new_hyp.score = new_hyp_score
|
||||
new_hyp.inputs.extend(prev_hyp.inputs)
|
||||
|
||||
if new_hyp.is_valid is False:
|
||||
continue
|
||||
|
||||
if new_hyp.completed:
|
||||
completed_beams.append(new_hyp)
|
||||
else:
|
||||
new_beams.append(new_hyp)
|
||||
live_hyp_ids.append(prev_hyp_id)
|
||||
|
||||
if live_hyp_ids:
|
||||
h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
|
||||
att_tm1 = att_t[live_hyp_ids]
|
||||
|
||||
beams = new_beams
|
||||
t += 1
|
||||
else:
|
||||
break
|
||||
|
||||
completed_beams.sort(key=lambda hyp: -hyp.score)
|
||||
|
||||
return [completed_beams, sketch_actions]
|
||||
|
||||
def step(self, x, h_tm1, src_encodings, src_encodings_att_linear, decoder, attention_func, src_token_mask=None,
|
||||
return_att_weight=False):
|
||||
# h_t: (batch_size, hidden_size)
|
||||
h_t, cell_t = decoder(x, h_tm1)
|
||||
|
||||
ctx_t, alpha_t = nn_utils.dot_prod_attention(h_t,
|
||||
src_encodings, src_encodings_att_linear,
|
||||
mask=src_token_mask)
|
||||
|
||||
att_t = F.tanh(attention_func(torch.cat([h_t, ctx_t], 1)))
|
||||
att_t = self.dropout(att_t)
|
||||
|
||||
if return_att_weight:
|
||||
return (h_t, cell_t), att_t, alpha_t
|
||||
else:
|
||||
return (h_t, cell_t), att_t
|
||||
|
||||
def init_decoder_state(self, enc_last_cell):
|
||||
h_0 = self.decoder_cell_init(enc_last_cell)
|
||||
h_0 = F.tanh(h_0)
|
||||
|
||||
return h_0, Variable(self.new_tensor(h_0.size()).zero_())
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : utils.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from six.moves import xrange
|
||||
|
||||
def dot_prod_attention(h_t, src_encoding, src_encoding_att_linear, mask=None):
|
||||
"""
|
||||
:param h_t: (batch_size, hidden_size)
|
||||
:param src_encoding: (batch_size, src_sent_len, hidden_size * 2)
|
||||
:param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size)
|
||||
:param mask: (batch_size, src_sent_len)
|
||||
"""
|
||||
# (batch_size, src_sent_len)
|
||||
att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2)
|
||||
if mask is not None:
|
||||
att_weight.data.masked_fill_(mask, -float('inf'))
|
||||
att_weight = F.softmax(att_weight, dim=-1)
|
||||
|
||||
att_view = (att_weight.size(0), 1, att_weight.size(1))
|
||||
# (batch_size, hidden_size)
|
||||
ctx_vec = torch.bmm(att_weight.view(*att_view), src_encoding).squeeze(1)
|
||||
|
||||
return ctx_vec, att_weight
|
||||
|
||||
|
||||
def length_array_to_mask_tensor(length_array, cuda=False, value=None):
|
||||
max_len = max(length_array)
|
||||
batch_size = len(length_array)
|
||||
|
||||
mask = np.ones((batch_size, max_len), dtype=np.uint8)
|
||||
for i, seq_len in enumerate(length_array):
|
||||
mask[i][:seq_len] = 0
|
||||
|
||||
if value != None:
|
||||
for b_id in range(len(value)):
|
||||
for c_id, c in enumerate(value[b_id]):
|
||||
if value[b_id][c_id] == [3]:
|
||||
mask[b_id][c_id] = 1
|
||||
|
||||
mask = torch.ByteTensor(mask)
|
||||
return mask.cuda() if cuda else mask
|
||||
|
||||
|
||||
def table_dict_to_mask_tensor(length_array, table_dict, cuda=False ):
|
||||
max_len = max(length_array)
|
||||
batch_size = len(table_dict)
|
||||
|
||||
mask = np.ones((batch_size, max_len), dtype=np.uint8)
|
||||
for i, ta_val in enumerate(table_dict):
|
||||
for tt in ta_val:
|
||||
mask[i][tt] = 0
|
||||
|
||||
mask = torch.ByteTensor(mask)
|
||||
return mask.cuda() if cuda else mask
|
||||
|
||||
|
||||
def length_position_tensor(length_array, cuda=False, value=None):
|
||||
max_len = max(length_array)
|
||||
batch_size = len(length_array)
|
||||
|
||||
mask = np.zeros((batch_size, max_len), dtype=np.float32)
|
||||
|
||||
for b_id in range(batch_size):
|
||||
for len_c in range(length_array[b_id]):
|
||||
mask[b_id][len_c] = len_c + 1
|
||||
|
||||
mask = torch.LongTensor(mask)
|
||||
return mask.cuda() if cuda else mask
|
||||
|
||||
|
||||
def appear_to_mask_tensor(length_array, cuda=False, value=None):
|
||||
max_len = max(length_array)
|
||||
batch_size = len(length_array)
|
||||
mask = np.zeros((batch_size, max_len), dtype=np.float32)
|
||||
return mask
|
||||
|
||||
def pred_col_mask(value, max_len):
|
||||
max_len = max(max_len)
|
||||
batch_size = len(value)
|
||||
mask = np.ones((batch_size, max_len), dtype=np.uint8)
|
||||
for v_ind, v_val in enumerate(value):
|
||||
for v in v_val:
|
||||
mask[v_ind][v] = 0
|
||||
mask = torch.ByteTensor(mask)
|
||||
return mask.cuda()
|
||||
|
||||
|
||||
def input_transpose(sents, pad_token):
|
||||
"""
|
||||
transform the input List[sequence] of size (batch_size, max_sent_len)
|
||||
into a list of size (batch_size, max_sent_len), with proper padding
|
||||
"""
|
||||
max_len = max(len(s) for s in sents)
|
||||
batch_size = len(sents)
|
||||
sents_t = []
|
||||
masks = []
|
||||
for e_id in range(batch_size):
|
||||
if type(sents[0][0]) != list:
|
||||
sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else pad_token for i in range(max_len)])
|
||||
else:
|
||||
sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else [pad_token] for i in range(max_len)])
|
||||
|
||||
masks.append([1 if len(sents[e_id]) > i else 0 for i in range(max_len)])
|
||||
|
||||
return sents_t, masks
|
||||
|
||||
|
||||
def word2id(sents, vocab):
|
||||
if type(sents[0]) == list:
|
||||
if type(sents[0][0]) != list:
|
||||
return [[vocab[w] for w in s] for s in sents]
|
||||
else:
|
||||
return [[[vocab[w] for w in s] for s in v] for v in sents ]
|
||||
else:
|
||||
return [vocab[w] for w in sents]
|
||||
|
||||
|
||||
def id2word(sents, vocab):
|
||||
if type(sents[0]) == list:
|
||||
return [[vocab.id2word[w] for w in s] for s in sents]
|
||||
else:
|
||||
return [vocab.id2word[w] for w in sents]
|
||||
|
||||
|
||||
def to_input_variable(sequences, vocab, cuda=False, training=True):
|
||||
"""
|
||||
given a list of sequences,
|
||||
return a tensor of shape (max_sent_len, batch_size)
|
||||
"""
|
||||
word_ids = word2id(sequences, vocab)
|
||||
sents_t, masks = input_transpose(word_ids, vocab['<pad>'])
|
||||
|
||||
if type(sents_t[0][0]) != list:
|
||||
with torch.no_grad():
|
||||
sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False)
|
||||
if cuda:
|
||||
sents_var = sents_var.cuda()
|
||||
else:
|
||||
sents_var = sents_t
|
||||
|
||||
return sents_var
|
||||
|
||||
|
||||
def variable_constr(x, v, cuda=False):
|
||||
return Variable(torch.cuda.x(v)) if cuda else Variable(torch.x(v))
|
||||
|
||||
|
||||
def batch_iter(examples, batch_size, shuffle=False):
|
||||
index_arr = np.arange(len(examples))
|
||||
if shuffle:
|
||||
np.random.shuffle(index_arr)
|
||||
|
||||
batch_num = int(np.ceil(len(examples) / float(batch_size)))
|
||||
for batch_id in xrange(batch_num):
|
||||
batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)]
|
||||
batch_examples = [examples[i] for i in batch_ids]
|
||||
|
||||
yield batch_examples
|
||||
|
||||
|
||||
def isnan(data):
|
||||
data = data.cpu().numpy()
|
||||
return np.isnan(data).any() or np.isinf(data).any()
|
||||
|
||||
|
||||
def log_sum_exp(inputs, dim=None, keepdim=False):
|
||||
"""Numerically stable logsumexp.
|
||||
source: https://github.com/pytorch/pytorch/issues/2591
|
||||
|
||||
Args:
|
||||
inputs: A Variable with any shape.
|
||||
dim: An integer.
|
||||
keepdim: A boolean.
|
||||
|
||||
Returns:
|
||||
Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
|
||||
"""
|
||||
# For a 1-D array x (any array along a single dimension),
|
||||
# log sum exp(x) = s + log sum exp(x - s)
|
||||
# with s = max(x) being a common choice.
|
||||
|
||||
if dim is None:
|
||||
inputs = inputs.view(-1)
|
||||
dim = 0
|
||||
s, _ = torch.max(inputs, dim=dim, keepdim=True)
|
||||
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
|
||||
if not keepdim:
|
||||
outputs = outputs.squeeze(dim)
|
||||
return outputs
|
||||
|
||||
|
||||
def uniform_init(lower, upper, params):
|
||||
for p in params:
|
||||
p.data.uniform_(lower, upper)
|
||||
|
||||
|
||||
def glorot_init(params):
|
||||
for p in params:
|
||||
if len(p.data.size()) > 1:
|
||||
init.xavier_normal(p.data)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def pad_matrix(matrixs, cuda=False):
|
||||
"""
|
||||
:param matrixs:
|
||||
:return: [batch_size, max_shape, max_shape], [batch_size]
|
||||
"""
|
||||
shape = [m.shape[0] for m in matrixs]
|
||||
max_shape = max(shape)
|
||||
tensors = list()
|
||||
for s, m in zip(shape, matrixs):
|
||||
delta = max_shape - s
|
||||
if s > 0:
|
||||
tensors.append(torch.as_tensor(np.pad(m, [(0, delta), (0, delta)], mode='constant'), dtype=torch.float))
|
||||
else:
|
||||
tensors.append(torch.as_tensor(m, dtype=torch.float))
|
||||
tensors = torch.stack(tensors)
|
||||
if cuda:
|
||||
tensors = tensors.cuda()
|
||||
return tensors
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# coding=utf8
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class AuxiliaryPointerNet(nn.Module):
|
||||
|
||||
def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
|
||||
super(AuxiliaryPointerNet, self).__init__()
|
||||
|
||||
assert attention_type in ('affine', 'dot_prod')
|
||||
if attention_type == 'affine':
|
||||
self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
|
||||
self.auxiliary_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
|
||||
self.attention_type = attention_type
|
||||
|
||||
def forward(self, src_encodings, src_context_encodings, src_token_mask, query_vec):
|
||||
"""
|
||||
:param src_context_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
|
||||
:param src_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
|
||||
:param src_token_mask: Variable(batch_size, src_sent_len)
|
||||
:param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
|
||||
:return: Variable(tgt_action_num, batch_size, src_sent_len)
|
||||
"""
|
||||
|
||||
# (batch_size, 1, src_sent_len, query_vec_size)
|
||||
encodings = src_encodings.clone()
|
||||
context_encodings = src_context_encodings.clone()
|
||||
if self.attention_type == 'affine':
|
||||
encodings = self.src_encoding_linear(src_encodings)
|
||||
context_encodings = self.auxiliary_encoding_linear(src_context_encodings)
|
||||
encodings = encodings.unsqueeze(1)
|
||||
context_encodings = context_encodings.unsqueeze(1)
|
||||
|
||||
# (batch_size, tgt_action_num, query_vec_size, 1)
|
||||
q = query_vec.permute(1, 0, 2).unsqueeze(3)
|
||||
|
||||
# (batch_size, tgt_action_num, src_sent_len)
|
||||
weights = torch.matmul(encodings, q).squeeze(3)
|
||||
context_weights = torch.matmul(context_encodings, q).squeeze(3)
|
||||
|
||||
# (tgt_action_num, batch_size, src_sent_len)
|
||||
weights = weights.permute(1, 0, 2)
|
||||
context_weights = context_weights.permute(1, 0, 2)
|
||||
|
||||
if src_token_mask is not None:
|
||||
# (tgt_action_num, batch_size, src_sent_len)
|
||||
src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
|
||||
weights.data.masked_fill_(src_token_mask, -float('inf'))
|
||||
context_weights.data.masked_fill_(src_token_mask, -float('inf'))
|
||||
|
||||
sigma = 0.1
|
||||
return weights.squeeze(0) + sigma * context_weights.squeeze(0)
|
||||
|
||||
|
||||
class PointerNet(nn.Module):
|
||||
def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
|
||||
super(PointerNet, self).__init__()
|
||||
|
||||
assert attention_type in ('affine', 'dot_prod')
|
||||
if attention_type == 'affine':
|
||||
self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
|
||||
|
||||
self.attention_type = attention_type
|
||||
self.input_linear = nn.Linear(query_vec_size, query_vec_size)
|
||||
self.type_linear = nn.Linear(32, query_vec_size)
|
||||
self.V = Parameter(torch.FloatTensor(query_vec_size), requires_grad=True)
|
||||
self.tanh = nn.Tanh()
|
||||
self.context_linear = nn.Conv1d(src_encoding_size, query_vec_size, 1, 1)
|
||||
self.coverage_linear = nn.Conv1d(1, query_vec_size, 1, 1)
|
||||
|
||||
|
||||
nn.init.uniform_(self.V, -1, 1)
|
||||
|
||||
def forward(self, src_encodings, src_token_mask, query_vec):
|
||||
"""
|
||||
:param src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2)
|
||||
:param src_token_mask: Variable(batch_size, src_sent_len)
|
||||
:param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
|
||||
:return: Variable(tgt_action_num, batch_size, src_sent_len)
|
||||
"""
|
||||
|
||||
# (batch_size, 1, src_sent_len, query_vec_size)
|
||||
|
||||
if self.attention_type == 'affine':
|
||||
src_encodings = self.src_encoding_linear(src_encodings)
|
||||
src_encodings = src_encodings.unsqueeze(1)
|
||||
|
||||
# (batch_size, tgt_action_num, query_vec_size, 1)
|
||||
q = query_vec.permute(1, 0, 2).unsqueeze(3)
|
||||
|
||||
weights = torch.matmul(src_encodings, q).squeeze(3)
|
||||
|
||||
weights = weights.permute(1, 0, 2)
|
||||
|
||||
if src_token_mask is not None:
|
||||
src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
|
||||
weights.data.masked_fill_(src_token_mask, -float('inf'))
|
||||
|
||||
return weights.squeeze(0)
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/27
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : __init__.py.py
|
||||
# @Software: PyCharm
|
||||
"""
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : utils.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
from collections import deque, namedtuple
|
||||
|
||||
|
||||
# we'll use infinity as a default distance to nodes.
|
||||
inf = float('inf')
|
||||
Edge = namedtuple('Edge', 'start, end, cost')
|
||||
|
||||
|
||||
def make_edge(start, end, cost=1):
|
||||
return Edge(start, end, cost)
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self, edges):
|
||||
# let's check that the data is right
|
||||
wrong_edges = [i for i in edges if len(i) not in [2, 3]]
|
||||
if wrong_edges:
|
||||
raise ValueError('Wrong edges data: {}'.format(wrong_edges))
|
||||
|
||||
self.edges = [make_edge(*edge) for edge in edges]
|
||||
|
||||
@property
|
||||
def vertices(self):
|
||||
return set(
|
||||
# this piece of magic turns ([1,2], [3,4]) into [1, 2, 3, 4]
|
||||
# the set above makes it's elements unique.
|
||||
sum(
|
||||
([edge.start, edge.end] for edge in self.edges), []
|
||||
)
|
||||
)
|
||||
|
||||
def get_node_pairs(self, n1, n2, both_ends=True):
|
||||
if both_ends:
|
||||
node_pairs = [[n1, n2], [n2, n1]]
|
||||
else:
|
||||
node_pairs = [[n1, n2]]
|
||||
return node_pairs
|
||||
|
||||
def remove_edge(self, n1, n2, both_ends=True):
|
||||
node_pairs = self.get_node_pairs(n1, n2, both_ends)
|
||||
edges = self.edges[:]
|
||||
for edge in edges:
|
||||
if [edge.start, edge.end] in node_pairs:
|
||||
self.edges.remove(edge)
|
||||
|
||||
def add_edge(self, n1, n2, cost=1, both_ends=True):
|
||||
node_pairs = self.get_node_pairs(n1, n2, both_ends)
|
||||
for edge in self.edges:
|
||||
if [edge.start, edge.end] in node_pairs:
|
||||
return ValueError('Edge {} {} already exists'.format(n1, n2))
|
||||
|
||||
self.edges.append(Edge(start=n1, end=n2, cost=cost))
|
||||
if both_ends:
|
||||
self.edges.append(Edge(start=n2, end=n1, cost=cost))
|
||||
|
||||
@property
|
||||
def neighbours(self):
|
||||
neighbours = {vertex: set() for vertex in self.vertices}
|
||||
for edge in self.edges:
|
||||
neighbours[edge.start].add((edge.end, edge.cost))
|
||||
|
||||
return neighbours
|
||||
|
||||
def dijkstra(self, source, dest):
|
||||
assert source in self.vertices, 'Such source node doesn\'t exist'
|
||||
assert dest in self.vertices, 'Such source node doesn\'t exis'
|
||||
|
||||
# 1. Mark all nodes unvisited and store them.
|
||||
# 2. Set the distance to zero for our initial node
|
||||
# and to infinity for other nodes.
|
||||
distances = {vertex: inf for vertex in self.vertices}
|
||||
previous_vertices = {
|
||||
vertex: None for vertex in self.vertices
|
||||
}
|
||||
distances[source] = 0
|
||||
vertices = self.vertices.copy()
|
||||
|
||||
while vertices:
|
||||
# 3. Select the unvisited node with the smallest distance,
|
||||
# it's current node now.
|
||||
current_vertex = min(
|
||||
vertices, key=lambda vertex: distances[vertex])
|
||||
|
||||
# 6. Stop, if the smallest distance
|
||||
# among the unvisited nodes is infinity.
|
||||
if distances[current_vertex] == inf:
|
||||
break
|
||||
|
||||
# 4. Find unvisited neighbors for the current node
|
||||
# and calculate their distances through the current node.
|
||||
for neighbour, cost in self.neighbours[current_vertex]:
|
||||
alternative_route = distances[current_vertex] + cost
|
||||
|
||||
# Compare the newly calculated distance to the assigned
|
||||
# and save the smaller one.
|
||||
if alternative_route < distances[neighbour]:
|
||||
distances[neighbour] = alternative_route
|
||||
previous_vertices[neighbour] = current_vertex
|
||||
|
||||
# 5. Mark the current node as visited
|
||||
# and remove it from the unvisited set.
|
||||
vertices.remove(current_vertex)
|
||||
|
||||
path, current_vertex = deque(), dest
|
||||
while previous_vertices[current_vertex] is not None:
|
||||
path.appendleft(current_vertex)
|
||||
current_vertex = previous_vertices[current_vertex]
|
||||
if path:
|
||||
path.appendleft(current_vertex)
|
||||
return path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
graph = Graph([
|
||||
("a", "b", 7), ("a", "c", 9), ("a", "f", 14), ("b", "c", 10),
|
||||
("b", "d", 15), ("c", "d", 11), ("c", "f", 2), ("d", "e", 6),
|
||||
("e", "f", 9)])
|
||||
|
||||
print(graph.dijkstra("a", "e"))
|
|
@ -0,0 +1,229 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : utils.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.rule import semQL as define_rule
|
||||
from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
|
||||
|
||||
|
||||
def _build_single_filter(lf, f):
|
||||
# No conjunction
|
||||
agg = lf.pop(0)
|
||||
column = lf.pop(0)
|
||||
if len(lf) == 0:
|
||||
table = None
|
||||
else:
|
||||
table = lf.pop(0)
|
||||
if not isinstance(table, define_rule.T):
|
||||
lf.insert(0, table)
|
||||
table = None
|
||||
assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
|
||||
if len(f.production.split()) == 3:
|
||||
f.add_children(agg)
|
||||
agg.set_parent(f)
|
||||
agg.add_children(column)
|
||||
column.set_parent(agg)
|
||||
if table is not None:
|
||||
column.add_children(table)
|
||||
table.set_parent(column)
|
||||
else:
|
||||
# Subquery
|
||||
f.add_children(agg)
|
||||
agg.set_parent(f)
|
||||
agg.add_children(column)
|
||||
column.set_parent(agg)
|
||||
if table is not None:
|
||||
column.add_children(table)
|
||||
table.set_parent(column)
|
||||
_root = _build(lf)
|
||||
f.add_children(_root)
|
||||
_root.set_parent(f)
|
||||
|
||||
|
||||
def _build_filter(lf, root_filter):
|
||||
assert isinstance(root_filter, define_rule.Filter)
|
||||
op = root_filter.production.split()[1]
|
||||
if op == 'and' or op == 'or':
|
||||
for i in range(2):
|
||||
child = lf.pop(0)
|
||||
op = child.production.split()[1]
|
||||
if op == 'and' or op == 'or':
|
||||
_f = _build_filter(lf, child)
|
||||
root_filter.add_children(_f)
|
||||
_f.set_parent(root_filter)
|
||||
else:
|
||||
_build_single_filter(lf, child)
|
||||
root_filter.add_children(child)
|
||||
child.set_parent(root_filter)
|
||||
else:
|
||||
_build_single_filter(lf, root_filter)
|
||||
return root_filter
|
||||
|
||||
|
||||
def _build(lf):
|
||||
root = lf.pop(0)
|
||||
assert isinstance(root, define_rule.Root)
|
||||
length = len(root.production.split()) - 1
|
||||
while len(root.children) != length:
|
||||
c_instance = lf.pop(0)
|
||||
if isinstance(c_instance, define_rule.Sel):
|
||||
sel_instance = c_instance
|
||||
root.add_children(sel_instance)
|
||||
sel_instance.set_parent(root)
|
||||
|
||||
# define_rule.N
|
||||
c_instance = lf.pop(0)
|
||||
c_instance.set_parent(sel_instance)
|
||||
sel_instance.add_children(c_instance)
|
||||
assert isinstance(c_instance, define_rule.N)
|
||||
for i in range(c_instance.id_c + 1):
|
||||
agg = lf.pop(0)
|
||||
column = lf.pop(0)
|
||||
if len(lf) == 0:
|
||||
table = None
|
||||
else:
|
||||
table = lf.pop(0)
|
||||
if not isinstance(table, define_rule.T):
|
||||
lf.insert(0, table)
|
||||
table = None
|
||||
assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
|
||||
c_instance.add_children(agg)
|
||||
agg.set_parent(c_instance)
|
||||
agg.add_children(column)
|
||||
column.set_parent(agg)
|
||||
if table is not None:
|
||||
column.add_children(table)
|
||||
table.set_parent(column)
|
||||
|
||||
elif isinstance(c_instance, define_rule.Sup) or isinstance(c_instance, define_rule.Order):
|
||||
root.add_children(c_instance)
|
||||
c_instance.set_parent(root)
|
||||
|
||||
agg = lf.pop(0)
|
||||
column = lf.pop(0)
|
||||
if len(lf) == 0:
|
||||
table = None
|
||||
else:
|
||||
table = lf.pop(0)
|
||||
if not isinstance(table, define_rule.T):
|
||||
lf.insert(0, table)
|
||||
table = None
|
||||
assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
|
||||
c_instance.add_children(agg)
|
||||
agg.set_parent(c_instance)
|
||||
agg.add_children(column)
|
||||
column.set_parent(agg)
|
||||
if table is not None:
|
||||
column.add_children(table)
|
||||
table.set_parent(column)
|
||||
|
||||
elif isinstance(c_instance, define_rule.Filter):
|
||||
_build_filter(lf, c_instance)
|
||||
root.add_children(c_instance)
|
||||
c_instance.set_parent(root)
|
||||
|
||||
return root
|
||||
|
||||
|
||||
def build_tree(lf):
|
||||
root = lf.pop(0)
|
||||
assert isinstance(root, define_rule.Root1)
|
||||
if root.id_c == 0 or root.id_c == 1 or root.id_c == 2:
|
||||
root_1 = _build(lf)
|
||||
root_2 = _build(lf)
|
||||
root.add_children(root_1)
|
||||
root.add_children(root_2)
|
||||
root_1.set_parent(root)
|
||||
root_2.set_parent(root)
|
||||
else:
|
||||
root_1 = _build(lf)
|
||||
root.add_children(root_1)
|
||||
root_1.set_parent(root)
|
||||
verify(root)
|
||||
# eliminate_parent(root)
|
||||
|
||||
|
||||
def eliminate_parent(node):
|
||||
for child in node.children:
|
||||
eliminate_parent(child)
|
||||
node.children = list()
|
||||
|
||||
|
||||
def verify(node):
|
||||
if isinstance(node, C) and len(node.children) > 0:
|
||||
table = node.children[0]
|
||||
assert table is None or isinstance(table, T)
|
||||
if isinstance(node, T):
|
||||
return
|
||||
children_num = len(node.children)
|
||||
if isinstance(node, Root1):
|
||||
if node.id_c == 0 or node.id_c == 1 or node.id_c == 2:
|
||||
assert children_num == 2
|
||||
else:
|
||||
assert children_num == 1
|
||||
elif isinstance(node, Root):
|
||||
assert children_num == len(node.production.split()) - 1
|
||||
elif isinstance(node, N):
|
||||
assert children_num == int(node.id_c) + 1
|
||||
elif isinstance(node, Sup) or isinstance(node, Order) or isinstance(node, Sel):
|
||||
assert children_num == 1
|
||||
elif isinstance(node, Filter):
|
||||
op = node.production.split()[1]
|
||||
if op == 'and' or op == 'or':
|
||||
assert children_num == 2
|
||||
else:
|
||||
if len(node.production.split()) == 3:
|
||||
assert children_num == 1
|
||||
else:
|
||||
assert children_num == 2
|
||||
for child in node.children:
|
||||
assert child.parent == node
|
||||
verify(child)
|
||||
|
||||
|
||||
def label_matrix(lf, matrix, node):
|
||||
nindex = lf.index(node)
|
||||
for child in node.children:
|
||||
if child not in lf:
|
||||
continue
|
||||
index = lf.index(child)
|
||||
matrix[nindex][index] = 1
|
||||
label_matrix(lf, matrix, child)
|
||||
|
||||
|
||||
def build_adjacency_matrix(lf, symmetry=False):
|
||||
_lf = list()
|
||||
for rule in lf:
|
||||
if isinstance(rule, A) or isinstance(rule, C) or isinstance(rule, T):
|
||||
continue
|
||||
_lf.append(rule)
|
||||
length = len(_lf)
|
||||
matrix = np.zeros((length, length,))
|
||||
label_matrix(_lf, matrix, _lf[0])
|
||||
if symmetry:
|
||||
matrix += matrix.T
|
||||
return matrix
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open(r'..\data\train.json', 'r') as f:
|
||||
data = json.load(f)
|
||||
for d in data:
|
||||
rule_label = [eval(x) for x in d['rule_label'].strip().split(' ')]
|
||||
print(d['question'])
|
||||
print(rule_label)
|
||||
build_tree(copy.copy(rule_label))
|
||||
adjacency_matrix = build_adjacency_matrix(rule_label, symmetry=True)
|
||||
print(adjacency_matrix)
|
||||
print('===\n\n')
|
|
@ -0,0 +1,399 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/24
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : semQL.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
Keywords = ['des', 'asc', 'and', 'or', 'sum', 'min', 'max', 'avg', 'none', '=', '!=', '<', '>', '<=', '>=', 'between', 'like', 'not_like'] + [
|
||||
'in', 'not_in', 'count', 'intersect', 'union', 'except'
|
||||
]
|
||||
|
||||
|
||||
class Grammar(object):
|
||||
def __init__(self, is_sketch=False):
|
||||
self.begin = 0
|
||||
self.type_id = 0
|
||||
self.is_sketch = is_sketch
|
||||
self.prod2id = {}
|
||||
self.type2id = {}
|
||||
self._init_grammar(Sel)
|
||||
self._init_grammar(Root)
|
||||
self._init_grammar(Sup)
|
||||
self._init_grammar(Filter)
|
||||
self._init_grammar(Order)
|
||||
self._init_grammar(N)
|
||||
self._init_grammar(Root1)
|
||||
|
||||
if not self.is_sketch:
|
||||
self._init_grammar(A)
|
||||
|
||||
self._init_id2prod()
|
||||
self.type2id[C] = self.type_id
|
||||
self.type_id += 1
|
||||
self.type2id[T] = self.type_id
|
||||
|
||||
def _init_grammar(self, Cls):
|
||||
"""
|
||||
get the production of class Cls
|
||||
:param Cls:
|
||||
:return:
|
||||
"""
|
||||
production = Cls._init_grammar()
|
||||
for p in production:
|
||||
self.prod2id[p] = self.begin
|
||||
self.begin += 1
|
||||
self.type2id[Cls] = self.type_id
|
||||
self.type_id += 1
|
||||
|
||||
def _init_id2prod(self):
|
||||
self.id2prod = {}
|
||||
for key, value in self.prod2id.items():
|
||||
self.id2prod[value] = key
|
||||
|
||||
def get_production(self, Cls):
|
||||
return Cls._init_grammar()
|
||||
|
||||
|
||||
class Action(object):
|
||||
def __init__(self):
|
||||
self.pt = 0
|
||||
self.production = None
|
||||
self.children = list()
|
||||
|
||||
def get_next_action(self, is_sketch=False):
|
||||
actions = list()
|
||||
for x in self.production.split(' ')[1:]:
|
||||
if x not in Keywords:
|
||||
rule_type = eval(x)
|
||||
if is_sketch:
|
||||
if rule_type is not A:
|
||||
actions.append(rule_type)
|
||||
else:
|
||||
actions.append(rule_type)
|
||||
return actions
|
||||
|
||||
def set_parent(self, parent):
|
||||
self.parent = parent
|
||||
|
||||
def add_children(self, child):
|
||||
self.children.append(child)
|
||||
|
||||
|
||||
class Root1(Action):
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(Root1, self).__init__()
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
# TODO: should add Root grammar to this
|
||||
self.grammar_dict = {
|
||||
0: 'Root1 intersect Root Root',
|
||||
1: 'Root1 union Root Root',
|
||||
2: 'Root1 except Root Root',
|
||||
3: 'Root1 Root',
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'Root1(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'Root1(' + str(self.id_c) + ')'
|
||||
|
||||
|
||||
class Root(Action):
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(Root, self).__init__()
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
# TODO: should add Root grammar to this
|
||||
self.grammar_dict = {
|
||||
0: 'Root Sel Sup Filter',
|
||||
1: 'Root Sel Filter Order',
|
||||
2: 'Root Sel Sup',
|
||||
3: 'Root Sel Filter',
|
||||
4: 'Root Sel Order',
|
||||
5: 'Root Sel'
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'Root(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'Root(' + str(self.id_c) + ')'
|
||||
|
||||
|
||||
class N(Action):
|
||||
"""
|
||||
Number of Columns
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(N, self).__init__()
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
self.grammar_dict = {
|
||||
0: 'N A',
|
||||
1: 'N A A',
|
||||
2: 'N A A A',
|
||||
3: 'N A A A A',
|
||||
4: 'N A A A A A'
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'N(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'N(' + str(self.id_c) + ')'
|
||||
|
||||
class C(Action):
|
||||
"""
|
||||
Column
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(C, self).__init__()
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self.production = 'C T'
|
||||
self.table = None
|
||||
|
||||
def __str__(self):
|
||||
return 'C(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'C(' + str(self.id_c) + ')'
|
||||
|
||||
|
||||
class T(Action):
|
||||
"""
|
||||
Table
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(T, self).__init__()
|
||||
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self.production = 'T min'
|
||||
self.table = None
|
||||
|
||||
def __str__(self):
|
||||
return 'T(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'T(' + str(self.id_c) + ')'
|
||||
|
||||
|
||||
class A(Action):
|
||||
"""
|
||||
Aggregator
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(A, self).__init__()
|
||||
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
# TODO: should add Root grammar to this
|
||||
self.grammar_dict = {
|
||||
0: 'A none C',
|
||||
1: 'A max C',
|
||||
2: "A min C",
|
||||
3: "A count C",
|
||||
4: "A sum C",
|
||||
5: "A avg C"
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'A(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'A(' + str(self.grammar_dict[self.id_c].split(' ')[1]) + ')'
|
||||
|
||||
|
||||
class Sel(Action):
|
||||
"""
|
||||
Select
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(Sel, self).__init__()
|
||||
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
self.grammar_dict = {
|
||||
0: 'Sel N',
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'Sel(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'Sel(' + str(self.id_c) + ')'
|
||||
|
||||
class Filter(Action):
|
||||
"""
|
||||
Filter
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(Filter, self).__init__()
|
||||
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
self.grammar_dict = {
|
||||
# 0: "Filter 1"
|
||||
0: 'Filter and Filter Filter',
|
||||
1: 'Filter or Filter Filter',
|
||||
2: 'Filter = A',
|
||||
3: 'Filter != A',
|
||||
4: 'Filter < A',
|
||||
5: 'Filter > A',
|
||||
6: 'Filter <= A',
|
||||
7: 'Filter >= A',
|
||||
8: 'Filter between A',
|
||||
9: 'Filter like A',
|
||||
10: 'Filter not_like A',
|
||||
# now begin root
|
||||
11: 'Filter = A Root',
|
||||
12: 'Filter < A Root',
|
||||
13: 'Filter > A Root',
|
||||
14: 'Filter != A Root',
|
||||
15: 'Filter between A Root',
|
||||
16: 'Filter >= A Root',
|
||||
17: 'Filter <= A Root',
|
||||
# now for In
|
||||
18: 'Filter in A Root',
|
||||
19: 'Filter not_in A Root'
|
||||
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'Filter(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'Filter(' + str(self.grammar_dict[self.id_c]) + ')'
|
||||
|
||||
|
||||
class Sup(Action):
|
||||
"""
|
||||
Superlative
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(Sup, self).__init__()
|
||||
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
self.grammar_dict = {
|
||||
0: 'Sup des A',
|
||||
1: 'Sup asc A',
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'Sup(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'Sup(' + str(self.id_c) + ')'
|
||||
|
||||
|
||||
class Order(Action):
|
||||
"""
|
||||
Order
|
||||
"""
|
||||
def __init__(self, id_c, parent=None):
|
||||
super(Order, self).__init__()
|
||||
|
||||
self.parent = parent
|
||||
self.id_c = id_c
|
||||
self._init_grammar()
|
||||
self.production = self.grammar_dict[id_c]
|
||||
|
||||
@classmethod
|
||||
def _init_grammar(self):
|
||||
self.grammar_dict = {
|
||||
0: 'Order des A',
|
||||
1: 'Order asc A',
|
||||
}
|
||||
self.production_id = {}
|
||||
for id_x, value in enumerate(self.grammar_dict.values()):
|
||||
self.production_id[value] = id_x
|
||||
|
||||
return self.grammar_dict.values()
|
||||
|
||||
def __str__(self):
|
||||
return 'Order(' + str(self.id_c) + ')'
|
||||
|
||||
def __repr__(self):
|
||||
return 'Order(' + str(self.id_c) + ')'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(list(Root._init_grammar()))
|
|
@ -0,0 +1,321 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/27
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : sem_utils.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import re as regex
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
from pattern.en import lemma
|
||||
wordnet_lemmatizer = WordNetLemmatizer()
|
||||
|
||||
|
||||
def load_dataSets(args):
|
||||
with open(args.input_path, 'r') as f:
|
||||
datas = json.load(f)
|
||||
with open(os.path.join(args.data_path, 'tables.json'), 'r', encoding='utf8') as f:
|
||||
table_datas = json.load(f)
|
||||
schemas = dict()
|
||||
for i in range(len(table_datas)):
|
||||
schemas[table_datas[i]['db_id']] = table_datas[i]
|
||||
return datas, schemas
|
||||
|
||||
|
||||
def partial_match(query, table_name):
|
||||
query = [lemma(x) for x in query]
|
||||
table_name = [lemma(x) for x in table_name]
|
||||
if query in table_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_partial_match(query, table_names):
|
||||
query = lemma(query)
|
||||
table_names = [[lemma(x) for x in names.split(' ') ] for names in table_names]
|
||||
same_count = 0
|
||||
result = None
|
||||
for names in table_names:
|
||||
if query in names:
|
||||
same_count += 1
|
||||
result = names
|
||||
return result if same_count == 1 else False
|
||||
|
||||
|
||||
def multi_option(question, q_ind, names, N):
|
||||
for i in range(q_ind + 1, q_ind + N + 1):
|
||||
if i < len(question):
|
||||
re = is_partial_match(question[i][0], names)
|
||||
if re is not False:
|
||||
return re
|
||||
return False
|
||||
|
||||
|
||||
def multi_equal(question, q_ind, names, N):
|
||||
for i in range(q_ind + 1, q_ind + N + 1):
|
||||
if i < len(question):
|
||||
if question[i] == names:
|
||||
return i
|
||||
return False
|
||||
|
||||
|
||||
def random_choice(question_arg, question_arg_type, names, ground_col_labels, q_ind, N, origin_name):
|
||||
# first try if there are other table
|
||||
for t_ind, t_val in enumerate(question_arg_type):
|
||||
if t_val == ['table']:
|
||||
return names[origin_name.index(question_arg[t_ind])]
|
||||
for i in range(q_ind + 1, q_ind + N + 1):
|
||||
if i < len(question_arg):
|
||||
if len(ground_col_labels) == 0:
|
||||
for n in names:
|
||||
if partial_match(question_arg[i][0], n) is True:
|
||||
return n
|
||||
else:
|
||||
for n_id, n in enumerate(names):
|
||||
if n_id in ground_col_labels and partial_match(question_arg[i][0], n) is True:
|
||||
return n
|
||||
if len(ground_col_labels) > 0:
|
||||
return names[ground_col_labels[0]]
|
||||
else:
|
||||
return names[0]
|
||||
|
||||
|
||||
def find_table(cur_table, origin_table_names, question_arg_type, question_arg):
|
||||
h_table = None
|
||||
for i in range(len(question_arg_type))[::-1]:
|
||||
if question_arg_type[i] == ['table']:
|
||||
h_table = question_arg[i]
|
||||
h_table = origin_table_names.index(h_table)
|
||||
if h_table != cur_table:
|
||||
break
|
||||
if h_table != cur_table:
|
||||
return h_table
|
||||
|
||||
# find partial
|
||||
for i in range(len(question_arg_type))[::-1]:
|
||||
if question_arg_type[i] == ['NONE']:
|
||||
for t_id, table_name in enumerate(origin_table_names):
|
||||
if partial_match(question_arg[i], table_name) is True and t_id != h_table:
|
||||
return t_id
|
||||
|
||||
# random return
|
||||
for i in range(len(question_arg_type))[::-1]:
|
||||
if question_arg_type[i] == ['table']:
|
||||
h_table = question_arg[i]
|
||||
h_table = origin_table_names.index(h_table)
|
||||
return h_table
|
||||
|
||||
return cur_table
|
||||
|
||||
|
||||
def alter_not_in(datas, schemas):
|
||||
for d in datas:
|
||||
if 'Filter(19)' in d['model_result']:
|
||||
current_table = schemas[d['db_id']]
|
||||
current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']]
|
||||
current_table['col_table'] = [col[0] for col in current_table['column_names']]
|
||||
origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in
|
||||
d['table_names']]
|
||||
question_arg_type = d['question_arg_type']
|
||||
question_arg = d['question_arg']
|
||||
pred_label = d['model_result'].split(' ')
|
||||
|
||||
# get potiantial table
|
||||
cur_table = None
|
||||
for label_id, label_val in enumerate(pred_label):
|
||||
if label_val in ['Filter(19)']:
|
||||
cur_table = int(pred_label[label_id - 1][2:-1])
|
||||
break
|
||||
|
||||
h_table = find_table(cur_table, origin_table_names, question_arg_type, question_arg)
|
||||
|
||||
for label_id, label_val in enumerate(pred_label):
|
||||
if label_val in ['Filter(19)']:
|
||||
for primary in current_table['primary_keys']:
|
||||
if int(current_table['col_table'][primary]) == int(pred_label[label_id - 1][2:-1]):
|
||||
pred_label[label_id + 2] = 'C(' + str(
|
||||
d['col_set'].index(current_table['schema_content_clean'][primary])) + ')'
|
||||
break
|
||||
for pair in current_table['foreign_keys']:
|
||||
if int(current_table['col_table'][pair[0]]) == h_table and d['col_set'].index(
|
||||
current_table['schema_content_clean'][pair[1]]) == int(pred_label[label_id + 2][2:-1]):
|
||||
pred_label[label_id + 8] = 'C(' + str(
|
||||
d['col_set'].index(current_table['schema_content_clean'][pair[0]])) + ')'
|
||||
pred_label[label_id + 9] = 'T(' + str(h_table) + ')'
|
||||
break
|
||||
elif int(current_table['col_table'][pair[1]]) == h_table and d['col_set'].index(
|
||||
current_table['schema_content_clean'][pair[0]]) == int(pred_label[label_id + 2][2:-1]):
|
||||
pred_label[label_id + 8] = 'C(' + str(
|
||||
d['col_set'].index(current_table['schema_content_clean'][pair[1]])) + ')'
|
||||
pred_label[label_id + 9] = 'T(' + str(h_table) + ')'
|
||||
break
|
||||
pred_label[label_id + 3] = pred_label[label_id - 1]
|
||||
|
||||
d['model_result'] = " ".join(pred_label)
|
||||
|
||||
|
||||
def alter_inter(datas):
|
||||
for d in datas:
|
||||
if 'Filter(0)' in d['model_result']:
|
||||
now_result = d['model_result'].split(' ')
|
||||
index = now_result.index('Filter(0)')
|
||||
c1 = None
|
||||
c2 = None
|
||||
for i in range(index + 1, len(now_result)):
|
||||
if c1 is None and 'C(' in now_result[i]:
|
||||
c1 = now_result[i]
|
||||
elif c1 is not None and c2 is None and 'C(' in now_result[i]:
|
||||
c2 = now_result[i]
|
||||
|
||||
if c1 != c2 or c1 is None or c2 is None:
|
||||
continue
|
||||
replace_result = ['Root1(0)'] + now_result[1:now_result.index('Filter(0)')]
|
||||
for r_id, r_val in enumerate(now_result[now_result.index('Filter(0)') + 2:]):
|
||||
if 'Filter' in r_val:
|
||||
break
|
||||
|
||||
replace_result = replace_result + now_result[now_result.index('Filter(0)') + 1:r_id + now_result.index(
|
||||
'Filter(0)') + 2]
|
||||
replace_result = replace_result + now_result[1:now_result.index('Filter(0)')]
|
||||
|
||||
replace_result = replace_result + now_result[r_id + now_result.index('Filter(0)') + 2:]
|
||||
replace_result = " ".join(replace_result)
|
||||
d['model_result'] = replace_result
|
||||
|
||||
|
||||
def alter_column0(datas):
|
||||
"""
|
||||
Attach column * table
|
||||
:return: model_result_replace
|
||||
"""
|
||||
zero_count = 0
|
||||
count = 0
|
||||
result = []
|
||||
for d in datas:
|
||||
if 'C(0)' in d['model_result']:
|
||||
pattern = regex.compile('C\(.*?\) T\(.*?\)')
|
||||
result_pattern = list(set(pattern.findall(d['model_result'])))
|
||||
ground_col_labels = []
|
||||
for pa in result_pattern:
|
||||
pa = pa.split(' ')
|
||||
if pa[0] != 'C(0)':
|
||||
index = int(pa[1][2:-1])
|
||||
ground_col_labels.append(index)
|
||||
|
||||
ground_col_labels = list(set(ground_col_labels))
|
||||
question_arg_type = d['question_arg_type']
|
||||
question_arg = d['question_arg']
|
||||
table_names = [[lemma(x) for x in names.split(' ')] for names in d['table_names']]
|
||||
origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in
|
||||
d['table_names']]
|
||||
count += 1
|
||||
easy_flag = False
|
||||
for q_ind, q in enumerate(d['question_arg']):
|
||||
q = [lemma(x) for x in q]
|
||||
q_str = " ".join(" ".join(x) for x in d['question_arg'])
|
||||
if 'how many' in q_str or 'number of' in q_str or 'count of' in q_str:
|
||||
easy_flag = True
|
||||
if easy_flag:
|
||||
# check for the last one is a table word
|
||||
for q_ind, q in enumerate(d['question_arg']):
|
||||
if (q_ind > 0 and q == ['many'] and d['question_arg'][q_ind - 1] == ['how']) or (
|
||||
q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['number']) or (
|
||||
q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['count']):
|
||||
re = multi_equal(question_arg_type, q_ind, ['table'], 2)
|
||||
if re is not False:
|
||||
# This step work for the number of [table] example
|
||||
table_result = table_names[origin_table_names.index(question_arg[re])]
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
break
|
||||
else:
|
||||
re = multi_option(question_arg, q_ind, d['table_names'], 2)
|
||||
if re is not False:
|
||||
table_result = re
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
pass
|
||||
else:
|
||||
re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type))
|
||||
if re is not False:
|
||||
# This step work for the number of [table] example
|
||||
table_result = table_names[origin_table_names.index(question_arg[re])]
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
break
|
||||
pass
|
||||
table_result = random_choice(question_arg=question_arg,
|
||||
question_arg_type=question_arg_type,
|
||||
names=table_names,
|
||||
ground_col_labels=ground_col_labels, q_ind=q_ind, N=2,
|
||||
origin_name=origin_table_names)
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
|
||||
zero_count += 1
|
||||
break
|
||||
|
||||
else:
|
||||
M_OP = False
|
||||
for q_ind, q in enumerate(d['question_arg']):
|
||||
if M_OP is False and q in [['than'], ['least'], ['most'], ['msot'], ['fewest']] or \
|
||||
question_arg_type[q_ind] == ['M_OP']:
|
||||
M_OP = True
|
||||
re = multi_equal(question_arg_type, q_ind, ['table'], 3)
|
||||
if re is not False:
|
||||
# This step work for the number of [table] example
|
||||
table_result = table_names[origin_table_names.index(question_arg[re])]
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
break
|
||||
else:
|
||||
re = multi_option(question_arg, q_ind, d['table_names'], 3)
|
||||
if re is not False:
|
||||
table_result = re
|
||||
# print(table_result)
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
pass
|
||||
else:
|
||||
# zero_count += 1
|
||||
re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type))
|
||||
if re is not False:
|
||||
# This step work for the number of [table] example
|
||||
table_result = table_names[origin_table_names.index(question_arg[re])]
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
break
|
||||
|
||||
table_result = random_choice(question_arg=question_arg,
|
||||
question_arg_type=question_arg_type,
|
||||
names=table_names,
|
||||
ground_col_labels=ground_col_labels, q_ind=q_ind, N=2,
|
||||
origin_name=origin_table_names)
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
|
||||
pass
|
||||
if M_OP is False:
|
||||
table_result = random_choice(question_arg=question_arg,
|
||||
question_arg_type=question_arg_type,
|
||||
names=table_names, ground_col_labels=ground_col_labels, q_ind=q_ind,
|
||||
N=2,
|
||||
origin_name=origin_table_names)
|
||||
result.append((d['query'], d['question'], table_result, d))
|
||||
|
||||
for re in result:
|
||||
table_names = [[lemma(x) for x in names.split(' ')] for names in re[3]['table_names']]
|
||||
origin_table_names = [[x for x in names.split(' ')] for names in re[3]['table_names']]
|
||||
if re[2] in table_names:
|
||||
re[3]['rule_count'] = table_names.index(re[2])
|
||||
else:
|
||||
re[3]['rule_count'] = origin_table_names.index(re[2])
|
||||
|
||||
for data in datas:
|
||||
if 'rule_count' in data:
|
||||
str_replace = 'C(0) T(' + str(data['rule_count']) + ')'
|
||||
replace_result = regex.sub('C\(0\) T\(.\)', str_replace, data['model_result'])
|
||||
data['model_result_replace'] = replace_result
|
||||
else:
|
||||
data['model_result_replace'] = data['model_result']
|
||||
|
||||
|
|
@ -0,0 +1,348 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : utils.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from nltk.stem import WordNetLemmatizer
|
||||
|
||||
from src.dataset import Example
|
||||
from src.rule import lf
|
||||
from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
|
||||
|
||||
wordnet_lemmatizer = WordNetLemmatizer()
|
||||
|
||||
|
||||
def load_word_emb(file_name, use_small=False):
|
||||
print ('Loading word embedding from %s'%file_name)
|
||||
ret = {}
|
||||
with open(file_name) as inf:
|
||||
for idx, line in enumerate(inf):
|
||||
if (use_small and idx >= 500000):
|
||||
break
|
||||
info = line.strip().split(' ')
|
||||
if info[0].lower() not in ret:
|
||||
ret[info[0]] = np.array(list(map(lambda x:float(x), info[1:])))
|
||||
return ret
|
||||
|
||||
def lower_keys(x):
|
||||
if isinstance(x, list):
|
||||
return [lower_keys(v) for v in x]
|
||||
elif isinstance(x, dict):
|
||||
return dict((k.lower(), lower_keys(v)) for k, v in x.items())
|
||||
else:
|
||||
return x
|
||||
|
||||
def get_table_colNames(tab_ids, tab_cols):
|
||||
table_col_dict = {}
|
||||
for ci, cv in zip(tab_ids, tab_cols):
|
||||
if ci != -1:
|
||||
table_col_dict[ci] = table_col_dict.get(ci, []) + cv
|
||||
result = []
|
||||
for ci in range(len(table_col_dict)):
|
||||
result.append(table_col_dict[ci])
|
||||
return result
|
||||
|
||||
def get_col_table_dict(tab_cols, tab_ids, sql):
|
||||
table_dict = {}
|
||||
for c_id, c_v in enumerate(sql['col_set']):
|
||||
for cor_id, cor_val in enumerate(tab_cols):
|
||||
if c_v == cor_val:
|
||||
table_dict[tab_ids[cor_id]] = table_dict.get(tab_ids[cor_id], []) + [c_id]
|
||||
|
||||
col_table_dict = {}
|
||||
for key_item, value_item in table_dict.items():
|
||||
for value in value_item:
|
||||
col_table_dict[value] = col_table_dict.get(value, []) + [key_item]
|
||||
col_table_dict[0] = [x for x in range(len(table_dict) - 1)]
|
||||
return col_table_dict
|
||||
|
||||
|
||||
def schema_linking(question_arg, question_arg_type, one_hot_type, col_set_type, col_set_iter, sql):
|
||||
|
||||
for count_q, t_q in enumerate(question_arg_type):
|
||||
t = t_q[0]
|
||||
if t == 'NONE':
|
||||
continue
|
||||
elif t == 'table':
|
||||
one_hot_type[count_q][0] = 1
|
||||
question_arg[count_q] = ['table'] + question_arg[count_q]
|
||||
elif t == 'col':
|
||||
one_hot_type[count_q][1] = 1
|
||||
try:
|
||||
col_set_type[col_set_iter.index(question_arg[count_q])][1] = 5
|
||||
question_arg[count_q] = ['column'] + question_arg[count_q]
|
||||
except:
|
||||
print(col_set_iter, question_arg[count_q])
|
||||
raise RuntimeError("not in col set")
|
||||
elif t == 'agg':
|
||||
one_hot_type[count_q][2] = 1
|
||||
elif t == 'MORE':
|
||||
one_hot_type[count_q][3] = 1
|
||||
elif t == 'MOST':
|
||||
one_hot_type[count_q][4] = 1
|
||||
elif t == 'value':
|
||||
one_hot_type[count_q][5] = 1
|
||||
question_arg[count_q] = ['value'] + question_arg[count_q]
|
||||
else:
|
||||
if len(t_q) == 1:
|
||||
for col_probase in t_q:
|
||||
if col_probase == 'asd':
|
||||
continue
|
||||
try:
|
||||
col_set_type[sql['col_set'].index(col_probase)][2] = 5
|
||||
question_arg[count_q] = ['value'] + question_arg[count_q]
|
||||
except:
|
||||
print(sql['col_set'], col_probase)
|
||||
raise RuntimeError('not in col')
|
||||
one_hot_type[count_q][5] = 1
|
||||
else:
|
||||
for col_probase in t_q:
|
||||
if col_probase == 'asd':
|
||||
continue
|
||||
col_set_type[sql['col_set'].index(col_probase)][3] += 1
|
||||
|
||||
def process(sql, table):
|
||||
|
||||
process_dict = {}
|
||||
|
||||
origin_sql = sql['question_toks']
|
||||
table_names = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in table['table_names']]
|
||||
|
||||
sql['pre_sql'] = copy.deepcopy(sql)
|
||||
|
||||
tab_cols = [col[1] for col in table['column_names']]
|
||||
tab_ids = [col[0] for col in table['column_names']]
|
||||
|
||||
col_set_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in sql['col_set']]
|
||||
col_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(" ")] for x in tab_cols]
|
||||
q_iter_small = [wordnet_lemmatizer.lemmatize(x).lower() for x in origin_sql]
|
||||
question_arg = copy.deepcopy(sql['question_arg'])
|
||||
question_arg_type = sql['question_arg_type']
|
||||
one_hot_type = np.zeros((len(question_arg_type), 6))
|
||||
|
||||
col_set_type = np.zeros((len(col_set_iter), 4))
|
||||
|
||||
process_dict['col_set_iter'] = col_set_iter
|
||||
process_dict['q_iter_small'] = q_iter_small
|
||||
process_dict['col_set_type'] = col_set_type
|
||||
process_dict['question_arg'] = question_arg
|
||||
process_dict['question_arg_type'] = question_arg_type
|
||||
process_dict['one_hot_type'] = one_hot_type
|
||||
process_dict['tab_cols'] = tab_cols
|
||||
process_dict['tab_ids'] = tab_ids
|
||||
process_dict['col_iter'] = col_iter
|
||||
process_dict['table_names'] = table_names
|
||||
|
||||
return process_dict
|
||||
|
||||
def is_valid(rule_label, col_table_dict, sql):
|
||||
try:
|
||||
lf.build_tree(copy.copy(rule_label))
|
||||
except:
|
||||
print(rule_label)
|
||||
|
||||
flag = False
|
||||
for r_id, rule in enumerate(rule_label):
|
||||
if type(rule) == C:
|
||||
try:
|
||||
assert rule_label[r_id + 1].id_c in col_table_dict[rule.id_c], print(sql['question'])
|
||||
except:
|
||||
flag = True
|
||||
print(sql['question'])
|
||||
return flag is False
|
||||
|
||||
|
||||
def to_batch_seq(sql_data, table_data, idxes, st, ed,
|
||||
is_train=True):
|
||||
"""
|
||||
|
||||
:return:
|
||||
"""
|
||||
examples = []
|
||||
|
||||
for i in range(st, ed):
|
||||
sql = sql_data[idxes[i]]
|
||||
table = table_data[sql['db_id']]
|
||||
|
||||
process_dict = process(sql, table)
|
||||
|
||||
for c_id, col_ in enumerate(process_dict['col_set_iter']):
|
||||
for q_id, ori in enumerate(process_dict['q_iter_small']):
|
||||
if ori in col_:
|
||||
process_dict['col_set_type'][c_id][0] += 1
|
||||
|
||||
schema_linking(process_dict['question_arg'], process_dict['question_arg_type'],
|
||||
process_dict['one_hot_type'], process_dict['col_set_type'], process_dict['col_set_iter'], sql)
|
||||
|
||||
col_table_dict = get_col_table_dict(process_dict['tab_cols'], process_dict['tab_ids'], sql)
|
||||
table_col_name = get_table_colNames(process_dict['tab_ids'], process_dict['col_iter'])
|
||||
|
||||
process_dict['col_set_iter'][0] = ['count', 'number', 'many']
|
||||
|
||||
rule_label = None
|
||||
if 'rule_label' in sql:
|
||||
rule_label = [eval(x) for x in sql['rule_label'].strip().split(' ')]
|
||||
if is_valid(rule_label, col_table_dict=col_table_dict, sql=sql) is False:
|
||||
continue
|
||||
|
||||
example = Example(
|
||||
src_sent=process_dict['question_arg'],
|
||||
col_num=len(process_dict['col_set_iter']),
|
||||
vis_seq=(sql['question'], process_dict['col_set_iter'], sql['query']),
|
||||
tab_cols=process_dict['col_set_iter'],
|
||||
sql=sql['query'],
|
||||
one_hot_type=process_dict['one_hot_type'],
|
||||
col_hot_type=process_dict['col_set_type'],
|
||||
table_names=process_dict['table_names'],
|
||||
table_len=len(process_dict['table_names']),
|
||||
col_table_dict=col_table_dict,
|
||||
cols=process_dict['tab_cols'],
|
||||
table_col_name=table_col_name,
|
||||
table_col_len=len(table_col_name),
|
||||
tokenized_src_sent=process_dict['col_set_type'],
|
||||
tgt_actions=rule_label
|
||||
)
|
||||
example.sql_json = copy.deepcopy(sql)
|
||||
examples.append(example)
|
||||
|
||||
if is_train:
|
||||
examples.sort(key=lambda e: -len(e.src_sent))
|
||||
return examples
|
||||
else:
|
||||
return examples
|
||||
|
||||
def epoch_train(model, optimizer, batch_size, sql_data, table_data,
|
||||
args, epoch=0, loss_epoch_threshold=20, sketch_loss_coefficient=0.2):
|
||||
model.train()
|
||||
# shuffe
|
||||
perm=np.random.permutation(len(sql_data))
|
||||
cum_loss = 0.0
|
||||
st = 0
|
||||
while st < len(sql_data):
|
||||
ed = st+batch_size if st+batch_size < len(perm) else len(perm)
|
||||
examples = to_batch_seq(sql_data, table_data, perm, st, ed)
|
||||
optimizer.zero_grad()
|
||||
|
||||
score = model.forward(examples)
|
||||
loss_sketch = -score[0]
|
||||
loss_lf = -score[1]
|
||||
|
||||
loss_sketch = torch.mean(loss_sketch)
|
||||
loss_lf = torch.mean(loss_lf)
|
||||
|
||||
if epoch > loss_epoch_threshold:
|
||||
loss = loss_lf + sketch_loss_coefficient * loss_sketch
|
||||
else:
|
||||
loss = loss_lf + loss_sketch
|
||||
|
||||
loss.backward()
|
||||
if args.clip_grad > 0.:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
|
||||
optimizer.step()
|
||||
cum_loss += loss.data.cpu().numpy()*(ed - st)
|
||||
st = ed
|
||||
return cum_loss / len(sql_data)
|
||||
|
||||
def epoch_acc(model, batch_size, sql_data, table_data, beam_size=3):
|
||||
model.eval()
|
||||
perm = list(range(len(sql_data)))
|
||||
st = 0
|
||||
|
||||
json_datas = []
|
||||
while st < len(sql_data):
|
||||
ed = st+batch_size if st+batch_size < len(perm) else len(perm)
|
||||
examples = to_batch_seq(sql_data, table_data, perm, st, ed,
|
||||
is_train=False)
|
||||
for example in examples:
|
||||
results_all = model.parse(example, beam_size=beam_size)
|
||||
results = results_all[0]
|
||||
list_preds = []
|
||||
try:
|
||||
|
||||
pred = " ".join([str(x) for x in results[0].actions])
|
||||
for x in results:
|
||||
list_preds.append(" ".join(str(x.actions)))
|
||||
except Exception as e:
|
||||
# print('Epoch Acc: ', e)
|
||||
# print(results)
|
||||
# print(results_all)
|
||||
pred = ""
|
||||
|
||||
simple_json = example.sql_json['pre_sql']
|
||||
|
||||
simple_json['sketch_result'] = " ".join(str(x) for x in results_all[1])
|
||||
simple_json['model_result'] = pred
|
||||
|
||||
json_datas.append(simple_json)
|
||||
st = ed
|
||||
return json_datas
|
||||
|
||||
def eval_acc(preds, sqls):
|
||||
sketch_correct, best_correct = 0, 0
|
||||
for i, (pred, sql) in enumerate(zip(preds, sqls)):
|
||||
if pred['model_result'] == sql['rule_label']:
|
||||
best_correct += 1
|
||||
print(best_correct / len(preds))
|
||||
return best_correct / len(preds)
|
||||
|
||||
|
||||
def load_data_new(sql_path, table_data, use_small=False):
|
||||
sql_data = []
|
||||
|
||||
print("Loading data from %s" % sql_path)
|
||||
with open(sql_path) as inf:
|
||||
data = lower_keys(json.load(inf))
|
||||
sql_data += data
|
||||
|
||||
table_data_new = {table['db_id']: table for table in table_data}
|
||||
|
||||
if use_small:
|
||||
return sql_data[:80], table_data_new
|
||||
else:
|
||||
return sql_data, table_data_new
|
||||
|
||||
|
||||
def load_dataset(dataset_dir, use_small=False):
|
||||
print("Loading from datasets...")
|
||||
|
||||
TABLE_PATH = os.path.join(dataset_dir, "tables.json")
|
||||
TRAIN_PATH = os.path.join(dataset_dir, "train.json")
|
||||
DEV_PATH = os.path.join(dataset_dir, "dev.json")
|
||||
with open(TABLE_PATH) as inf:
|
||||
print("Loading data from %s"%TABLE_PATH)
|
||||
table_data = json.load(inf)
|
||||
|
||||
train_sql_data, train_table_data = load_data_new(TRAIN_PATH, table_data, use_small=use_small)
|
||||
val_sql_data, val_table_data = load_data_new(DEV_PATH, table_data, use_small=use_small)
|
||||
|
||||
return train_sql_data, train_table_data, val_sql_data, val_table_data
|
||||
|
||||
|
||||
def save_checkpoint(model, checkpoint_name):
|
||||
torch.save(model.state_dict(), checkpoint_name)
|
||||
|
||||
|
||||
def save_args(args, path):
|
||||
with open(path, 'w') as f:
|
||||
f.write(json.dumps(vars(args), indent=4))
|
||||
|
||||
def init_log_checkpoint_path(args):
|
||||
save_path = args.save
|
||||
dir_name = save_path + str(int(time.time()))
|
||||
save_path = os.path.join(os.path.curdir, 'saved_model', dir_name)
|
||||
if os.path.exists(save_path) is False:
|
||||
os.makedirs(save_path)
|
||||
return save_path
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# @Time : 2019/5/25
|
||||
# @Author : Jiaqi&Zecheng
|
||||
# @File : train.py
|
||||
# @Software: PyCharm
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
import copy
|
||||
|
||||
from src import args as arg
|
||||
from src import utils
|
||||
from src.models.model import IRNet
|
||||
from src.rule import semQL
|
||||
|
||||
|
||||
def train(args):
|
||||
"""
|
||||
:param args:
|
||||
:return:
|
||||
"""
|
||||
|
||||
grammar = semQL.Grammar()
|
||||
sql_data, table_data, val_sql_data,\
|
||||
val_table_data= utils.load_dataset(args.dataset, use_small=args.toy)
|
||||
|
||||
model = IRNet(args, grammar)
|
||||
|
||||
|
||||
if args.cuda: model.cuda()
|
||||
|
||||
# now get the optimizer
|
||||
optimizer_cls = eval('torch.optim.%s' % args.optimizer)
|
||||
optimizer = optimizer_cls(model.parameters(), lr=args.lr)
|
||||
print('Enable Learning Rate Scheduler: ', args.lr_scheduler)
|
||||
if args.lr_scheduler:
|
||||
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
|
||||
print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)
|
||||
|
||||
if args.load_model:
|
||||
print('load pretrained model from %s'% (args.load_model))
|
||||
pretrained_model = torch.load(args.load_model,
|
||||
map_location=lambda storage, loc: storage)
|
||||
pretrained_modeled = copy.deepcopy(pretrained_model)
|
||||
for k in pretrained_model.keys():
|
||||
if k not in model.state_dict().keys():
|
||||
del pretrained_modeled[k]
|
||||
|
||||
model.load_state_dict(pretrained_modeled)
|
||||
|
||||
model.word_emb = utils.load_word_emb(args.glove_embed_path)
|
||||
# begin train
|
||||
|
||||
model_save_path = utils.init_log_checkpoint_path(args)
|
||||
utils.save_args(args, os.path.join(model_save_path, 'config.json'))
|
||||
best_dev_acc = .0
|
||||
|
||||
try:
|
||||
with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:
|
||||
for epoch in tqdm.tqdm(range(args.epoch)):
|
||||
if args.lr_scheduler:
|
||||
scheduler.step()
|
||||
epoch_begin = time.time()
|
||||
loss = utils.epoch_train(model, optimizer, args.batch_size, sql_data, table_data, args,
|
||||
loss_epoch_threshold=args.loss_epoch_threshold,
|
||||
sketch_loss_coefficient=args.sketch_loss_coefficient)
|
||||
epoch_end = time.time()
|
||||
json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
|
||||
beam_size=args.beam_size)
|
||||
acc = utils.eval_acc(json_datas, val_sql_data)
|
||||
|
||||
if acc > best_dev_acc:
|
||||
utils.save_checkpoint(model, os.path.join(model_save_path, 'best_model.model'))
|
||||
best_dev_acc = acc
|
||||
utils.save_checkpoint(model, os.path.join(model_save_path, '{%s}_{%s}.model') % (epoch, acc))
|
||||
|
||||
log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % (
|
||||
epoch + 1, loss, acc, acc, epoch_end - epoch_begin)
|
||||
tqdm.tqdm.write(log_str)
|
||||
epoch_fd.write(log_str)
|
||||
epoch_fd.flush()
|
||||
except Exception as e:
|
||||
# Save model
|
||||
utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
|
||||
print(e)
|
||||
tb = traceback.format_exc()
|
||||
print(tb)
|
||||
else:
|
||||
utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
|
||||
json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
|
||||
beam_size=args.beam_size)
|
||||
acc = utils.eval_acc(json_datas, val_sql_data)
|
||||
|
||||
print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (acc, acc, acc,))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
arg_parser = arg.init_arg_parser()
|
||||
args = arg.init_config(arg_parser)
|
||||
print(args)
|
||||
train(args)
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
devices=$1
|
||||
save_name=$2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$devices nohup python -u train.py --dataset ./data \
|
||||
--glove_embed_path ./data/glove.42B.300d.txt \
|
||||
--cuda \
|
||||
--epoch 50 \
|
||||
--loss_epoch_threshold 50 \
|
||||
--sketch_loss_coefficie 1.0 \
|
||||
--beam_size 1 \
|
||||
--seed 90 \
|
||||
--save ${save_name} \
|
||||
--embed_size 300 \
|
||||
--sentence_features \
|
||||
--column_pointer \
|
||||
--hidden_size 300 \
|
||||
--lr_scheduler \
|
||||
--lr_scheduler_gammar 0.5 \
|
||||
--att_vec_size 300 > ${save_name}".log" &
|
Загрузка…
Ссылка в новой задаче