This commit is contained in:
Olive 2019-11-01 10:26:20 +08:00 коммит произвёл GitHub
Родитель 4e4521e795
Коммит de4c3d02ca
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
28 изменённых файлов: 4990 добавлений и 8 удалений

Просмотреть файл

@ -1,14 +1,86 @@
# IRNet
Code for our ACL'19 accepted paper: [Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation](https://arxiv.org/pdf/1905.08205.pdf)
<p align='center'>
<img src='https://zhanzecheng.github.io/architecture.png' width="91%"/>
</p>
## Environment Setup
* `Python3.6`
* `Pytorch 0.4.0` or higher
Install Python dependency via `pip install -r requirements.txt` when the environment of Python and Pytorch is setup.
## Running Code
#### Data preparation
* Download [Glove Embedding](https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip) and put `glove.42B.300d` under `./data/` directory
* Download [Pretrained IRNet](https://drive.google.com/open?id=1VoV28fneYss8HaZmoThGlvYU3A-aK31q) and put `
IRNet_pretrained.model` under `./saved_model/` directory
* Download preprocessed train/dev datasets from [here](https://drive.google.com/open?id=1YFV1GoLivOMlmunKW0nkzefKULO4wtrn) and put `train.json`, `dev.json` and
`tables.json` under `./data/` directory
##### Generating train/dev data by yourself
You could process the origin [Spider Data](https://drive.google.com/uc?export=download&id=11icoH_EA-NYb0OrPTdehRWm_d7-DIzWX) by your own. Download and put `train.json`, `dev.json` and
`tables.json` under `./data/` directory and follow the instruction on `./preprocess/`
#### Training
Run `train.sh` to train IRNet.
`sh train.sh [GPU_ID] [SAVE_FOLD]`
#### Testing
Run `eval.sh` to eval IRNet.
`sh eval.sh [GPU_ID] [OUTPUT_FOLD]`
#### Evaluation
You could follow the general evaluation process in [Spider Page](https://github.com/taoyds/spider)
## Results
| **Model** | Dev <br /> Exact Set Match <br />Accuracy | Test<br /> Exact Set Match <br />Accuracy |
| ----------- | ------------------------------------- | -------------------------------------- |
| IRNet | 53.2 | 46.7 |
| IRNet+BERT(base) | 61.9 | **54.7** |
## Citation
If you use IRNet, please cite the following work.
```
@article{GuoIRNet2019,
author={Jiaqi Guo and Zecheng Zhan and Yan Gao and Yan Xiao and Jian-Guang Lou and Ting Liu and Dongmei Zhang},
title={Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation},
journal={arXiv preprint arXiv:1905.08205},
year={2019},
note={version 1}
}
```
## Thanks
We would like to thank [Tao Yu](https://taoyds.github.io/) and [Bo Pang](https://www.linkedin.com/in/bo-pang/) for running evaluations on our submitted models.
We are also grateful to the flexible semantic parser [TranX](https://github.com/pcyin/tranX) that inspires our works.
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
This project welcomes contributions and suggestions. Most contributions require you to
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
and actually do, grant us the rights to use your contribution. For details, visit
https://cla.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
When you submit a pull request, a CLA-bot will automatically determine whether you need
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

10
__init__.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : __init__.py
# @Software: PyCharm
"""

58
eval.py Normal file
Просмотреть файл

@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/27
# @Author : Jiaqi&Zecheng
# @File : eval.py
# @Software: PyCharm
"""
import torch
from src import args as arg
from src import utils
from src.models.model import IRNet
from src.rule import semQL
def evaluate(args):
"""
:param args:
:return:
"""
grammar = semQL.Grammar()
sql_data, table_data, val_sql_data,\
val_table_data= utils.load_dataset(args.dataset, use_small=args.toy)
model = IRNet(args, grammar)
if args.cuda: model.cuda()
print('load pretrained model from %s'% (args.load_model))
pretrained_model = torch.load(args.load_model,
map_location=lambda storage, loc: storage)
import copy
pretrained_modeled = copy.deepcopy(pretrained_model)
for k in pretrained_model.keys():
if k not in model.state_dict().keys():
del pretrained_modeled[k]
model.load_state_dict(pretrained_modeled)
model.word_emb = utils.load_word_emb(args.glove_embed_path)
json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
beam_size=args.beam_size)
# utils.eval_acc(json_datas, val_sql_data)
import json
with open('./predict_lf.json', 'w') as f:
json.dump(json_datas, f)
if __name__ == '__main__':
arg_parser = arg.init_arg_parser()
args = arg.init_config(arg_parser)
print(args)
evaluate(args)

28
eval.sh Normal file
Просмотреть файл

@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#!/bin/bash
devices=$1
save_name=$2
CUDA_VISIBLE_DEVICES=$devices python -u eval.py --dataset ./data \
--glove_embed_path ./data/glove.42B.300d.txt \
--cuda \
--epoch 50 \
--loss_epoch_threshold 50 \
--sketch_loss_coefficie 1.0 \
--beam_size 5 \
--seed 90 \
--save ${save_name} \
--embed_size 300 \
--sentence_features \
--column_pointer \
--hidden_size 300 \
--lr_scheduler \
--lr_scheduler_gammar 0.5 \
--att_vec_size 300 \
--load_model ./saved_model/IRNet_pretrained.model
python sem2SQL.py --data_path ./data --input_path predict_lf.json --output_path ${save_name}

219
preprocess/data_process.py Normal file
Просмотреть файл

@ -0,0 +1,219 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/24
# @Author : Jiaqi&Zecheng
# @File : data_process.py
# @Software: PyCharm
"""
import json
import argparse
import nltk
import os
import pickle
from utils import symbol_filter, re_lemma, fully_part_header, group_header, partial_header, num2year, group_symbol, group_values, group_digital
from utils import AGG, wordnet_lemmatizer
from utils import load_dataSets
def process_datas(datas, args):
"""
:param datas:
:param args:
:return:
"""
with open(os.path.join(args.conceptNet, 'english_RelatedTo.pkl'), 'rb') as f:
english_RelatedTo = pickle.load(f)
with open(os.path.join(args.conceptNet, 'english_IsA.pkl'), 'rb') as f:
english_IsA = pickle.load(f)
# copy of the origin question_toks
for d in datas:
if 'origin_question_toks' not in d:
d['origin_question_toks'] = d['question_toks']
for entry in datas:
entry['question_toks'] = symbol_filter(entry['question_toks'])
origin_question_toks = symbol_filter([x for x in entry['origin_question_toks'] if x.lower() != 'the'])
question_toks = [wordnet_lemmatizer.lemmatize(x.lower()) for x in entry['question_toks'] if x.lower() != 'the']
entry['question_toks'] = question_toks
table_names = []
table_names_pattern = []
for y in entry['table_names']:
x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
table_names.append(" ".join(x))
x = [re_lemma(x.lower()) for x in y.split(' ')]
table_names_pattern.append(" ".join(x))
header_toks = []
header_toks_list = []
header_toks_pattern = []
header_toks_list_pattern = []
for y in entry['col_set']:
x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')]
header_toks.append(" ".join(x))
header_toks_list.append(x)
x = [re_lemma(x.lower()) for x in y.split(' ')]
header_toks_pattern.append(" ".join(x))
header_toks_list_pattern.append(x)
num_toks = len(question_toks)
idx = 0
tok_concol = []
type_concol = []
nltk_result = nltk.pos_tag(question_toks)
while idx < num_toks:
# fully header
end_idx, header = fully_part_header(question_toks, idx, num_toks, header_toks)
if header:
tok_concol.append(question_toks[idx: end_idx])
type_concol.append(["col"])
idx = end_idx
continue
# check for table
end_idx, tname = group_header(question_toks, idx, num_toks, table_names)
if tname:
tok_concol.append(question_toks[idx: end_idx])
type_concol.append(["table"])
idx = end_idx
continue
# check for column
end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
if header:
tok_concol.append(question_toks[idx: end_idx])
type_concol.append(["col"])
idx = end_idx
continue
# check for partial column
end_idx, tname = partial_header(question_toks, idx, header_toks_list)
if tname:
tok_concol.append(tname)
type_concol.append(["col"])
idx = end_idx
continue
# check for aggregation
end_idx, agg = group_header(question_toks, idx, num_toks, AGG)
if agg:
tok_concol.append(question_toks[idx: end_idx])
type_concol.append(["agg"])
idx = end_idx
continue
if nltk_result[idx][1] == 'RBR' or nltk_result[idx][1] == 'JJR':
tok_concol.append([question_toks[idx]])
type_concol.append(['MORE'])
idx += 1
continue
if nltk_result[idx][1] == 'RBS' or nltk_result[idx][1] == 'JJS':
tok_concol.append([question_toks[idx]])
type_concol.append(['MOST'])
idx += 1
continue
# string match for Time Format
if num2year(question_toks[idx]):
question_toks[idx] = 'year'
end_idx, header = group_header(question_toks, idx, num_toks, header_toks)
if header:
tok_concol.append(question_toks[idx: end_idx])
type_concol.append(["col"])
idx = end_idx
continue
def get_concept_result(toks, graph):
for begin_id in range(0, len(toks)):
for r_ind in reversed(range(1, len(toks) + 1 - begin_id)):
tmp_query = "_".join(toks[begin_id:r_ind])
if tmp_query in graph:
mi = graph[tmp_query]
for col in entry['col_set']:
if col in mi:
return col
end_idx, symbol = group_symbol(question_toks, idx, num_toks)
if symbol:
tmp_toks = [x for x in question_toks[idx: end_idx]]
assert len(tmp_toks) > 0, print(symbol, question_toks)
pro_result = get_concept_result(tmp_toks, english_IsA)
if pro_result is None:
pro_result = get_concept_result(tmp_toks, english_RelatedTo)
if pro_result is None:
pro_result = "NONE"
for tmp in tmp_toks:
tok_concol.append([tmp])
type_concol.append([pro_result])
pro_result = "NONE"
idx = end_idx
continue
end_idx, values = group_values(origin_question_toks, idx, num_toks)
if values and (len(values) > 1 or question_toks[idx - 1] not in ['?', '.']):
tmp_toks = [wordnet_lemmatizer.lemmatize(x) for x in question_toks[idx: end_idx] if x.isalnum() is True]
assert len(tmp_toks) > 0, print(question_toks[idx: end_idx], values, question_toks, idx, end_idx)
pro_result = get_concept_result(tmp_toks, english_IsA)
if pro_result is None:
pro_result = get_concept_result(tmp_toks, english_RelatedTo)
if pro_result is None:
pro_result = "NONE"
for tmp in tmp_toks:
tok_concol.append([tmp])
type_concol.append([pro_result])
pro_result = "NONE"
idx = end_idx
continue
result = group_digital(question_toks, idx)
if result is True:
tok_concol.append(question_toks[idx: idx + 1])
type_concol.append(["value"])
idx += 1
continue
if question_toks[idx] == ['ha']:
question_toks[idx] = ['have']
tok_concol.append([question_toks[idx]])
type_concol.append(['NONE'])
idx += 1
continue
entry['question_arg'] = tok_concol
entry['question_arg_type'] = type_concol
entry['nltk_pos'] = nltk_result
return datas
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--data_path', type=str, help='dataset', required=True)
arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True)
arg_parser.add_argument('--output', type=str, help='output data')
args = arg_parser.parse_args()
args.conceptNet = './conceptNet'
# loading dataSets
datas, table = load_dataSets(args)
# process datasets
process_result = process_datas(datas, args)
with open(args.output, 'w') as f:
json.dump(datas, f)

Просмотреть файл

@ -0,0 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/1/29
# @Author : Jiaqi&Zecheng
# @File : download_nltk.py
# @Software: PyCharm
"""
import nltk
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('wordnet')

19
preprocess/run_me.sh Normal file
Просмотреть файл

@ -0,0 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#!/bin/bash
data=$1
table_data=$2
output=$3
echo "Start download NLTK data"
python download_nltk.py
echo "Start process the origin Spider dataset"
python data_process.py --data_path ${data} --table_path ${table_data} --output "process_data.json"
echo "Start generate SemQL from SQL"
python sql2SemQL.py --data_path process_data.json --table_path ${table_data} --output ${data}
rm process_data.json

391
preprocess/sql2SemQL.py Normal file
Просмотреть файл

@ -0,0 +1,391 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/24
# @Author : Jiaqi&Zecheng
# @File : sql2SemQL.py
# @Software: PyCharm
"""
import argparse
import json
import sys
import copy
from utils import load_dataSets
sys.path.append("..")
from src.rule.semQL import Root1, Root, N, A, C, T, Sel, Sup, Filter, Order
class Parser:
def __init__(self):
self.copy_selec = None
self.sel_result = []
self.colSet = set()
def _init_rule(self):
self.copy_selec = None
self.colSet = set()
def _parse_root(self, sql):
"""
parsing the sql by the grammar
R ::= Select | Select Filter | Select Order | ... |
:return: [R(), states]
"""
use_sup, use_ord, use_fil = True, True, False
if sql['sql']['limit'] == None:
use_sup = False
if sql['sql']['orderBy'] == []:
use_ord = False
elif sql['sql']['limit'] != None:
use_ord = False
# check the where and having
if sql['sql']['where'] != [] or \
sql['sql']['having'] != []:
use_fil = True
if use_fil and use_sup:
return [Root(0)], ['FILTER', 'SUP', 'SEL']
elif use_fil and use_ord:
return [Root(1)], ['ORDER', 'FILTER', 'SEL']
elif use_sup:
return [Root(2)], ['SUP', 'SEL']
elif use_fil:
return [Root(3)], ['FILTER', 'SEL']
elif use_ord:
return [Root(4)], ['ORDER', 'SEL']
else:
return [Root(5)], ['SEL']
def _parser_column0(self, sql, select):
"""
Find table of column '*'
:return: T(table_id)
"""
if len(sql['sql']['from']['table_units']) == 1:
return T(sql['sql']['from']['table_units'][0][1])
else:
table_list = []
for tmp_t in sql['sql']['from']['table_units']:
if type(tmp_t[1]) == int:
table_list.append(tmp_t[1])
table_set, other_set = set(table_list), set()
for sel_p in select:
if sel_p[1][1][1] != 0:
other_set.add(sql['col_table'][sel_p[1][1][1]])
if len(sql['sql']['where']) == 1:
other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
elif len(sql['sql']['where']) == 3:
other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]])
elif len(sql['sql']['where']) == 5:
other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]])
other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]])
other_set.add(sql['col_table'][sql['sql']['where'][4][2][1][1]])
table_set = table_set - other_set
if len(table_set) == 1:
return T(list(table_set)[0])
elif len(table_set) == 0 and sql['sql']['groupBy'] != []:
return T(sql['col_table'][sql['sql']['groupBy'][0][1]])
else:
question = sql['question']
self.sel_result.append(question)
print('column * table error')
return T(sql['sql']['from']['table_units'][0][1])
def _parse_select(self, sql):
"""
parsing the sql by the grammar
Select ::= A | AA | AAA | ... |
A ::= agg column table
:return: [Sel(), states]
"""
result = []
select = sql['sql']['select'][1]
result.append(Sel(0))
result.append(N(len(select) - 1))
for sel in select:
result.append(A(sel[0]))
self.colSet.add(sql['col_set'].index(sql['names'][sel[1][1][1]]))
result.append(C(sql['col_set'].index(sql['names'][sel[1][1][1]])))
# now check for the situation with *
if sel[1][1][1] == 0:
result.append(self._parser_column0(sql, select))
else:
result.append(T(sql['col_table'][sel[1][1][1]]))
if not self.copy_selec:
self.copy_selec = [copy.deepcopy(result[-2]), copy.deepcopy(result[-1])]
return result, None
def _parse_sup(self, sql):
"""
parsing the sql by the grammar
Sup ::= Most A | Least A
A ::= agg column table
:return: [Sup(), states]
"""
result = []
select = sql['sql']['select'][1]
if sql['sql']['limit'] == None:
return result, None
if sql['sql']['orderBy'][0] == 'desc':
result.append(Sup(0))
else:
result.append(Sup(1))
result.append(A(sql['sql']['orderBy'][1][0][1][0]))
self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))
result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])))
if sql['sql']['orderBy'][1][0][1][1] == 0:
result.append(self._parser_column0(sql, select))
else:
result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]]))
return result, None
def _parse_filter(self, sql):
"""
parsing the sql by the grammar
Filter ::= and Filter Filter | ... |
A ::= agg column table
:return: [Filter(), states]
"""
result = []
# check the where
if sql['sql']['where'] != [] and sql['sql']['having'] != []:
result.append(Filter(0))
if sql['sql']['where'] != []:
# check the not and/or
if len(sql['sql']['where']) == 1:
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
elif len(sql['sql']['where']) == 3:
if sql['sql']['where'][1] == 'or':
result.append(Filter(1))
else:
result.append(Filter(0))
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
else:
if sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'and':
result.append(Filter(0))
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
result.append(Filter(0))
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
elif sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'or':
result.append(Filter(1))
result.append(Filter(0))
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
elif sql['sql']['where'][1] == 'or' and sql['sql']['where'][3] == 'and':
result.append(Filter(1))
result.append(Filter(0))
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
else:
result.append(Filter(1))
result.append(Filter(1))
result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql))
result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql))
# check having
if sql['sql']['having'] != []:
result.extend(self.parse_one_condition(sql['sql']['having'][0], sql['names'], sql))
return result, None
def _parse_order(self, sql):
"""
parsing the sql by the grammar
Order ::= asc A | desc A
A ::= agg column table
:return: [Order(), states]
"""
result = []
if 'order' not in sql['query_toks_no_value'] or 'by' not in sql['query_toks_no_value']:
return result, None
elif 'limit' in sql['query_toks_no_value']:
return result, None
else:
if sql['sql']['orderBy'] == []:
return result, None
else:
select = sql['sql']['select'][1]
if sql['sql']['orderBy'][0] == 'desc':
result.append(Order(0))
else:
result.append(Order(1))
result.append(A(sql['sql']['orderBy'][1][0][1][0]))
self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))
result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])))
if sql['sql']['orderBy'][1][0][1][1] == 0:
result.append(self._parser_column0(sql, select))
else:
result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]]))
return result, None
def parse_one_condition(self, sql_condit, names, sql):
result = []
# check if V(root)
nest_query = True
if type(sql_condit[3]) != dict:
nest_query = False
if sql_condit[0] == True:
if sql_condit[1] == 9:
# not like only with values
fil = Filter(10)
elif sql_condit[1] == 8:
# not in with Root
fil = Filter(19)
else:
print(sql_condit[1])
raise NotImplementedError("not implement for the others FIL")
else:
# check for Filter (<,=,>,!=,between, >=, <=, ...)
single_map = {1:8,2:2,3:5,4:4,5:7,6:6,7:3}
nested_map = {1:15,2:11,3:13,4:12,5:16,6:17,7:14}
if sql_condit[1] in [1, 2, 3, 4, 5, 6, 7]:
if nest_query == False:
fil = Filter(single_map[sql_condit[1]])
else:
fil = Filter(nested_map[sql_condit[1]])
elif sql_condit[1] == 9:
fil = Filter(9)
elif sql_condit[1] == 8:
fil = Filter(18)
else:
print(sql_condit[1])
raise NotImplementedError("not implement for the others FIL")
result.append(fil)
result.append(A(sql_condit[2][1][0]))
self.colSet.add(sql['col_set'].index(sql['names'][sql_condit[2][1][1]]))
result.append(C(sql['col_set'].index(sql['names'][sql_condit[2][1][1]])))
if sql_condit[2][1][1] == 0:
select = sql['sql']['select'][1]
result.append(self._parser_column0(sql, select))
else:
result.append(T(sql['col_table'][sql_condit[2][1][1]]))
# check for the nested value
if type(sql_condit[3]) == dict:
nest_query = {}
nest_query['names'] = names
nest_query['query_toks_no_value'] = ""
nest_query['sql'] = sql_condit[3]
nest_query['col_table'] = sql['col_table']
nest_query['col_set'] = sql['col_set']
nest_query['table_names'] = sql['table_names']
nest_query['question'] = sql['question']
nest_query['query'] = sql['query']
nest_query['keys'] = sql['keys']
result.extend(self.parser(nest_query))
return result
def _parse_step(self, state, sql):
if state == 'ROOT':
return self._parse_root(sql)
if state == 'SEL':
return self._parse_select(sql)
elif state == 'SUP':
return self._parse_sup(sql)
elif state == 'FILTER':
return self._parse_filter(sql)
elif state == 'ORDER':
return self._parse_order(sql)
else:
raise NotImplementedError("Not the right state")
def full_parse(self, query):
sql = query['sql']
nest_query = {}
nest_query['names'] = query['names']
nest_query['query_toks_no_value'] = ""
nest_query['col_table'] = query['col_table']
nest_query['col_set'] = query['col_set']
nest_query['table_names'] = query['table_names']
nest_query['question'] = query['question']
nest_query['query'] = query['query']
nest_query['keys'] = query['keys']
if sql['intersect']:
results = [Root1(0)]
nest_query['sql'] = sql['intersect']
results.extend(self.parser(query))
results.extend(self.parser(nest_query))
return results
if sql['union']:
results = [Root1(1)]
nest_query['sql'] = sql['union']
results.extend(self.parser(query))
results.extend(self.parser(nest_query))
return results
if sql['except']:
results = [Root1(2)]
nest_query['sql'] = sql['except']
results.extend(self.parser(query))
results.extend(self.parser(nest_query))
return results
results = [Root1(3)]
results.extend(self.parser(query))
return results
def parser(self, query):
stack = ["ROOT"]
result = []
while len(stack) > 0:
state = stack.pop()
step_result, step_state = self._parse_step(state, query)
result.extend(step_result)
if step_state:
stack.extend(step_state)
return result
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--data_path', type=str, help='dataset', required=True)
arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True)
arg_parser.add_argument('--output', type=str, help='output data', required=True)
args = arg_parser.parse_args()
parser = Parser()
# loading dataSets
datas, table = load_dataSets(args)
processed_data = []
for i, d in enumerate(datas):
if len(datas[i]['sql']['select'][1]) > 5:
continue
r = parser.full_parse(datas[i])
datas[i]['rule_label'] = " ".join([str(x) for x in r])
processed_data.append(datas[i])
print('Finished %s datas and failed %s datas' % (len(processed_data), len(datas) - len(processed_data)))
with open(args.output, 'w', encoding='utf8') as f:
f.write(json.dumps(processed_data))

173
preprocess/utils.py Normal file
Просмотреть файл

@ -0,0 +1,173 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/24
# @Author : Jiaqi&Zecheng
# @File : utils.py
# @Software: PyCharm
"""
import os
import json
from pattern.en import lemma
from nltk.stem import WordNetLemmatizer
VALUE_FILTER = ['what', 'how', 'list', 'give', 'show', 'find', 'id', 'order', 'when']
AGG = ['average', 'sum', 'max', 'min', 'minimum', 'maximum', 'between']
wordnet_lemmatizer = WordNetLemmatizer()
def load_dataSets(args):
with open(args.table_path, 'r', encoding='utf8') as f:
table_datas = json.load(f)
with open(args.data_path, 'r', encoding='utf8') as f:
datas = json.load(f)
output_tab = {}
tables = {}
tabel_name = set()
for i in range(len(table_datas)):
table = table_datas[i]
temp = {}
temp['col_map'] = table['column_names']
temp['table_names'] = table['table_names']
tmp_col = []
for cc in [x[1] for x in table['column_names']]:
if cc not in tmp_col:
tmp_col.append(cc)
table['col_set'] = tmp_col
db_name = table['db_id']
tabel_name.add(db_name)
table['schema_content'] = [col[1] for col in table['column_names']]
table['col_table'] = [col[0] for col in table['column_names']]
output_tab[db_name] = temp
tables[db_name] = table
for d in datas:
d['names'] = tables[d['db_id']]['schema_content']
d['table_names'] = tables[d['db_id']]['table_names']
d['col_set'] = tables[d['db_id']]['col_set']
d['col_table'] = tables[d['db_id']]['col_table']
keys = {}
for kv in tables[d['db_id']]['foreign_keys']:
keys[kv[0]] = kv[1]
keys[kv[1]] = kv[0]
for id_k in tables[d['db_id']]['primary_keys']:
keys[id_k] = id_k
d['keys'] = keys
return datas, tables
def group_header(toks, idx, num_toks, header_toks):
for endIdx in reversed(range(idx + 1, num_toks+1)):
sub_toks = toks[idx: endIdx]
sub_toks = " ".join(sub_toks)
if sub_toks in header_toks:
return endIdx, sub_toks
return idx, None
def fully_part_header(toks, idx, num_toks, header_toks):
for endIdx in reversed(range(idx + 1, num_toks+1)):
sub_toks = toks[idx: endIdx]
if len(sub_toks) > 1:
sub_toks = " ".join(sub_toks)
if sub_toks in header_toks:
return endIdx, sub_toks
return idx, None
def partial_header(toks, idx, header_toks):
def check_in(list_one, list_two):
if len(set(list_one) & set(list_two)) == len(list_one) and (len(list_two) <= 3):
return True
for endIdx in reversed(range(idx + 1, len(toks))):
sub_toks = toks[idx: min(endIdx, len(toks))]
if len(sub_toks) > 1:
flag_count = 0
tmp_heads = None
for heads in header_toks:
if check_in(sub_toks, heads):
flag_count += 1
tmp_heads = heads
if flag_count == 1:
return endIdx, tmp_heads
return idx, None
def symbol_filter(questions):
question_tmp_q = []
for q_id, q_val in enumerate(questions):
if len(q_val) > 2 and q_val[0] in ["'", '"', '`', '<EFBFBD>', '<EFBFBD>'] and q_val[-1] in ["'", '"', '`', '<EFBFBD>']:
question_tmp_q.append("'")
question_tmp_q += ["".join(q_val[1:-1])]
question_tmp_q.append("'")
elif len(q_val) > 2 and q_val[0] in ["'", '"', '`', '<EFBFBD>'] :
question_tmp_q.append("'")
question_tmp_q += ["".join(q_val[1:])]
elif len(q_val) > 2 and q_val[-1] in ["'", '"', '`', '<EFBFBD>']:
question_tmp_q += ["".join(q_val[0:-1])]
question_tmp_q.append("'")
elif q_val in ["'", '"', '`', '<EFBFBD>', '<EFBFBD>', '``', "''"]:
question_tmp_q += ["'"]
else:
question_tmp_q += [q_val]
return question_tmp_q
def group_values(toks, idx, num_toks):
def check_isupper(tok_lists):
for tok_one in tok_lists:
if tok_one[0].isupper() is False:
return False
return True
for endIdx in reversed(range(idx + 1, num_toks + 1)):
sub_toks = toks[idx: endIdx]
if len(sub_toks) > 1 and check_isupper(sub_toks) is True:
return endIdx, sub_toks
if len(sub_toks) == 1:
if sub_toks[0][0].isupper() and sub_toks[0].lower() not in VALUE_FILTER and \
sub_toks[0].lower().isalnum() is True:
return endIdx, sub_toks
return idx, None
def group_digital(toks, idx):
test = toks[idx].replace(':', '')
test = test.replace('.', '')
if test.isdigit():
return True
else:
return False
def group_symbol(toks, idx, num_toks):
if toks[idx-1] == "'":
for i in range(0, min(3, num_toks-idx)):
if toks[i + idx] == "'":
return i + idx, toks[idx:i+idx]
return idx, None
def num2year(tok):
if len(str(tok)) == 4 and str(tok).isdigit() and int(str(tok)[:2]) < 22 and int(str(tok)[:2]) > 15:
return True
return False
def set_header(toks, header_toks, tok_concol, idx, num_toks):
def check_in(list_one, list_two):
if set(list_one) == set(list_two):
return True
for endIdx in range(idx, num_toks):
toks += tok_concol[endIdx]
if len(tok_concol[endIdx]) > 1:
break
for heads in header_toks:
if check_in(toks, heads):
return heads
return None
def re_lemma(string):
lema = lemma(string.lower())
if len(lema) > 0:
return lema
else:
return string.lower()

8
requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
nltk==3.4
pattern
numpy==1.14.0
pytorch-pretrained-bert==0.5.1
tqdm==4.31.1

697
sem2SQL.py Normal file
Просмотреть файл

@ -0,0 +1,697 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/27
# @Author : Jiaqi&Zecheng
# @File : sem2SQL.py
# @Software: PyCharm
"""
import argparse
import traceback
from src.rule.graph import Graph
from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
from src.rule.sem_utils import alter_inter, alter_not_in, alter_column0, load_dataSets
def split_logical_form(lf):
indexs = [i+1 for i, letter in enumerate(lf) if letter == ')']
indexs.insert(0, 0)
components = list()
for i in range(1, len(indexs)):
components.append(lf[indexs[i-1]:indexs[i]].strip())
return components
def pop_front(array):
if len(array) == 0:
return 'None'
return array.pop(0)
def is_end(components, transformed_sql, is_root_processed):
end = False
c = pop_front(components)
c_instance = eval(c)
if isinstance(c_instance, Root) and is_root_processed:
# intersect, union, except
end = True
elif isinstance(c_instance, Filter):
if 'where' not in transformed_sql:
end = True
else:
num_conjunction = 0
for f in transformed_sql['where']:
if isinstance(f, str) and (f == 'and' or f == 'or'):
num_conjunction += 1
current_filters = len(transformed_sql['where'])
valid_filters = current_filters - num_conjunction
if valid_filters >= num_conjunction + 1:
end = True
elif isinstance(c_instance, Order):
if 'order' not in transformed_sql:
end = True
elif len(transformed_sql['order']) == 0:
end = False
else:
end = True
elif isinstance(c_instance, Sup):
if 'sup' not in transformed_sql:
end = True
elif len(transformed_sql['sup']) == 0:
end = False
else:
end = True
components.insert(0, c)
return end
def _transform(components, transformed_sql, col_set, table_names, schema):
processed_root = False
current_table = schema
while len(components) > 0:
if is_end(components, transformed_sql, processed_root):
break
c = pop_front(components)
c_instance = eval(c)
if isinstance(c_instance, Root):
processed_root = True
transformed_sql['select'] = list()
if c_instance.id_c == 0:
transformed_sql['where'] = list()
transformed_sql['sup'] = list()
elif c_instance.id_c == 1:
transformed_sql['where'] = list()
transformed_sql['order'] = list()
elif c_instance.id_c == 2:
transformed_sql['sup'] = list()
elif c_instance.id_c == 3:
transformed_sql['where'] = list()
elif c_instance.id_c == 4:
transformed_sql['order'] = list()
elif isinstance(c_instance, Sel):
continue
elif isinstance(c_instance, N):
for i in range(c_instance.id_c + 1):
agg = eval(pop_front(components))
column = eval(pop_front(components))
_table = pop_front(components)
table = eval(_table)
if not isinstance(table, T):
table = None
components.insert(0, _table)
assert isinstance(agg, A) and isinstance(column, C)
transformed_sql['select'].append((
agg.production.split()[1],
replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) if table is not None else col_set[column.id_c],
table_names[table.id_c] if table is not None else table
))
elif isinstance(c_instance, Sup):
transformed_sql['sup'].append(c_instance.production.split()[1])
agg = eval(pop_front(components))
column = eval(pop_front(components))
_table = pop_front(components)
table = eval(_table)
if not isinstance(table, T):
table = None
components.insert(0, _table)
assert isinstance(agg, A) and isinstance(column, C)
transformed_sql['sup'].append(agg.production.split()[1])
if table:
fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)
else:
fix_col_id = col_set[column.id_c]
raise RuntimeError('not found table !!!!')
transformed_sql['sup'].append(fix_col_id)
transformed_sql['sup'].append(table_names[table.id_c] if table is not None else table)
elif isinstance(c_instance, Order):
transformed_sql['order'].append(c_instance.production.split()[1])
agg = eval(pop_front(components))
column = eval(pop_front(components))
_table = pop_front(components)
table = eval(_table)
if not isinstance(table, T):
table = None
components.insert(0, _table)
assert isinstance(agg, A) and isinstance(column, C)
transformed_sql['order'].append(agg.production.split()[1])
transformed_sql['order'].append(replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table))
transformed_sql['order'].append(table_names[table.id_c] if table is not None else table)
elif isinstance(c_instance, Filter):
op = c_instance.production.split()[1]
if op == 'and' or op == 'or':
transformed_sql['where'].append(op)
else:
# No Supquery
agg = eval(pop_front(components))
column = eval(pop_front(components))
_table = pop_front(components)
table = eval(_table)
if not isinstance(table, T):
table = None
components.insert(0, _table)
assert isinstance(agg, A) and isinstance(column, C)
if len(c_instance.production.split()) == 3:
if table:
fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)
else:
fix_col_id = col_set[column.id_c]
raise RuntimeError('not found table !!!!')
transformed_sql['where'].append((
op,
agg.production.split()[1],
fix_col_id,
table_names[table.id_c] if table is not None else table,
None
))
else:
# Subquery
new_dict = dict()
new_dict['sql'] = transformed_sql['sql']
transformed_sql['where'].append((
op,
agg.production.split()[1],
replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table),
table_names[table.id_c] if table is not None else table,
_transform(components, new_dict, col_set, table_names, schema)
))
return transformed_sql
def transform(query, schema, origin=None):
preprocess_schema(schema)
if origin is None:
lf = query['model_result_replace']
else:
lf = origin
# lf = query['rule_label']
col_set = query['col_set']
table_names = query['table_names']
current_table = schema
current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']]
current_table['schema_content'] = [x[1] for x in current_table['column_names_original']]
components = split_logical_form(lf)
transformed_sql = dict()
transformed_sql['sql'] = query
c = pop_front(components)
c_instance = eval(c)
assert isinstance(c_instance, Root1)
if c_instance.id_c == 0:
transformed_sql['intersect'] = dict()
transformed_sql['intersect']['sql'] = query
_transform(components, transformed_sql, col_set, table_names, schema)
_transform(components, transformed_sql['intersect'], col_set, table_names, schema)
elif c_instance.id_c == 1:
transformed_sql['union'] = dict()
transformed_sql['union']['sql'] = query
_transform(components, transformed_sql, col_set, table_names, schema)
_transform(components, transformed_sql['union'], col_set, table_names, schema)
elif c_instance.id_c == 2:
transformed_sql['except'] = dict()
transformed_sql['except']['sql'] = query
_transform(components, transformed_sql, col_set, table_names, schema)
_transform(components, transformed_sql['except'], col_set, table_names, schema)
else:
_transform(components, transformed_sql, col_set, table_names, schema)
parse_result = to_str(transformed_sql, 1, schema)
parse_result = parse_result.replace('\t', '')
return [parse_result]
def col_to_str(agg, col, tab, table_names, N=1):
_col = col.replace(' ', '_')
if agg == 'none':
if tab not in table_names:
table_names[tab] = 'T' + str(len(table_names) + N)
table_alias = table_names[tab]
if col == '*':
return '*'
return '%s.%s' % (table_alias, _col)
else:
if col == '*':
if tab is not None and tab not in table_names:
table_names[tab] = 'T' + str(len(table_names) + N)
return '%s(%s)' % (agg, _col)
else:
if tab not in table_names:
table_names[tab] = 'T' + str(len(table_names) + N)
table_alias = table_names[tab]
return '%s(%s.%s)' % (agg, table_alias, _col)
def infer_from_clause(table_names, schema, columns):
tables = list(table_names.keys())
# print(table_names)
start_table = None
end_table = None
join_clause = list()
if len(tables) == 1:
join_clause.append((tables[0], table_names[tables[0]]))
elif len(tables) == 2:
use_graph = True
# print(schema['graph'].vertices)
for t in tables:
if t not in schema['graph'].vertices:
use_graph = False
break
if use_graph:
start_table = tables[0]
end_table = tables[1]
_tables = list(schema['graph'].dijkstra(tables[0], tables[1]))
# print('Two tables: ', _tables)
max_key = 1
for t, k in table_names.items():
_k = int(k[1:])
if _k > max_key:
max_key = _k
for t in _tables:
if t not in table_names:
table_names[t] = 'T' + str(max_key + 1)
max_key += 1
join_clause.append((t, table_names[t],))
else:
join_clause = list()
for t in tables:
join_clause.append((t, table_names[t],))
else:
# > 2
# print('More than 2 table')
for t in tables:
join_clause.append((t, table_names[t],))
if len(join_clause) >= 3:
star_table = None
for agg, col, tab in columns:
if col == '*':
star_table = tab
break
if star_table is not None:
star_table_count = 0
for agg, col, tab in columns:
if tab == star_table and col != '*':
star_table_count += 1
if star_table_count == 0 and ((end_table is None or end_table == star_table) or (start_table is None or start_table == star_table)):
# Remove the table the rest tables still can join without star_table
new_join_clause = list()
for t in join_clause:
if t[0] != star_table:
new_join_clause.append(t)
join_clause = new_join_clause
join_clause = ' JOIN '.join(['%s AS %s' % (jc[0], jc[1]) for jc in join_clause])
return 'FROM ' + join_clause
def replace_col_with_original_col(query, col, current_table):
# print(query, col)
if query == '*':
return query
cur_table = col
cur_col = query
single_final_col = None
# print(query, col)
for col_ind, col_name in enumerate(current_table['schema_content_clean']):
if col_name == cur_col:
assert cur_table in current_table['table_names']
if current_table['table_names'][current_table['col_table'][col_ind]] == cur_table:
single_final_col = current_table['column_names_original'][col_ind][1]
break
assert single_final_col
# if query != single_final_col:
# print(query, single_final_col)
return single_final_col
def build_graph(schema):
relations = list()
foreign_keys = schema['foreign_keys']
for (fkey, pkey) in foreign_keys:
fkey_table = schema['table_names_original'][schema['column_names'][fkey][0]]
pkey_table = schema['table_names_original'][schema['column_names'][pkey][0]]
relations.append((fkey_table, pkey_table))
relations.append((pkey_table, fkey_table))
return Graph(relations)
def preprocess_schema(schema):
tmp_col = []
for cc in [x[1] for x in schema['column_names']]:
if cc not in tmp_col:
tmp_col.append(cc)
schema['col_set'] = tmp_col
# print table
schema['schema_content'] = [col[1] for col in schema['column_names']]
schema['col_table'] = [col[0] for col in schema['column_names']]
graph = build_graph(schema)
schema['graph'] = graph
def to_str(sql_json, N_T, schema, pre_table_names=None):
all_columns = list()
select_clause = list()
table_names = dict()
current_table = schema
for (agg, col, tab) in sql_json['select']:
all_columns.append((agg, col, tab))
select_clause.append(col_to_str(agg, col, tab, table_names, N_T))
select_clause_str = 'SELECT ' + ', '.join(select_clause).strip()
sup_clause = ''
order_clause = ''
direction_map = {"des": 'DESC', 'asc': 'ASC'}
if 'sup' in sql_json:
(direction, agg, col, tab,) = sql_json['sup']
all_columns.append((agg, col, tab))
subject = col_to_str(agg, col, tab, table_names, N_T)
sup_clause = ('ORDER BY %s %s LIMIT 1' % (subject, direction_map[direction])).strip()
elif 'order' in sql_json:
(direction, agg, col, tab,) = sql_json['order']
all_columns.append((agg, col, tab))
subject = col_to_str(agg, col, tab, table_names, N_T)
order_clause = ('ORDER BY %s %s' % (subject, direction_map[direction])).strip()
has_group_by = False
where_clause = ''
have_clause = ''
if 'where' in sql_json:
conjunctions = list()
filters = list()
# print(sql_json['where'])
for f in sql_json['where']:
if isinstance(f, str):
conjunctions.append(f)
else:
op, agg, col, tab, value = f
if value:
value['sql'] = sql_json['sql']
all_columns.append((agg, col, tab))
subject = col_to_str(agg, col, tab, table_names, N_T)
if value is None:
where_value = '1'
if op == 'between':
where_value = '1 AND 2'
filters.append('%s %s %s' % (subject, op, where_value))
else:
if op == 'in' and len(value['select']) == 1 and value['select'][0][0] == 'none' \
and 'where' not in value and 'order' not in value and 'sup' not in value:
# and value['select'][0][2] not in table_names:
if value['select'][0][2] not in table_names:
table_names[value['select'][0][2]] = 'T' + str(len(table_names) + N_T)
filters.append(None)
else:
filters.append('%s %s %s' % (subject, op, '(' + to_str(value, len(table_names) + 1, schema) + ')'))
if len(conjunctions):
filters.append(conjunctions.pop())
aggs = ['count(', 'avg(', 'min(', 'max(', 'sum(']
having_filters = list()
idx = 0
while idx < len(filters):
_filter = filters[idx]
if _filter is None:
idx += 1
continue
for agg in aggs:
if _filter.startswith(agg):
having_filters.append(_filter)
filters.pop(idx)
# print(filters)
if 0 < idx and (filters[idx - 1] in ['and', 'or']):
filters.pop(idx - 1)
# print(filters)
break
else:
idx += 1
if len(having_filters) > 0:
have_clause = 'HAVING ' + ' '.join(having_filters).strip()
if len(filters) > 0:
# print(filters)
filters = [_f for _f in filters if _f is not None]
conjun_num = 0
filter_num = 0
for _f in filters:
if _f in ['or', 'and']:
conjun_num += 1
else:
filter_num += 1
if conjun_num > 0 and filter_num != (conjun_num + 1):
# assert 'and' in filters
idx = 0
while idx < len(filters):
if filters[idx] == 'and':
if idx - 1 == 0:
filters.pop(idx)
break
if filters[idx - 1] in ['and', 'or']:
filters.pop(idx)
break
if idx + 1 >= len(filters) - 1:
filters.pop(idx)
break
if filters[idx + 1] in ['and', 'or']:
filters.pop(idx)
break
idx += 1
if len(filters) > 0:
where_clause = 'WHERE ' + ' '.join(filters).strip()
where_clause = where_clause.replace('not_in', 'NOT IN')
else:
where_clause = ''
if len(having_filters) > 0:
has_group_by = True
for agg in ['count(', 'avg(', 'min(', 'max(', 'sum(']:
if (len(sql_json['select']) > 1 and agg in select_clause_str)\
or agg in sup_clause or agg in order_clause:
has_group_by = True
break
group_by_clause = ''
if has_group_by:
if len(table_names) == 1:
# check none agg
is_agg_flag = False
for (agg, col, tab) in sql_json['select']:
if agg == 'none':
group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
else:
is_agg_flag = True
if is_agg_flag is False and len(group_by_clause) > 5:
group_by_clause = "GROUP BY"
for (agg, col, tab) in sql_json['select']:
group_by_clause = group_by_clause + ' ' + col_to_str(agg, col, tab, table_names, N_T)
if len(group_by_clause) < 5:
if 'count(*)' in select_clause_str:
current_table = schema
for primary in current_table['primary_keys']:
if current_table['table_names'][current_table['col_table'][primary]] in table_names :
group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][primary],
current_table['table_names'][
current_table['col_table'][primary]],
table_names, N_T)
else:
# if only one select
if len(sql_json['select']) == 1:
agg, col, tab = sql_json['select'][0]
non_lists = [tab]
fix_flag = False
# add tab from other part
for key, value in table_names.items():
if key not in non_lists:
non_lists.append(key)
a = non_lists[0]
b = None
for non in non_lists:
if a != non:
b = non
if b:
for pair in current_table['foreign_keys']:
t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
if t1 in [a, b] and t2 in [a, b]:
if pre_table_names and t1 not in pre_table_names:
assert t2 in pre_table_names
t1 = t2
group_by_clause = 'GROUP BY ' + col_to_str('none',
current_table['schema_content'][pair[0]],
t1,
table_names, N_T)
fix_flag = True
break
if fix_flag is False:
agg, col, tab = sql_json['select'][0]
group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
else:
# check if there are only one non agg
non_agg, non_agg_count = None, 0
non_lists = []
for (agg, col, tab) in sql_json['select']:
if agg == 'none':
non_agg = (agg, col, tab)
non_lists.append(tab)
non_agg_count += 1
non_lists = list(set(non_lists))
# print(non_lists)
if non_agg_count == 1:
group_by_clause = 'GROUP BY ' + col_to_str(non_agg[0], non_agg[1], non_agg[2], table_names, N_T)
elif non_agg:
find_flag = False
fix_flag = False
find_primary = None
if len(non_lists) <= 1:
for key, value in table_names.items():
if key not in non_lists:
non_lists.append(key)
if len(non_lists) > 1:
a = non_lists[0]
b = None
for non in non_lists:
if a != non:
b = non
if b:
for pair in current_table['foreign_keys']:
t1 = current_table['table_names'][current_table['col_table'][pair[0]]]
t2 = current_table['table_names'][current_table['col_table'][pair[1]]]
if t1 in [a, b] and t2 in [a, b]:
if pre_table_names and t1 not in pre_table_names:
assert t2 in pre_table_names
t1 = t2
group_by_clause = 'GROUP BY ' + col_to_str('none',
current_table['schema_content'][pair[0]],
t1,
table_names, N_T)
fix_flag = True
break
tab = non_agg[2]
assert tab in current_table['table_names']
for primary in current_table['primary_keys']:
if current_table['table_names'][current_table['col_table'][primary]] == tab:
find_flag = True
find_primary = (current_table['schema_content'][primary], tab)
if fix_flag is False:
if find_flag is False:
# rely on count *
foreign = []
for pair in current_table['foreign_keys']:
if current_table['table_names'][current_table['col_table'][pair[0]]] == tab:
foreign.append(pair[1])
if current_table['table_names'][current_table['col_table'][pair[1]]] == tab:
foreign.append(pair[0])
for pair in foreign:
if current_table['table_names'][current_table['col_table'][pair]] in table_names:
group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][pair],
current_table['table_names'][current_table['col_table'][pair]],
table_names, N_T)
find_flag = True
break
if find_flag is False:
for (agg, col, tab) in sql_json['select']:
if 'id' in col.lower():
group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T)
break
if len(group_by_clause) > 5:
pass
else:
raise RuntimeError('fail to convert')
else:
group_by_clause = 'GROUP BY ' + col_to_str('none', find_primary[0],
find_primary[1],
table_names, N_T)
intersect_clause = ''
if 'intersect' in sql_json:
sql_json['intersect']['sql'] = sql_json['sql']
intersect_clause = 'INTERSECT ' + to_str(sql_json['intersect'], len(table_names) + 1, schema, table_names)
union_clause = ''
if 'union' in sql_json:
sql_json['union']['sql'] = sql_json['sql']
union_clause = 'UNION ' + to_str(sql_json['union'], len(table_names) + 1, schema, table_names)
except_clause = ''
if 'except' in sql_json:
sql_json['except']['sql'] = sql_json['sql']
except_clause = 'EXCEPT ' + to_str(sql_json['except'], len(table_names) + 1, schema, table_names)
# print(current_table['table_names_original'])
table_names_replace = {}
for a, b in zip(current_table['table_names_original'], current_table['table_names']):
table_names_replace[b] = a
new_table_names = {}
for key, value in table_names.items():
if key is None:
continue
new_table_names[table_names_replace[key]] = value
from_clause = infer_from_clause(new_table_names, schema, all_columns).strip()
sql = ' '.join([select_clause_str, from_clause, where_clause, group_by_clause, have_clause, sup_clause, order_clause,
intersect_clause, union_clause, except_clause])
return sql
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--data_path', type=str, help='dataset path', required=True)
arg_parser.add_argument('--input_path', type=str, help='predicted logical form', required=True)
arg_parser.add_argument('--output_path', type=str, help='output data')
args = arg_parser.parse_args()
# loading dataSets
datas, schemas = load_dataSets(args)
alter_not_in(datas, schemas=schemas)
alter_inter(datas)
alter_column0(datas)
index = range(len(datas))
count = 0
exception_count = 0
with open(args.output_path, 'w', encoding='utf8') as d:
for i in index:
try:
result = transform(datas[i], schemas[datas[i]['db_id']])
d.write(result[0] + '\n')
count += 1
except Exception as e:
result = transform(datas[i], schemas[datas[i]['db_id']], origin='Root1(3) Root(5) Sel(0) N(0) A(3) C(0) T(0)')
exception_count += 1
d.write(result[0] + '\n')
count += 1
print(e)
print('Exception')
print(traceback.format_exc())
print('===\n\n')
print(count, exception_count)

10
src/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : __init__.py
# @Software: PyCharm
"""

82
src/args.py Normal file
Просмотреть файл

@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : args.py
# @Software: PyCharm
"""
import random
import argparse
import torch
import numpy as np
def init_arg_parser():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--seed', default=5783287, type=int, help='random seed')
arg_parser.add_argument('--cuda', action='store_true', help='use gpu')
arg_parser.add_argument('--lr_scheduler', action='store_true', help='use learning rate scheduler')
arg_parser.add_argument('--lr_scheduler_gammar', default=0.5, type=float, help='decay rate of learning rate scheduler')
arg_parser.add_argument('--column_pointer', action='store_true', help='use column pointer')
arg_parser.add_argument('--loss_epoch_threshold', default=20, type=int, help='loss epoch threshold')
arg_parser.add_argument('--sketch_loss_coefficient', default=0.2, type=float, help='sketch loss coefficient')
arg_parser.add_argument('--sentence_features', action='store_true', help='use sentence features')
arg_parser.add_argument('--model_name', choices=['transformer', 'rnn', 'table', 'sketch'], default='rnn',
help='model name')
arg_parser.add_argument('--lstm', choices=['lstm', 'lstm_with_dropout', 'parent_feed'], default='lstm')
arg_parser.add_argument('--load_model', default=None, type=str, help='load a pre-trained model')
arg_parser.add_argument('--glove_embed_path', default="glove.42B.300d.txt", type=str)
arg_parser.add_argument('--batch_size', default=64, type=int, help='batch size')
arg_parser.add_argument('--beam_size', default=5, type=int, help='beam size for beam search')
arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings')
arg_parser.add_argument('--col_embed_size', default=300, type=int, help='size of word embeddings')
arg_parser.add_argument('--action_embed_size', default=128, type=int, help='size of word embeddings')
arg_parser.add_argument('--type_embed_size', default=128, type=int, help='size of word embeddings')
arg_parser.add_argument('--hidden_size', default=100, type=int, help='size of LSTM hidden states')
arg_parser.add_argument('--att_vec_size', default=100, type=int, help='size of attentional vector')
arg_parser.add_argument('--dropout', default=0.3, type=float, help='dropout rate')
arg_parser.add_argument('--word_dropout', default=0.2, type=float, help='word dropout rate')
# readout layer
arg_parser.add_argument('--no_query_vec_to_action_map', default=False, action='store_true')
arg_parser.add_argument('--readout', default='identity', choices=['identity', 'non_linear'])
arg_parser.add_argument('--query_vec_to_action_diff_map', default=False, action='store_true')
arg_parser.add_argument('--column_att', choices=['dot_prod', 'affine'], default='affine')
arg_parser.add_argument('--decode_max_time_step', default=40, type=int, help='maximum number of time steps used '
'in decoding and sampling')
arg_parser.add_argument('--save_to', default='model', type=str, help='save trained model to')
arg_parser.add_argument('--toy', action='store_true',
help='If set, use small data; used for fast debugging.')
arg_parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients')
arg_parser.add_argument('--max_epoch', default=-1, type=int, help='maximum number of training epoches')
arg_parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer')
arg_parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
arg_parser.add_argument('--dataset', default="./data", type=str)
arg_parser.add_argument('--epoch', default=50, type=int, help='Maximum Epoch')
arg_parser.add_argument('--save', default='./', type=str,
help="Path to save the checkpoint and logs of epoch")
return arg_parser
def init_config(arg_parser):
args = arg_parser.parse_args()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
np.random.seed(int(args.seed * 13 / 7))
random.seed(int(args.seed))
return args

206
src/beam.py Normal file
Просмотреть файл

@ -0,0 +1,206 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : beam.py
# @Software: PyCharm
"""
import copy
from src.rule import semQL
class ActionInfo(object):
"""sufficient statistics for making a prediction of an action at a time step"""
def __init__(self, action=None):
self.t = 0
self.score = 0
self.parent_t = -1
self.action = action
self.frontier_prod = None
self.frontier_field = None
# for GenToken actions only
self.copy_from_src = False
self.src_token_position = -1
class Beams(object):
def __init__(self, is_sketch=False):
self.actions = []
self.action_infos = []
self.inputs = []
self.score = 0.
self.t = 0
self.is_sketch = is_sketch
self.sketch_step = 0
self.sketch_attention_history = list()
def get_availableClass(self):
"""
return the available action class
:return:
"""
# TODO: it could be update by speed
# return the available class using rule
# FIXME: now should change for these 11: "Filter 1 ROOT",
def check_type(lists):
for s in lists:
if type(s) == int:
return False
return True
stack = [semQL.Root1]
for action in self.actions:
infer_action = action.get_next_action(is_sketch=self.is_sketch)
infer_action.reverse()
if stack[-1] is type(action):
stack.pop()
# check if the are non-terminal
if check_type(infer_action):
stack.extend(infer_action)
else:
raise RuntimeError("Not the right action")
result = stack[-1] if len(stack) > 0 else None
return result
@classmethod
def get_parent_action(cls, actions):
"""
:param actions:
:return:
"""
def check_type(lists):
for s in lists:
if type(s) == int:
return False
return True
# check the origin state Root
if len(actions) == 0:
return None
stack = [semQL.Root1]
for id_x, action in enumerate(actions):
infer_action = action.get_next_action()
for ac in infer_action:
ac.parent = action
ac.pt = id_x
infer_action.reverse()
if stack[-1] is type(action):
stack.pop()
# check if the are non-terminal
if check_type(infer_action):
stack.extend(infer_action)
else:
for t in actions:
if type(t) != semQL.C:
print(t, end="")
print('asd')
print(action)
print(stack[-1])
raise RuntimeError("Not the right action")
result = stack[-1] if len(stack) > 0 else None
return result
def apply_action(self, action):
# TODO: not finish implement yet
self.t += 1
self.actions.append(action)
def clone_and_apply_action(self, action):
new_hyp = self.copy()
new_hyp.apply_action(action)
return new_hyp
def clone_and_apply_action_info(self, action_info):
action = action_info.action
action.score = action_info.score
new_hyp = self.clone_and_apply_action(action)
new_hyp.action_infos.append(action_info)
new_hyp.sketch_step = self.sketch_step
new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)
return new_hyp
def copy(self):
new_hyp = Beams(is_sketch=self.is_sketch)
# if self.tree:
# new_hyp.tree = self.tree.copy()
new_hyp.actions = list(self.actions)
new_hyp.score = self.score
new_hyp.t = self.t
new_hyp.sketch_step = self.sketch_step
new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)
return new_hyp
def infer_n(self):
if len(self.actions) > 4:
prev_action = self.actions[-3]
if isinstance(prev_action, semQL.Filter):
if prev_action.id_c > 11:
# Nested Query, only select 1 column
return ['N A']
if self.actions[0].id_c != 3:
return [self.actions[3].production]
return semQL.N._init_grammar()
@property
def completed(self):
return True if self.get_availableClass() is None else False
@property
def is_valid(self):
actions = self.actions
return self.check_sel_valid(actions)
def check_sel_valid(self, actions):
find_sel = False
sel_actions = list()
for ac in actions:
if type(ac) == semQL.Sel:
find_sel = True
elif find_sel and type(ac) in [semQL.N, semQL.T, semQL.C, semQL.A]:
if type(ac) not in [semQL.N]:
sel_actions.append(ac)
elif find_sel and type(ac) not in [semQL.N, semQL.T, semQL.C, semQL.A]:
break
if find_sel is False:
return True
# not the complete sel lf
if len(sel_actions) % 3 != 0:
return True
sel_string = list()
for i in range(len(sel_actions) // 3):
if (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c) in sel_string:
return False
else:
sel_string.append(
(sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c))
return True
if __name__ == '__main__':
test = Beams(is_sketch=True)
# print(semQL.Root1(1).get_next_action())
test.actions.append(semQL.Root1(3))
test.actions.append(semQL.Root(5))
print(str(test.get_availableClass()))

140
src/dataset.py Normal file
Просмотреть файл

@ -0,0 +1,140 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : utils.py
# @Software: PyCharm
"""
import copy
import src.rule.semQL as define_rule
from src.models import nn_utils
class Example:
"""
"""
def __init__(self, src_sent, tgt_actions=None, vis_seq=None, tab_cols=None, col_num=None, sql=None,
one_hot_type=None, col_hot_type=None, schema_len=None, tab_ids=None,
table_names=None, table_len=None, col_table_dict=None, cols=None,
table_col_name=None, table_col_len=None,
col_pred=None, tokenized_src_sent=None,
):
self.src_sent = src_sent
self.tokenized_src_sent = tokenized_src_sent
self.vis_seq = vis_seq
self.tab_cols = tab_cols
self.col_num = col_num
self.sql = sql
self.one_hot_type=one_hot_type
self.col_hot_type = col_hot_type
self.schema_len = schema_len
self.tab_ids = tab_ids
self.table_names = table_names
self.table_len = table_len
self.col_table_dict = col_table_dict
self.cols = cols
self.table_col_name = table_col_name
self.table_col_len = table_col_len
self.col_pred = col_pred
self.tgt_actions = tgt_actions
self.truth_actions = copy.deepcopy(tgt_actions)
self.sketch = list()
if self.truth_actions:
for ta in self.truth_actions:
if isinstance(ta, define_rule.C) or isinstance(ta, define_rule.T) or isinstance(ta, define_rule.A):
continue
self.sketch.append(ta)
class cached_property(object):
""" A property that is only computed once per instance and then replaces
itself with an ordinary attribute. Deleting the attribute resets the
property.
Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76
"""
def __init__(self, func):
self.__doc__ = getattr(func, '__doc__')
self.func = func
def __get__(self, obj, cls):
if obj is None:
return self
value = obj.__dict__[self.func.__name__] = self.func(obj)
return value
class Batch(object):
def __init__(self, examples, grammar, cuda=False):
self.examples = examples
if examples[0].tgt_actions:
self.max_action_num = max(len(e.tgt_actions) for e in self.examples)
self.max_sketch_num = max(len(e.sketch) for e in self.examples)
self.src_sents = [e.src_sent for e in self.examples]
self.src_sents_len = [len(e.src_sent) for e in self.examples]
self.tokenized_src_sents = [e.tokenized_src_sent for e in self.examples]
self.tokenized_src_sents_len = [len(e.tokenized_src_sent) for e in examples]
self.src_sents_word = [e.src_sent for e in self.examples]
self.table_sents_word = [[" ".join(x) for x in e.tab_cols] for e in self.examples]
self.schema_sents_word = [[" ".join(x) for x in e.table_names] for e in self.examples]
self.src_type = [e.one_hot_type for e in self.examples]
self.col_hot_type = [e.col_hot_type for e in self.examples]
self.table_sents = [e.tab_cols for e in self.examples]
self.col_num = [e.col_num for e in self.examples]
self.tab_ids = [e.tab_ids for e in self.examples]
self.table_names = [e.table_names for e in self.examples]
self.table_len = [e.table_len for e in examples]
self.col_table_dict = [e.col_table_dict for e in examples]
self.table_col_name = [e.table_col_name for e in examples]
self.table_col_len = [e.table_col_len for e in examples]
self.col_pred = [e.col_pred for e in examples]
self.grammar = grammar
self.cuda = cuda
def __len__(self):
return len(self.examples)
def table_dict_mask(self, table_dict):
return nn_utils.table_dict_to_mask_tensor(self.table_len, table_dict, cuda=self.cuda)
@cached_property
def pred_col_mask(self):
return nn_utils.pred_col_mask(self.col_pred, self.col_num)
@cached_property
def schema_token_mask(self):
return nn_utils.length_array_to_mask_tensor(self.table_len, cuda=self.cuda)
@cached_property
def table_token_mask(self):
return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda)
@cached_property
def table_appear_mask(self):
return nn_utils.appear_to_mask_tensor(self.col_num, cuda=self.cuda)
@cached_property
def table_unk_mask(self):
return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda, value=None)
@cached_property
def src_token_mask(self):
return nn_utils.length_array_to_mask_tensor(self.src_sents_len,
cuda=self.cuda)

10
src/models/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/26
# @Author : Jiaqi&Zecheng
# @File : __init__.py.py
# @Software: PyCharm
"""

160
src/models/basic_model.py Normal file
Просмотреть файл

@ -0,0 +1,160 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/26
# @Author : Jiaqi&Zecheng
# @File : basic_model.py
# @Software: PyCharm
"""
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from src.rule import semQL as define_rule
class BasicModel(nn.Module):
def __init__(self):
super(BasicModel, self).__init__()
pass
def embedding_cosine(self, src_embedding, table_embedding, table_unk_mask):
embedding_differ = []
for i in range(table_embedding.size(1)):
one_table_embedding = table_embedding[:, i, :]
one_table_embedding = one_table_embedding.unsqueeze(1).expand(table_embedding.size(0),
src_embedding.size(1),
table_embedding.size(2))
topk_val = F.cosine_similarity(one_table_embedding, src_embedding, dim=-1)
embedding_differ.append(topk_val)
embedding_differ = torch.stack(embedding_differ).transpose(1, 0)
embedding_differ.data.masked_fill_(table_unk_mask.unsqueeze(2).expand(
table_embedding.size(0),
table_embedding.size(1),
embedding_differ.size(2)
), 0)
return embedding_differ
def encode(self, src_sents_var, src_sents_len, q_onehot_project=None):
"""
encode the source sequence
:return:
src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2)
last_state, last_cell: Variable(batch_size, hidden_size)
"""
src_token_embed = self.gen_x_batch(src_sents_var)
if q_onehot_project is not None:
src_token_embed = torch.cat([src_token_embed, q_onehot_project], dim=-1)
packed_src_token_embed = pack_padded_sequence(src_token_embed, src_sents_len, batch_first=True)
# src_encodings: (tgt_query_len, batch_size, hidden_size)
src_encodings, (last_state, last_cell) = self.encoder_lstm(packed_src_token_embed)
src_encodings, _ = pad_packed_sequence(src_encodings, batch_first=True)
# src_encodings: (batch_size, tgt_query_len, hidden_size)
# src_encodings = src_encodings.permute(1, 0, 2)
# (batch_size, hidden_size * 2)
last_state = torch.cat([last_state[0], last_state[1]], -1)
last_cell = torch.cat([last_cell[0], last_cell[1]], -1)
return src_encodings, (last_state, last_cell)
def input_type(self, values_list):
B = len(values_list)
val_len = []
for value in values_list:
val_len.append(len(value))
max_len = max(val_len)
# for the Begin and End
val_emb_array = np.zeros((B, max_len, values_list[0].shape[1]), dtype=np.float32)
for i in range(B):
val_emb_array[i, :val_len[i], :] = values_list[i][:, :]
val_inp = torch.from_numpy(val_emb_array)
if self.args.cuda:
val_inp = val_inp.cuda()
val_inp_var = Variable(val_inp)
return val_inp_var
def padding_sketch(self, sketch):
padding_result = []
for action in sketch:
padding_result.append(action)
if type(action) == define_rule.N:
for _ in range(action.id_c + 1):
padding_result.append(define_rule.A(0))
padding_result.append(define_rule.C(0))
padding_result.append(define_rule.T(0))
elif type(action) == define_rule.Filter and 'A' in action.production:
padding_result.append(define_rule.A(0))
padding_result.append(define_rule.C(0))
padding_result.append(define_rule.T(0))
elif type(action) == define_rule.Order or type(action) == define_rule.Sup:
padding_result.append(define_rule.A(0))
padding_result.append(define_rule.C(0))
padding_result.append(define_rule.T(0))
return padding_result
def gen_x_batch(self, q):
B = len(q)
val_embs = []
val_len = np.zeros(B, dtype=np.int64)
is_list = False
if type(q[0][0]) == list:
is_list = True
for i, one_q in enumerate(q):
if not is_list:
q_val = list(
map(lambda x: self.word_emb.get(x, np.zeros(self.args.col_embed_size, dtype=np.float32)), one_q))
else:
q_val = []
for ws in one_q:
emb_list = []
ws_len = len(ws)
for w in ws:
emb_list.append(self.word_emb.get(w, self.word_emb['unk']))
if ws_len == 0:
raise Exception("word list should not be empty!")
elif ws_len == 1:
q_val.append(emb_list[0])
else:
q_val.append(sum(emb_list) / float(ws_len))
val_embs.append(q_val)
val_len[i] = len(q_val)
max_len = max(val_len)
val_emb_array = np.zeros((B, max_len, self.args.col_embed_size), dtype=np.float32)
for i in range(B):
for t in range(len(val_embs[i])):
val_emb_array[i, t, :] = val_embs[i][t]
val_inp = torch.from_numpy(val_emb_array)
if self.args.cuda:
val_inp = val_inp.cuda()
return val_inp
def save(self, path):
dir_name = os.path.dirname(path)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
params = {
'args': self.args,
'vocab': self.vocab,
'grammar': self.grammar,
'state_dict': self.state_dict()
}
torch.save(params, path)

764
src/models/model.py Normal file
Просмотреть файл

@ -0,0 +1,764 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : model.py
# @Software: PyCharm
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils
from torch.autograd import Variable
from src.beam import Beams, ActionInfo
from src.dataset import Batch
from src.models import nn_utils
from src.models.basic_model import BasicModel
from src.models.pointer_net import PointerNet
from src.rule import semQL as define_rule
class IRNet(BasicModel):
def __init__(self, args, grammar):
super(IRNet, self).__init__()
self.args = args
self.grammar = grammar
self.use_column_pointer = args.column_pointer
self.use_sentence_features = args.sentence_features
if args.cuda:
self.new_long_tensor = torch.cuda.LongTensor
self.new_tensor = torch.cuda.FloatTensor
else:
self.new_long_tensor = torch.LongTensor
self.new_tensor = torch.FloatTensor
self.encoder_lstm = nn.LSTM(args.embed_size, args.hidden_size // 2, bidirectional=True,
batch_first=True)
input_dim = args.action_embed_size + \
args.att_vec_size + \
args.type_embed_size
# previous action
# input feeding
# pre type embedding
self.lf_decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size)
self.sketch_decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size)
# initialize the decoder's state and cells with encoder hidden states
self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size)
self.att_sketch_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.att_lf_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.sketch_att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False)
self.lf_att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False)
self.prob_att = nn.Linear(args.att_vec_size, 1)
self.prob_len = nn.Linear(1, 1)
self.col_type = nn.Linear(4, args.col_embed_size)
self.sketch_encoder = nn.LSTM(args.action_embed_size, args.action_embed_size // 2, bidirectional=True,
batch_first=True)
self.production_embed = nn.Embedding(len(grammar.prod2id), args.action_embed_size)
self.type_embed = nn.Embedding(len(grammar.type2id), args.type_embed_size)
self.production_readout_b = nn.Parameter(torch.FloatTensor(len(grammar.prod2id)).zero_())
self.att_project = nn.Linear(args.hidden_size + args.type_embed_size, args.hidden_size)
self.N_embed = nn.Embedding(len(define_rule.N._init_grammar()), args.action_embed_size)
self.read_out_act = F.tanh if args.readout == 'non_linear' else nn_utils.identity
self.query_vec_to_action_embed = nn.Linear(args.att_vec_size, args.action_embed_size,
bias=args.readout == 'non_linear')
self.production_readout = lambda q: F.linear(self.read_out_act(self.query_vec_to_action_embed(q)),
self.production_embed.weight, self.production_readout_b)
self.q_att = nn.Linear(args.hidden_size, args.embed_size)
self.column_rnn_input = nn.Linear(args.col_embed_size, args.action_embed_size, bias=False)
self.table_rnn_input = nn.Linear(args.col_embed_size, args.action_embed_size, bias=False)
self.dropout = nn.Dropout(args.dropout)
self.column_pointer_net = PointerNet(args.hidden_size, args.col_embed_size, attention_type=args.column_att)
self.table_pointer_net = PointerNet(args.hidden_size, args.col_embed_size, attention_type=args.column_att)
# initial the embedding layers
nn.init.xavier_normal_(self.production_embed.weight.data)
nn.init.xavier_normal_(self.type_embed.weight.data)
nn.init.xavier_normal_(self.N_embed.weight.data)
print('Use Column Pointer: ', True if self.use_column_pointer else False)
def forward(self, examples):
args = self.args
# now should implement the examples
batch = Batch(examples, self.grammar, cuda=self.args.cuda)
table_appear_mask = batch.table_appear_mask
src_encodings, (last_state, last_cell) = self.encode(batch.src_sents, batch.src_sents_len, None)
src_encodings = self.dropout(src_encodings)
utterance_encodings_sketch_linear = self.att_sketch_linear(src_encodings)
utterance_encodings_lf_linear = self.att_lf_linear(src_encodings)
dec_init_vec = self.init_decoder_state(last_cell)
h_tm1 = dec_init_vec
action_probs = [[] for _ in examples]
zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())
zero_type_embed = Variable(self.new_tensor(args.type_embed_size).zero_())
sketch_attention_history = list()
for t in range(batch.max_sketch_num):
if t == 0:
x = Variable(self.new_tensor(len(batch), self.sketch_decoder_lstm.input_size).zero_(),
requires_grad=False)
else:
a_tm1_embeds = []
pre_types = []
for e_id, example in enumerate(examples):
if t < len(example.sketch):
# get the last action
# This is the action embedding
action_tm1 = example.sketch[t - 1]
if type(action_tm1) in [define_rule.Root1,
define_rule.Root,
define_rule.Sel,
define_rule.Filter,
define_rule.Sup,
define_rule.N,
define_rule.Order]:
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
else:
print(action_tm1, 'only for sketch')
quit()
a_tm1_embed = zero_action_embed
pass
else:
a_tm1_embed = zero_action_embed
a_tm1_embeds.append(a_tm1_embed)
a_tm1_embeds = torch.stack(a_tm1_embeds)
inputs = [a_tm1_embeds]
for e_id, example in enumerate(examples):
if t < len(example.sketch):
action_tm = example.sketch[t - 1]
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
else:
pre_type = zero_type_embed
pre_types.append(pre_type)
pre_types = torch.stack(pre_types)
inputs.append(att_tm1)
inputs.append(pre_types)
x = torch.cat(inputs, dim=-1)
src_mask = batch.src_token_mask
(h_t, cell_t), att_t, aw = self.step(x, h_tm1, src_encodings,
utterance_encodings_sketch_linear, self.sketch_decoder_lstm,
self.sketch_att_vec_linear,
src_token_mask=src_mask, return_att_weight=True)
sketch_attention_history.append(att_t)
# get the Root possibility
apply_rule_prob = F.softmax(self.production_readout(att_t), dim=-1)
for e_id, example in enumerate(examples):
if t < len(example.sketch):
action_t = example.sketch[t]
act_prob_t_i = apply_rule_prob[e_id, self.grammar.prod2id[action_t.production]]
action_probs[e_id].append(act_prob_t_i)
h_tm1 = (h_t, cell_t)
att_tm1 = att_t
sketch_prob_var = torch.stack(
[torch.stack(action_probs_i, dim=0).log().sum() for action_probs_i in action_probs], dim=0)
table_embedding = self.gen_x_batch(batch.table_sents)
src_embedding = self.gen_x_batch(batch.src_sents)
schema_embedding = self.gen_x_batch(batch.table_names)
# get emb differ
embedding_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=table_embedding,
table_unk_mask=batch.table_unk_mask)
schema_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=schema_embedding,
table_unk_mask=batch.schema_token_mask)
tab_ctx = (src_encodings.unsqueeze(1) * embedding_differ.unsqueeze(3)).sum(2)
schema_ctx = (src_encodings.unsqueeze(1) * schema_differ.unsqueeze(3)).sum(2)
table_embedding = table_embedding + tab_ctx
schema_embedding = schema_embedding + schema_ctx
col_type = self.input_type(batch.col_hot_type)
col_type_var = self.col_type(col_type)
table_embedding = table_embedding + col_type_var
batch_table_dict = batch.col_table_dict
table_enable = np.zeros(shape=(len(examples)))
action_probs = [[] for _ in examples]
h_tm1 = dec_init_vec
for t in range(batch.max_action_num):
if t == 0:
# x = self.lf_begin_vec.unsqueeze(0).repeat(len(batch), 1)
x = Variable(self.new_tensor(len(batch), self.lf_decoder_lstm.input_size).zero_(), requires_grad=False)
else:
a_tm1_embeds = []
pre_types = []
for e_id, example in enumerate(examples):
if t < len(example.tgt_actions):
action_tm1 = example.tgt_actions[t - 1]
if type(action_tm1) in [define_rule.Root1,
define_rule.Root,
define_rule.Sel,
define_rule.Filter,
define_rule.Sup,
define_rule.N,
define_rule.Order,
]:
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
else:
if isinstance(action_tm1, define_rule.C):
a_tm1_embed = self.column_rnn_input(table_embedding[e_id, action_tm1.id_c])
elif isinstance(action_tm1, define_rule.T):
a_tm1_embed = self.column_rnn_input(schema_embedding[e_id, action_tm1.id_c])
elif isinstance(action_tm1, define_rule.A):
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
else:
print(action_tm1, 'not implement')
quit()
a_tm1_embed = zero_action_embed
pass
else:
a_tm1_embed = zero_action_embed
a_tm1_embeds.append(a_tm1_embed)
a_tm1_embeds = torch.stack(a_tm1_embeds)
inputs = [a_tm1_embeds]
# tgt t-1 action type
for e_id, example in enumerate(examples):
if t < len(example.tgt_actions):
action_tm = example.tgt_actions[t - 1]
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
else:
pre_type = zero_type_embed
pre_types.append(pre_type)
pre_types = torch.stack(pre_types)
inputs.append(att_tm1)
inputs.append(pre_types)
x = torch.cat(inputs, dim=-1)
src_mask = batch.src_token_mask
(h_t, cell_t), att_t, aw = self.step(x, h_tm1, src_encodings,
utterance_encodings_lf_linear, self.lf_decoder_lstm,
self.lf_att_vec_linear,
src_token_mask=src_mask, return_att_weight=True)
apply_rule_prob = F.softmax(self.production_readout(att_t), dim=-1)
table_appear_mask_val = torch.from_numpy(table_appear_mask)
if self.cuda:
table_appear_mask_val = table_appear_mask_val.cuda()
if self.use_column_pointer:
gate = F.sigmoid(self.prob_att(att_t))
weights = self.column_pointer_net(src_encodings=table_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=None) * table_appear_mask_val * gate + self.column_pointer_net(
src_encodings=table_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=None) * (1 - table_appear_mask_val) * (1 - gate)
else:
weights = self.column_pointer_net(src_encodings=table_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=batch.table_token_mask)
weights.data.masked_fill_(batch.table_token_mask, -float('inf'))
column_attention_weights = F.softmax(weights, dim=-1)
table_weights = self.table_pointer_net(src_encodings=schema_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=None)
schema_token_mask = batch.schema_token_mask.expand_as(table_weights)
table_weights.data.masked_fill_(schema_token_mask, -float('inf'))
table_dict = [batch_table_dict[x_id][int(x)] for x_id, x in enumerate(table_enable.tolist())]
table_mask = batch.table_dict_mask(table_dict)
table_weights.data.masked_fill_(table_mask, -float('inf'))
table_weights = F.softmax(table_weights, dim=-1)
# now get the loss
for e_id, example in enumerate(examples):
if t < len(example.tgt_actions):
action_t = example.tgt_actions[t]
if isinstance(action_t, define_rule.C):
table_appear_mask[e_id, action_t.id_c] = 1
table_enable[e_id] = action_t.id_c
act_prob_t_i = column_attention_weights[e_id, action_t.id_c]
action_probs[e_id].append(act_prob_t_i)
elif isinstance(action_t, define_rule.T):
act_prob_t_i = table_weights[e_id, action_t.id_c]
action_probs[e_id].append(act_prob_t_i)
elif isinstance(action_t, define_rule.A):
act_prob_t_i = apply_rule_prob[e_id, self.grammar.prod2id[action_t.production]]
action_probs[e_id].append(act_prob_t_i)
else:
pass
h_tm1 = (h_t, cell_t)
att_tm1 = att_t
lf_prob_var = torch.stack(
[torch.stack(action_probs_i, dim=0).log().sum() for action_probs_i in action_probs], dim=0)
return [sketch_prob_var, lf_prob_var]
def parse(self, examples, beam_size=5):
"""
one example a time
:param examples:
:param beam_size:
:return:
"""
batch = Batch([examples], self.grammar, cuda=self.args.cuda)
src_encodings, (last_state, last_cell) = self.encode(batch.src_sents, batch.src_sents_len, None)
src_encodings = self.dropout(src_encodings)
utterance_encodings_sketch_linear = self.att_sketch_linear(src_encodings)
utterance_encodings_lf_linear = self.att_lf_linear(src_encodings)
dec_init_vec = self.init_decoder_state(last_cell)
h_tm1 = dec_init_vec
t = 0
beams = [Beams(is_sketch=True)]
completed_beams = []
while len(completed_beams) < beam_size and t < self.args.decode_max_time_step:
hyp_num = len(beams)
exp_src_enconding = src_encodings.expand(hyp_num, src_encodings.size(1),
src_encodings.size(2))
exp_src_encodings_sketch_linear = utterance_encodings_sketch_linear.expand(hyp_num,
utterance_encodings_sketch_linear.size(
1),
utterance_encodings_sketch_linear.size(
2))
if t == 0:
with torch.no_grad():
x = Variable(self.new_tensor(1, self.sketch_decoder_lstm.input_size).zero_())
else:
a_tm1_embeds = []
pre_types = []
for e_id, hyp in enumerate(beams):
action_tm1 = hyp.actions[-1]
if type(action_tm1) in [define_rule.Root1,
define_rule.Root,
define_rule.Sel,
define_rule.Filter,
define_rule.Sup,
define_rule.N,
define_rule.Order]:
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
else:
raise ValueError('unknown action %s' % action_tm1)
a_tm1_embeds.append(a_tm1_embed)
a_tm1_embeds = torch.stack(a_tm1_embeds)
inputs = [a_tm1_embeds]
for e_id, hyp in enumerate(beams):
action_tm = hyp.actions[-1]
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
pre_types.append(pre_type)
pre_types = torch.stack(pre_types)
inputs.append(att_tm1)
inputs.append(pre_types)
x = torch.cat(inputs, dim=-1)
(h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_enconding,
exp_src_encodings_sketch_linear, self.sketch_decoder_lstm,
self.sketch_att_vec_linear,
src_token_mask=None)
apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)
new_hyp_meta = []
for hyp_id, hyp in enumerate(beams):
action_class = hyp.get_availableClass()
if action_class in [define_rule.Root1,
define_rule.Root,
define_rule.Sel,
define_rule.Filter,
define_rule.Sup,
define_rule.N,
define_rule.Order]:
possible_productions = self.grammar.get_production(action_class)
for possible_production in possible_productions:
prod_id = self.grammar.prod2id[possible_production]
prod_score = apply_rule_log_prob[hyp_id, prod_id]
new_hyp_score = hyp.score + prod_score.data.cpu()
meta_entry = {'action_type': action_class, 'prod_id': prod_id,
'score': prod_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
else:
raise RuntimeError('No right action class')
if not new_hyp_meta: break
new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta], dim=0)
top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
k=min(new_hyp_scores.size(0),
beam_size - len(completed_beams)))
live_hyp_ids = []
new_beams = []
for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
action_info = ActionInfo()
hyp_meta_entry = new_hyp_meta[meta_id]
prev_hyp_id = hyp_meta_entry['prev_hyp_id']
prev_hyp = beams[prev_hyp_id]
action_type_str = hyp_meta_entry['action_type']
prod_id = hyp_meta_entry['prod_id']
if prod_id < len(self.grammar.id2prod):
production = self.grammar.id2prod[prod_id]
action = action_type_str(list(action_type_str._init_grammar()).index(production))
else:
raise NotImplementedError
action_info.action = action
action_info.t = t
action_info.score = hyp_meta_entry['score']
new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
new_hyp.score = new_hyp_score
new_hyp.inputs.extend(prev_hyp.inputs)
if new_hyp.is_valid is False:
continue
if new_hyp.completed:
completed_beams.append(new_hyp)
else:
new_beams.append(new_hyp)
live_hyp_ids.append(prev_hyp_id)
if live_hyp_ids:
h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
att_tm1 = att_t[live_hyp_ids]
beams = new_beams
t += 1
else:
break
# now get the sketch result
completed_beams.sort(key=lambda hyp: -hyp.score)
if len(completed_beams) == 0:
return [[], []]
sketch_actions = completed_beams[0].actions
# sketch_actions = examples.sketch
padding_sketch = self.padding_sketch(sketch_actions)
table_embedding = self.gen_x_batch(batch.table_sents)
src_embedding = self.gen_x_batch(batch.src_sents)
schema_embedding = self.gen_x_batch(batch.table_names)
# get emb differ
embedding_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=table_embedding,
table_unk_mask=batch.table_unk_mask)
schema_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=schema_embedding,
table_unk_mask=batch.schema_token_mask)
tab_ctx = (src_encodings.unsqueeze(1) * embedding_differ.unsqueeze(3)).sum(2)
schema_ctx = (src_encodings.unsqueeze(1) * schema_differ.unsqueeze(3)).sum(2)
table_embedding = table_embedding + tab_ctx
schema_embedding = schema_embedding + schema_ctx
col_type = self.input_type(batch.col_hot_type)
col_type_var = self.col_type(col_type)
table_embedding = table_embedding + col_type_var
batch_table_dict = batch.col_table_dict
h_tm1 = dec_init_vec
t = 0
beams = [Beams(is_sketch=False)]
completed_beams = []
while len(completed_beams) < beam_size and t < self.args.decode_max_time_step:
hyp_num = len(beams)
# expand value
exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1),
src_encodings.size(2))
exp_utterance_encodings_lf_linear = utterance_encodings_lf_linear.expand(hyp_num,
utterance_encodings_lf_linear.size(
1),
utterance_encodings_lf_linear.size(
2))
exp_table_embedding = table_embedding.expand(hyp_num, table_embedding.size(1),
table_embedding.size(2))
exp_schema_embedding = schema_embedding.expand(hyp_num, schema_embedding.size(1),
schema_embedding.size(2))
table_appear_mask = batch.table_appear_mask
table_appear_mask = np.zeros((hyp_num, table_appear_mask.shape[1]), dtype=np.float32)
table_enable = np.zeros(shape=(hyp_num))
for e_id, hyp in enumerate(beams):
for act in hyp.actions:
if type(act) == define_rule.C:
table_appear_mask[e_id][act.id_c] = 1
table_enable[e_id] = act.id_c
if t == 0:
with torch.no_grad():
x = Variable(self.new_tensor(1, self.lf_decoder_lstm.input_size).zero_())
else:
a_tm1_embeds = []
pre_types = []
for e_id, hyp in enumerate(beams):
action_tm1 = hyp.actions[-1]
if type(action_tm1) in [define_rule.Root1,
define_rule.Root,
define_rule.Sel,
define_rule.Filter,
define_rule.Sup,
define_rule.N,
define_rule.Order]:
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
hyp.sketch_step += 1
elif isinstance(action_tm1, define_rule.C):
a_tm1_embed = self.column_rnn_input(table_embedding[0, action_tm1.id_c])
elif isinstance(action_tm1, define_rule.T):
a_tm1_embed = self.column_rnn_input(schema_embedding[0, action_tm1.id_c])
elif isinstance(action_tm1, define_rule.A):
a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]]
else:
raise ValueError('unknown action %s' % action_tm1)
a_tm1_embeds.append(a_tm1_embed)
a_tm1_embeds = torch.stack(a_tm1_embeds)
inputs = [a_tm1_embeds]
for e_id, hyp in enumerate(beams):
action_tm = hyp.actions[-1]
pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]]
pre_types.append(pre_type)
pre_types = torch.stack(pre_types)
inputs.append(att_tm1)
inputs.append(pre_types)
x = torch.cat(inputs, dim=-1)
(h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings,
exp_utterance_encodings_lf_linear, self.lf_decoder_lstm,
self.lf_att_vec_linear,
src_token_mask=None)
apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1)
table_appear_mask_val = torch.from_numpy(table_appear_mask)
if self.args.cuda: table_appear_mask_val = table_appear_mask_val.cuda()
if self.use_column_pointer:
gate = F.sigmoid(self.prob_att(att_t))
weights = self.column_pointer_net(src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=None) * table_appear_mask_val * gate + self.column_pointer_net(
src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=None) * (1 - table_appear_mask_val) * (1 - gate)
# weights = weights + self.col_attention_out(exp_embedding_differ).squeeze()
else:
weights = self.column_pointer_net(src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=batch.table_token_mask)
# weights.data.masked_fill_(exp_col_pred_mask, -float('inf'))
column_selection_log_prob = F.log_softmax(weights, dim=-1)
table_weights = self.table_pointer_net(src_encodings=exp_schema_embedding, query_vec=att_t.unsqueeze(0),
src_token_mask=None)
# table_weights = self.table_pointer_net(src_encodings=exp_schema_embedding, query_vec=att_t.unsqueeze(0), src_token_mask=None)
schema_token_mask = batch.schema_token_mask.expand_as(table_weights)
table_weights.data.masked_fill_(schema_token_mask, -float('inf'))
table_dict = [batch_table_dict[0][int(x)] for x_id, x in enumerate(table_enable.tolist())]
table_mask = batch.table_dict_mask(table_dict)
table_weights.data.masked_fill_(table_mask, -float('inf'))
table_weights = F.log_softmax(table_weights, dim=-1)
new_hyp_meta = []
for hyp_id, hyp in enumerate(beams):
# TODO: should change this
if type(padding_sketch[t]) == define_rule.A:
possible_productions = self.grammar.get_production(define_rule.A)
for possible_production in possible_productions:
prod_id = self.grammar.prod2id[possible_production]
prod_score = apply_rule_log_prob[hyp_id, prod_id]
new_hyp_score = hyp.score + prod_score.data.cpu()
meta_entry = {'action_type': define_rule.A, 'prod_id': prod_id,
'score': prod_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif type(padding_sketch[t]) == define_rule.C:
for col_id, _ in enumerate(batch.table_sents[0]):
col_sel_score = column_selection_log_prob[hyp_id, col_id]
new_hyp_score = hyp.score + col_sel_score.data.cpu()
meta_entry = {'action_type': define_rule.C, 'col_id': col_id,
'score': col_sel_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
elif type(padding_sketch[t]) == define_rule.T:
for t_id, _ in enumerate(batch.table_names[0]):
t_sel_score = table_weights[hyp_id, t_id]
new_hyp_score = hyp.score + t_sel_score.data.cpu()
meta_entry = {'action_type': define_rule.T, 't_id': t_id,
'score': t_sel_score, 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
else:
prod_id = self.grammar.prod2id[padding_sketch[t].production]
new_hyp_score = hyp.score + torch.tensor(0.0)
meta_entry = {'action_type': type(padding_sketch[t]), 'prod_id': prod_id,
'score': torch.tensor(0.0), 'new_hyp_score': new_hyp_score,
'prev_hyp_id': hyp_id}
new_hyp_meta.append(meta_entry)
if not new_hyp_meta: break
new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta], dim=0)
top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores,
k=min(new_hyp_scores.size(0),
beam_size - len(completed_beams)))
live_hyp_ids = []
new_beams = []
for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()):
action_info = ActionInfo()
hyp_meta_entry = new_hyp_meta[meta_id]
prev_hyp_id = hyp_meta_entry['prev_hyp_id']
prev_hyp = beams[prev_hyp_id]
action_type_str = hyp_meta_entry['action_type']
if 'prod_id' in hyp_meta_entry:
prod_id = hyp_meta_entry['prod_id']
if action_type_str == define_rule.C:
col_id = hyp_meta_entry['col_id']
action = define_rule.C(col_id)
elif action_type_str == define_rule.T:
t_id = hyp_meta_entry['t_id']
action = define_rule.T(t_id)
elif prod_id < len(self.grammar.id2prod):
production = self.grammar.id2prod[prod_id]
action = action_type_str(list(action_type_str._init_grammar()).index(production))
else:
raise NotImplementedError
action_info.action = action
action_info.t = t
action_info.score = hyp_meta_entry['score']
new_hyp = prev_hyp.clone_and_apply_action_info(action_info)
new_hyp.score = new_hyp_score
new_hyp.inputs.extend(prev_hyp.inputs)
if new_hyp.is_valid is False:
continue
if new_hyp.completed:
completed_beams.append(new_hyp)
else:
new_beams.append(new_hyp)
live_hyp_ids.append(prev_hyp_id)
if live_hyp_ids:
h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids])
att_tm1 = att_t[live_hyp_ids]
beams = new_beams
t += 1
else:
break
completed_beams.sort(key=lambda hyp: -hyp.score)
return [completed_beams, sketch_actions]
def step(self, x, h_tm1, src_encodings, src_encodings_att_linear, decoder, attention_func, src_token_mask=None,
return_att_weight=False):
# h_t: (batch_size, hidden_size)
h_t, cell_t = decoder(x, h_tm1)
ctx_t, alpha_t = nn_utils.dot_prod_attention(h_t,
src_encodings, src_encodings_att_linear,
mask=src_token_mask)
att_t = F.tanh(attention_func(torch.cat([h_t, ctx_t], 1)))
att_t = self.dropout(att_t)
if return_att_weight:
return (h_t, cell_t), att_t, alpha_t
else:
return (h_t, cell_t), att_t
def init_decoder_state(self, enc_last_cell):
h_0 = self.decoder_cell_init(enc_last_cell)
h_0 = F.tanh(h_0)
return h_0, Variable(self.new_tensor(h_0.size()).zero_())

236
src/models/nn_utils.py Normal file
Просмотреть файл

@ -0,0 +1,236 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : utils.py
# @Software: PyCharm
"""
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
import torch
from torch.autograd import Variable
from six.moves import xrange
def dot_prod_attention(h_t, src_encoding, src_encoding_att_linear, mask=None):
"""
:param h_t: (batch_size, hidden_size)
:param src_encoding: (batch_size, src_sent_len, hidden_size * 2)
:param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size)
:param mask: (batch_size, src_sent_len)
"""
# (batch_size, src_sent_len)
att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2)
if mask is not None:
att_weight.data.masked_fill_(mask, -float('inf'))
att_weight = F.softmax(att_weight, dim=-1)
att_view = (att_weight.size(0), 1, att_weight.size(1))
# (batch_size, hidden_size)
ctx_vec = torch.bmm(att_weight.view(*att_view), src_encoding).squeeze(1)
return ctx_vec, att_weight
def length_array_to_mask_tensor(length_array, cuda=False, value=None):
max_len = max(length_array)
batch_size = len(length_array)
mask = np.ones((batch_size, max_len), dtype=np.uint8)
for i, seq_len in enumerate(length_array):
mask[i][:seq_len] = 0
if value != None:
for b_id in range(len(value)):
for c_id, c in enumerate(value[b_id]):
if value[b_id][c_id] == [3]:
mask[b_id][c_id] = 1
mask = torch.ByteTensor(mask)
return mask.cuda() if cuda else mask
def table_dict_to_mask_tensor(length_array, table_dict, cuda=False ):
max_len = max(length_array)
batch_size = len(table_dict)
mask = np.ones((batch_size, max_len), dtype=np.uint8)
for i, ta_val in enumerate(table_dict):
for tt in ta_val:
mask[i][tt] = 0
mask = torch.ByteTensor(mask)
return mask.cuda() if cuda else mask
def length_position_tensor(length_array, cuda=False, value=None):
max_len = max(length_array)
batch_size = len(length_array)
mask = np.zeros((batch_size, max_len), dtype=np.float32)
for b_id in range(batch_size):
for len_c in range(length_array[b_id]):
mask[b_id][len_c] = len_c + 1
mask = torch.LongTensor(mask)
return mask.cuda() if cuda else mask
def appear_to_mask_tensor(length_array, cuda=False, value=None):
max_len = max(length_array)
batch_size = len(length_array)
mask = np.zeros((batch_size, max_len), dtype=np.float32)
return mask
def pred_col_mask(value, max_len):
max_len = max(max_len)
batch_size = len(value)
mask = np.ones((batch_size, max_len), dtype=np.uint8)
for v_ind, v_val in enumerate(value):
for v in v_val:
mask[v_ind][v] = 0
mask = torch.ByteTensor(mask)
return mask.cuda()
def input_transpose(sents, pad_token):
"""
transform the input List[sequence] of size (batch_size, max_sent_len)
into a list of size (batch_size, max_sent_len), with proper padding
"""
max_len = max(len(s) for s in sents)
batch_size = len(sents)
sents_t = []
masks = []
for e_id in range(batch_size):
if type(sents[0][0]) != list:
sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else pad_token for i in range(max_len)])
else:
sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else [pad_token] for i in range(max_len)])
masks.append([1 if len(sents[e_id]) > i else 0 for i in range(max_len)])
return sents_t, masks
def word2id(sents, vocab):
if type(sents[0]) == list:
if type(sents[0][0]) != list:
return [[vocab[w] for w in s] for s in sents]
else:
return [[[vocab[w] for w in s] for s in v] for v in sents ]
else:
return [vocab[w] for w in sents]
def id2word(sents, vocab):
if type(sents[0]) == list:
return [[vocab.id2word[w] for w in s] for s in sents]
else:
return [vocab.id2word[w] for w in sents]
def to_input_variable(sequences, vocab, cuda=False, training=True):
"""
given a list of sequences,
return a tensor of shape (max_sent_len, batch_size)
"""
word_ids = word2id(sequences, vocab)
sents_t, masks = input_transpose(word_ids, vocab['<pad>'])
if type(sents_t[0][0]) != list:
with torch.no_grad():
sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False)
if cuda:
sents_var = sents_var.cuda()
else:
sents_var = sents_t
return sents_var
def variable_constr(x, v, cuda=False):
return Variable(torch.cuda.x(v)) if cuda else Variable(torch.x(v))
def batch_iter(examples, batch_size, shuffle=False):
index_arr = np.arange(len(examples))
if shuffle:
np.random.shuffle(index_arr)
batch_num = int(np.ceil(len(examples) / float(batch_size)))
for batch_id in xrange(batch_num):
batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)]
batch_examples = [examples[i] for i in batch_ids]
yield batch_examples
def isnan(data):
data = data.cpu().numpy()
return np.isnan(data).any() or np.isinf(data).any()
def log_sum_exp(inputs, dim=None, keepdim=False):
"""Numerically stable logsumexp.
source: https://github.com/pytorch/pytorch/issues/2591
Args:
inputs: A Variable with any shape.
dim: An integer.
keepdim: A boolean.
Returns:
Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
"""
# For a 1-D array x (any array along a single dimension),
# log sum exp(x) = s + log sum exp(x - s)
# with s = max(x) being a common choice.
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs
def uniform_init(lower, upper, params):
for p in params:
p.data.uniform_(lower, upper)
def glorot_init(params):
for p in params:
if len(p.data.size()) > 1:
init.xavier_normal(p.data)
def identity(x):
return x
def pad_matrix(matrixs, cuda=False):
"""
:param matrixs:
:return: [batch_size, max_shape, max_shape], [batch_size]
"""
shape = [m.shape[0] for m in matrixs]
max_shape = max(shape)
tensors = list()
for s, m in zip(shape, matrixs):
delta = max_shape - s
if s > 0:
tensors.append(torch.as_tensor(np.pad(m, [(0, delta), (0, delta)], mode='constant'), dtype=torch.float))
else:
tensors.append(torch.as_tensor(m, dtype=torch.float))
tensors = torch.stack(tensors)
if cuda:
tensors = tensors.cuda()
return tensors

106
src/models/pointer_net.py Normal file
Просмотреть файл

@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# coding=utf8
import torch
import torch.nn as nn
import torch.nn.utils
from torch.nn import Parameter
class AuxiliaryPointerNet(nn.Module):
def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
super(AuxiliaryPointerNet, self).__init__()
assert attention_type in ('affine', 'dot_prod')
if attention_type == 'affine':
self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
self.auxiliary_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
self.attention_type = attention_type
def forward(self, src_encodings, src_context_encodings, src_token_mask, query_vec):
"""
:param src_context_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
:param src_encodings: Variable(batch_size, src_sent_len, src_encoding_size)
:param src_token_mask: Variable(batch_size, src_sent_len)
:param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
:return: Variable(tgt_action_num, batch_size, src_sent_len)
"""
# (batch_size, 1, src_sent_len, query_vec_size)
encodings = src_encodings.clone()
context_encodings = src_context_encodings.clone()
if self.attention_type == 'affine':
encodings = self.src_encoding_linear(src_encodings)
context_encodings = self.auxiliary_encoding_linear(src_context_encodings)
encodings = encodings.unsqueeze(1)
context_encodings = context_encodings.unsqueeze(1)
# (batch_size, tgt_action_num, query_vec_size, 1)
q = query_vec.permute(1, 0, 2).unsqueeze(3)
# (batch_size, tgt_action_num, src_sent_len)
weights = torch.matmul(encodings, q).squeeze(3)
context_weights = torch.matmul(context_encodings, q).squeeze(3)
# (tgt_action_num, batch_size, src_sent_len)
weights = weights.permute(1, 0, 2)
context_weights = context_weights.permute(1, 0, 2)
if src_token_mask is not None:
# (tgt_action_num, batch_size, src_sent_len)
src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
weights.data.masked_fill_(src_token_mask, -float('inf'))
context_weights.data.masked_fill_(src_token_mask, -float('inf'))
sigma = 0.1
return weights.squeeze(0) + sigma * context_weights.squeeze(0)
class PointerNet(nn.Module):
def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'):
super(PointerNet, self).__init__()
assert attention_type in ('affine', 'dot_prod')
if attention_type == 'affine':
self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False)
self.attention_type = attention_type
self.input_linear = nn.Linear(query_vec_size, query_vec_size)
self.type_linear = nn.Linear(32, query_vec_size)
self.V = Parameter(torch.FloatTensor(query_vec_size), requires_grad=True)
self.tanh = nn.Tanh()
self.context_linear = nn.Conv1d(src_encoding_size, query_vec_size, 1, 1)
self.coverage_linear = nn.Conv1d(1, query_vec_size, 1, 1)
nn.init.uniform_(self.V, -1, 1)
def forward(self, src_encodings, src_token_mask, query_vec):
"""
:param src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2)
:param src_token_mask: Variable(batch_size, src_sent_len)
:param query_vec: Variable(tgt_action_num, batch_size, query_vec_size)
:return: Variable(tgt_action_num, batch_size, src_sent_len)
"""
# (batch_size, 1, src_sent_len, query_vec_size)
if self.attention_type == 'affine':
src_encodings = self.src_encoding_linear(src_encodings)
src_encodings = src_encodings.unsqueeze(1)
# (batch_size, tgt_action_num, query_vec_size, 1)
q = query_vec.permute(1, 0, 2).unsqueeze(3)
weights = torch.matmul(src_encodings, q).squeeze(3)
weights = weights.permute(1, 0, 2)
if src_token_mask is not None:
src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights)
weights.data.masked_fill_(src_token_mask, -float('inf'))
return weights.squeeze(0)

10
src/rule/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/27
# @Author : Jiaqi&Zecheng
# @File : __init__.py.py
# @Software: PyCharm
"""

130
src/rule/graph.py Normal file
Просмотреть файл

@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : utils.py
# @Software: PyCharm
"""
from collections import deque, namedtuple
# we'll use infinity as a default distance to nodes.
inf = float('inf')
Edge = namedtuple('Edge', 'start, end, cost')
def make_edge(start, end, cost=1):
return Edge(start, end, cost)
class Graph:
def __init__(self, edges):
# let's check that the data is right
wrong_edges = [i for i in edges if len(i) not in [2, 3]]
if wrong_edges:
raise ValueError('Wrong edges data: {}'.format(wrong_edges))
self.edges = [make_edge(*edge) for edge in edges]
@property
def vertices(self):
return set(
# this piece of magic turns ([1,2], [3,4]) into [1, 2, 3, 4]
# the set above makes it's elements unique.
sum(
([edge.start, edge.end] for edge in self.edges), []
)
)
def get_node_pairs(self, n1, n2, both_ends=True):
if both_ends:
node_pairs = [[n1, n2], [n2, n1]]
else:
node_pairs = [[n1, n2]]
return node_pairs
def remove_edge(self, n1, n2, both_ends=True):
node_pairs = self.get_node_pairs(n1, n2, both_ends)
edges = self.edges[:]
for edge in edges:
if [edge.start, edge.end] in node_pairs:
self.edges.remove(edge)
def add_edge(self, n1, n2, cost=1, both_ends=True):
node_pairs = self.get_node_pairs(n1, n2, both_ends)
for edge in self.edges:
if [edge.start, edge.end] in node_pairs:
return ValueError('Edge {} {} already exists'.format(n1, n2))
self.edges.append(Edge(start=n1, end=n2, cost=cost))
if both_ends:
self.edges.append(Edge(start=n2, end=n1, cost=cost))
@property
def neighbours(self):
neighbours = {vertex: set() for vertex in self.vertices}
for edge in self.edges:
neighbours[edge.start].add((edge.end, edge.cost))
return neighbours
def dijkstra(self, source, dest):
assert source in self.vertices, 'Such source node doesn\'t exist'
assert dest in self.vertices, 'Such source node doesn\'t exis'
# 1. Mark all nodes unvisited and store them.
# 2. Set the distance to zero for our initial node
# and to infinity for other nodes.
distances = {vertex: inf for vertex in self.vertices}
previous_vertices = {
vertex: None for vertex in self.vertices
}
distances[source] = 0
vertices = self.vertices.copy()
while vertices:
# 3. Select the unvisited node with the smallest distance,
# it's current node now.
current_vertex = min(
vertices, key=lambda vertex: distances[vertex])
# 6. Stop, if the smallest distance
# among the unvisited nodes is infinity.
if distances[current_vertex] == inf:
break
# 4. Find unvisited neighbors for the current node
# and calculate their distances through the current node.
for neighbour, cost in self.neighbours[current_vertex]:
alternative_route = distances[current_vertex] + cost
# Compare the newly calculated distance to the assigned
# and save the smaller one.
if alternative_route < distances[neighbour]:
distances[neighbour] = alternative_route
previous_vertices[neighbour] = current_vertex
# 5. Mark the current node as visited
# and remove it from the unvisited set.
vertices.remove(current_vertex)
path, current_vertex = deque(), dest
while previous_vertices[current_vertex] is not None:
path.appendleft(current_vertex)
current_vertex = previous_vertices[current_vertex]
if path:
path.appendleft(current_vertex)
return path
if __name__ == '__main__':
graph = Graph([
("a", "b", 7), ("a", "c", 9), ("a", "f", 14), ("b", "c", 10),
("b", "d", 15), ("c", "d", 11), ("c", "f", 2), ("d", "e", 6),
("e", "f", 9)])
print(graph.dijkstra("a", "e"))

229
src/rule/lf.py Normal file
Просмотреть файл

@ -0,0 +1,229 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : utils.py
# @Software: PyCharm
"""
import copy
import json
import numpy as np
from src.rule import semQL as define_rule
from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
def _build_single_filter(lf, f):
# No conjunction
agg = lf.pop(0)
column = lf.pop(0)
if len(lf) == 0:
table = None
else:
table = lf.pop(0)
if not isinstance(table, define_rule.T):
lf.insert(0, table)
table = None
assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
if len(f.production.split()) == 3:
f.add_children(agg)
agg.set_parent(f)
agg.add_children(column)
column.set_parent(agg)
if table is not None:
column.add_children(table)
table.set_parent(column)
else:
# Subquery
f.add_children(agg)
agg.set_parent(f)
agg.add_children(column)
column.set_parent(agg)
if table is not None:
column.add_children(table)
table.set_parent(column)
_root = _build(lf)
f.add_children(_root)
_root.set_parent(f)
def _build_filter(lf, root_filter):
assert isinstance(root_filter, define_rule.Filter)
op = root_filter.production.split()[1]
if op == 'and' or op == 'or':
for i in range(2):
child = lf.pop(0)
op = child.production.split()[1]
if op == 'and' or op == 'or':
_f = _build_filter(lf, child)
root_filter.add_children(_f)
_f.set_parent(root_filter)
else:
_build_single_filter(lf, child)
root_filter.add_children(child)
child.set_parent(root_filter)
else:
_build_single_filter(lf, root_filter)
return root_filter
def _build(lf):
root = lf.pop(0)
assert isinstance(root, define_rule.Root)
length = len(root.production.split()) - 1
while len(root.children) != length:
c_instance = lf.pop(0)
if isinstance(c_instance, define_rule.Sel):
sel_instance = c_instance
root.add_children(sel_instance)
sel_instance.set_parent(root)
# define_rule.N
c_instance = lf.pop(0)
c_instance.set_parent(sel_instance)
sel_instance.add_children(c_instance)
assert isinstance(c_instance, define_rule.N)
for i in range(c_instance.id_c + 1):
agg = lf.pop(0)
column = lf.pop(0)
if len(lf) == 0:
table = None
else:
table = lf.pop(0)
if not isinstance(table, define_rule.T):
lf.insert(0, table)
table = None
assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
c_instance.add_children(agg)
agg.set_parent(c_instance)
agg.add_children(column)
column.set_parent(agg)
if table is not None:
column.add_children(table)
table.set_parent(column)
elif isinstance(c_instance, define_rule.Sup) or isinstance(c_instance, define_rule.Order):
root.add_children(c_instance)
c_instance.set_parent(root)
agg = lf.pop(0)
column = lf.pop(0)
if len(lf) == 0:
table = None
else:
table = lf.pop(0)
if not isinstance(table, define_rule.T):
lf.insert(0, table)
table = None
assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C)
c_instance.add_children(agg)
agg.set_parent(c_instance)
agg.add_children(column)
column.set_parent(agg)
if table is not None:
column.add_children(table)
table.set_parent(column)
elif isinstance(c_instance, define_rule.Filter):
_build_filter(lf, c_instance)
root.add_children(c_instance)
c_instance.set_parent(root)
return root
def build_tree(lf):
root = lf.pop(0)
assert isinstance(root, define_rule.Root1)
if root.id_c == 0 or root.id_c == 1 or root.id_c == 2:
root_1 = _build(lf)
root_2 = _build(lf)
root.add_children(root_1)
root.add_children(root_2)
root_1.set_parent(root)
root_2.set_parent(root)
else:
root_1 = _build(lf)
root.add_children(root_1)
root_1.set_parent(root)
verify(root)
# eliminate_parent(root)
def eliminate_parent(node):
for child in node.children:
eliminate_parent(child)
node.children = list()
def verify(node):
if isinstance(node, C) and len(node.children) > 0:
table = node.children[0]
assert table is None or isinstance(table, T)
if isinstance(node, T):
return
children_num = len(node.children)
if isinstance(node, Root1):
if node.id_c == 0 or node.id_c == 1 or node.id_c == 2:
assert children_num == 2
else:
assert children_num == 1
elif isinstance(node, Root):
assert children_num == len(node.production.split()) - 1
elif isinstance(node, N):
assert children_num == int(node.id_c) + 1
elif isinstance(node, Sup) or isinstance(node, Order) or isinstance(node, Sel):
assert children_num == 1
elif isinstance(node, Filter):
op = node.production.split()[1]
if op == 'and' or op == 'or':
assert children_num == 2
else:
if len(node.production.split()) == 3:
assert children_num == 1
else:
assert children_num == 2
for child in node.children:
assert child.parent == node
verify(child)
def label_matrix(lf, matrix, node):
nindex = lf.index(node)
for child in node.children:
if child not in lf:
continue
index = lf.index(child)
matrix[nindex][index] = 1
label_matrix(lf, matrix, child)
def build_adjacency_matrix(lf, symmetry=False):
_lf = list()
for rule in lf:
if isinstance(rule, A) or isinstance(rule, C) or isinstance(rule, T):
continue
_lf.append(rule)
length = len(_lf)
matrix = np.zeros((length, length,))
label_matrix(_lf, matrix, _lf[0])
if symmetry:
matrix += matrix.T
return matrix
if __name__ == '__main__':
with open(r'..\data\train.json', 'r') as f:
data = json.load(f)
for d in data:
rule_label = [eval(x) for x in d['rule_label'].strip().split(' ')]
print(d['question'])
print(rule_label)
build_tree(copy.copy(rule_label))
adjacency_matrix = build_adjacency_matrix(rule_label, symmetry=True)
print(adjacency_matrix)
print('===\n\n')

399
src/rule/semQL.py Normal file
Просмотреть файл

@ -0,0 +1,399 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/24
# @Author : Jiaqi&Zecheng
# @File : semQL.py
# @Software: PyCharm
"""
Keywords = ['des', 'asc', 'and', 'or', 'sum', 'min', 'max', 'avg', 'none', '=', '!=', '<', '>', '<=', '>=', 'between', 'like', 'not_like'] + [
'in', 'not_in', 'count', 'intersect', 'union', 'except'
]
class Grammar(object):
def __init__(self, is_sketch=False):
self.begin = 0
self.type_id = 0
self.is_sketch = is_sketch
self.prod2id = {}
self.type2id = {}
self._init_grammar(Sel)
self._init_grammar(Root)
self._init_grammar(Sup)
self._init_grammar(Filter)
self._init_grammar(Order)
self._init_grammar(N)
self._init_grammar(Root1)
if not self.is_sketch:
self._init_grammar(A)
self._init_id2prod()
self.type2id[C] = self.type_id
self.type_id += 1
self.type2id[T] = self.type_id
def _init_grammar(self, Cls):
"""
get the production of class Cls
:param Cls:
:return:
"""
production = Cls._init_grammar()
for p in production:
self.prod2id[p] = self.begin
self.begin += 1
self.type2id[Cls] = self.type_id
self.type_id += 1
def _init_id2prod(self):
self.id2prod = {}
for key, value in self.prod2id.items():
self.id2prod[value] = key
def get_production(self, Cls):
return Cls._init_grammar()
class Action(object):
def __init__(self):
self.pt = 0
self.production = None
self.children = list()
def get_next_action(self, is_sketch=False):
actions = list()
for x in self.production.split(' ')[1:]:
if x not in Keywords:
rule_type = eval(x)
if is_sketch:
if rule_type is not A:
actions.append(rule_type)
else:
actions.append(rule_type)
return actions
def set_parent(self, parent):
self.parent = parent
def add_children(self, child):
self.children.append(child)
class Root1(Action):
def __init__(self, id_c, parent=None):
super(Root1, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
# TODO: should add Root grammar to this
self.grammar_dict = {
0: 'Root1 intersect Root Root',
1: 'Root1 union Root Root',
2: 'Root1 except Root Root',
3: 'Root1 Root',
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'Root1(' + str(self.id_c) + ')'
def __repr__(self):
return 'Root1(' + str(self.id_c) + ')'
class Root(Action):
def __init__(self, id_c, parent=None):
super(Root, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
# TODO: should add Root grammar to this
self.grammar_dict = {
0: 'Root Sel Sup Filter',
1: 'Root Sel Filter Order',
2: 'Root Sel Sup',
3: 'Root Sel Filter',
4: 'Root Sel Order',
5: 'Root Sel'
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'Root(' + str(self.id_c) + ')'
def __repr__(self):
return 'Root(' + str(self.id_c) + ')'
class N(Action):
"""
Number of Columns
"""
def __init__(self, id_c, parent=None):
super(N, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
self.grammar_dict = {
0: 'N A',
1: 'N A A',
2: 'N A A A',
3: 'N A A A A',
4: 'N A A A A A'
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'N(' + str(self.id_c) + ')'
def __repr__(self):
return 'N(' + str(self.id_c) + ')'
class C(Action):
"""
Column
"""
def __init__(self, id_c, parent=None):
super(C, self).__init__()
self.parent = parent
self.id_c = id_c
self.production = 'C T'
self.table = None
def __str__(self):
return 'C(' + str(self.id_c) + ')'
def __repr__(self):
return 'C(' + str(self.id_c) + ')'
class T(Action):
"""
Table
"""
def __init__(self, id_c, parent=None):
super(T, self).__init__()
self.parent = parent
self.id_c = id_c
self.production = 'T min'
self.table = None
def __str__(self):
return 'T(' + str(self.id_c) + ')'
def __repr__(self):
return 'T(' + str(self.id_c) + ')'
class A(Action):
"""
Aggregator
"""
def __init__(self, id_c, parent=None):
super(A, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
# TODO: should add Root grammar to this
self.grammar_dict = {
0: 'A none C',
1: 'A max C',
2: "A min C",
3: "A count C",
4: "A sum C",
5: "A avg C"
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'A(' + str(self.id_c) + ')'
def __repr__(self):
return 'A(' + str(self.grammar_dict[self.id_c].split(' ')[1]) + ')'
class Sel(Action):
"""
Select
"""
def __init__(self, id_c, parent=None):
super(Sel, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
self.grammar_dict = {
0: 'Sel N',
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'Sel(' + str(self.id_c) + ')'
def __repr__(self):
return 'Sel(' + str(self.id_c) + ')'
class Filter(Action):
"""
Filter
"""
def __init__(self, id_c, parent=None):
super(Filter, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
self.grammar_dict = {
# 0: "Filter 1"
0: 'Filter and Filter Filter',
1: 'Filter or Filter Filter',
2: 'Filter = A',
3: 'Filter != A',
4: 'Filter < A',
5: 'Filter > A',
6: 'Filter <= A',
7: 'Filter >= A',
8: 'Filter between A',
9: 'Filter like A',
10: 'Filter not_like A',
# now begin root
11: 'Filter = A Root',
12: 'Filter < A Root',
13: 'Filter > A Root',
14: 'Filter != A Root',
15: 'Filter between A Root',
16: 'Filter >= A Root',
17: 'Filter <= A Root',
# now for In
18: 'Filter in A Root',
19: 'Filter not_in A Root'
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'Filter(' + str(self.id_c) + ')'
def __repr__(self):
return 'Filter(' + str(self.grammar_dict[self.id_c]) + ')'
class Sup(Action):
"""
Superlative
"""
def __init__(self, id_c, parent=None):
super(Sup, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
self.grammar_dict = {
0: 'Sup des A',
1: 'Sup asc A',
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'Sup(' + str(self.id_c) + ')'
def __repr__(self):
return 'Sup(' + str(self.id_c) + ')'
class Order(Action):
"""
Order
"""
def __init__(self, id_c, parent=None):
super(Order, self).__init__()
self.parent = parent
self.id_c = id_c
self._init_grammar()
self.production = self.grammar_dict[id_c]
@classmethod
def _init_grammar(self):
self.grammar_dict = {
0: 'Order des A',
1: 'Order asc A',
}
self.production_id = {}
for id_x, value in enumerate(self.grammar_dict.values()):
self.production_id[value] = id_x
return self.grammar_dict.values()
def __str__(self):
return 'Order(' + str(self.id_c) + ')'
def __repr__(self):
return 'Order(' + str(self.id_c) + ')'
if __name__ == '__main__':
print(list(Root._init_grammar()))

321
src/rule/sem_utils.py Normal file
Просмотреть файл

@ -0,0 +1,321 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/27
# @Author : Jiaqi&Zecheng
# @File : sem_utils.py
# @Software: PyCharm
"""
import os
import json
import argparse
import re as regex
from nltk.stem import WordNetLemmatizer
from pattern.en import lemma
wordnet_lemmatizer = WordNetLemmatizer()
def load_dataSets(args):
with open(args.input_path, 'r') as f:
datas = json.load(f)
with open(os.path.join(args.data_path, 'tables.json'), 'r', encoding='utf8') as f:
table_datas = json.load(f)
schemas = dict()
for i in range(len(table_datas)):
schemas[table_datas[i]['db_id']] = table_datas[i]
return datas, schemas
def partial_match(query, table_name):
query = [lemma(x) for x in query]
table_name = [lemma(x) for x in table_name]
if query in table_name:
return True
return False
def is_partial_match(query, table_names):
query = lemma(query)
table_names = [[lemma(x) for x in names.split(' ') ] for names in table_names]
same_count = 0
result = None
for names in table_names:
if query in names:
same_count += 1
result = names
return result if same_count == 1 else False
def multi_option(question, q_ind, names, N):
for i in range(q_ind + 1, q_ind + N + 1):
if i < len(question):
re = is_partial_match(question[i][0], names)
if re is not False:
return re
return False
def multi_equal(question, q_ind, names, N):
for i in range(q_ind + 1, q_ind + N + 1):
if i < len(question):
if question[i] == names:
return i
return False
def random_choice(question_arg, question_arg_type, names, ground_col_labels, q_ind, N, origin_name):
# first try if there are other table
for t_ind, t_val in enumerate(question_arg_type):
if t_val == ['table']:
return names[origin_name.index(question_arg[t_ind])]
for i in range(q_ind + 1, q_ind + N + 1):
if i < len(question_arg):
if len(ground_col_labels) == 0:
for n in names:
if partial_match(question_arg[i][0], n) is True:
return n
else:
for n_id, n in enumerate(names):
if n_id in ground_col_labels and partial_match(question_arg[i][0], n) is True:
return n
if len(ground_col_labels) > 0:
return names[ground_col_labels[0]]
else:
return names[0]
def find_table(cur_table, origin_table_names, question_arg_type, question_arg):
h_table = None
for i in range(len(question_arg_type))[::-1]:
if question_arg_type[i] == ['table']:
h_table = question_arg[i]
h_table = origin_table_names.index(h_table)
if h_table != cur_table:
break
if h_table != cur_table:
return h_table
# find partial
for i in range(len(question_arg_type))[::-1]:
if question_arg_type[i] == ['NONE']:
for t_id, table_name in enumerate(origin_table_names):
if partial_match(question_arg[i], table_name) is True and t_id != h_table:
return t_id
# random return
for i in range(len(question_arg_type))[::-1]:
if question_arg_type[i] == ['table']:
h_table = question_arg[i]
h_table = origin_table_names.index(h_table)
return h_table
return cur_table
def alter_not_in(datas, schemas):
for d in datas:
if 'Filter(19)' in d['model_result']:
current_table = schemas[d['db_id']]
current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']]
current_table['col_table'] = [col[0] for col in current_table['column_names']]
origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in
d['table_names']]
question_arg_type = d['question_arg_type']
question_arg = d['question_arg']
pred_label = d['model_result'].split(' ')
# get potiantial table
cur_table = None
for label_id, label_val in enumerate(pred_label):
if label_val in ['Filter(19)']:
cur_table = int(pred_label[label_id - 1][2:-1])
break
h_table = find_table(cur_table, origin_table_names, question_arg_type, question_arg)
for label_id, label_val in enumerate(pred_label):
if label_val in ['Filter(19)']:
for primary in current_table['primary_keys']:
if int(current_table['col_table'][primary]) == int(pred_label[label_id - 1][2:-1]):
pred_label[label_id + 2] = 'C(' + str(
d['col_set'].index(current_table['schema_content_clean'][primary])) + ')'
break
for pair in current_table['foreign_keys']:
if int(current_table['col_table'][pair[0]]) == h_table and d['col_set'].index(
current_table['schema_content_clean'][pair[1]]) == int(pred_label[label_id + 2][2:-1]):
pred_label[label_id + 8] = 'C(' + str(
d['col_set'].index(current_table['schema_content_clean'][pair[0]])) + ')'
pred_label[label_id + 9] = 'T(' + str(h_table) + ')'
break
elif int(current_table['col_table'][pair[1]]) == h_table and d['col_set'].index(
current_table['schema_content_clean'][pair[0]]) == int(pred_label[label_id + 2][2:-1]):
pred_label[label_id + 8] = 'C(' + str(
d['col_set'].index(current_table['schema_content_clean'][pair[1]])) + ')'
pred_label[label_id + 9] = 'T(' + str(h_table) + ')'
break
pred_label[label_id + 3] = pred_label[label_id - 1]
d['model_result'] = " ".join(pred_label)
def alter_inter(datas):
for d in datas:
if 'Filter(0)' in d['model_result']:
now_result = d['model_result'].split(' ')
index = now_result.index('Filter(0)')
c1 = None
c2 = None
for i in range(index + 1, len(now_result)):
if c1 is None and 'C(' in now_result[i]:
c1 = now_result[i]
elif c1 is not None and c2 is None and 'C(' in now_result[i]:
c2 = now_result[i]
if c1 != c2 or c1 is None or c2 is None:
continue
replace_result = ['Root1(0)'] + now_result[1:now_result.index('Filter(0)')]
for r_id, r_val in enumerate(now_result[now_result.index('Filter(0)') + 2:]):
if 'Filter' in r_val:
break
replace_result = replace_result + now_result[now_result.index('Filter(0)') + 1:r_id + now_result.index(
'Filter(0)') + 2]
replace_result = replace_result + now_result[1:now_result.index('Filter(0)')]
replace_result = replace_result + now_result[r_id + now_result.index('Filter(0)') + 2:]
replace_result = " ".join(replace_result)
d['model_result'] = replace_result
def alter_column0(datas):
"""
Attach column * table
:return: model_result_replace
"""
zero_count = 0
count = 0
result = []
for d in datas:
if 'C(0)' in d['model_result']:
pattern = regex.compile('C\(.*?\) T\(.*?\)')
result_pattern = list(set(pattern.findall(d['model_result'])))
ground_col_labels = []
for pa in result_pattern:
pa = pa.split(' ')
if pa[0] != 'C(0)':
index = int(pa[1][2:-1])
ground_col_labels.append(index)
ground_col_labels = list(set(ground_col_labels))
question_arg_type = d['question_arg_type']
question_arg = d['question_arg']
table_names = [[lemma(x) for x in names.split(' ')] for names in d['table_names']]
origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in
d['table_names']]
count += 1
easy_flag = False
for q_ind, q in enumerate(d['question_arg']):
q = [lemma(x) for x in q]
q_str = " ".join(" ".join(x) for x in d['question_arg'])
if 'how many' in q_str or 'number of' in q_str or 'count of' in q_str:
easy_flag = True
if easy_flag:
# check for the last one is a table word
for q_ind, q in enumerate(d['question_arg']):
if (q_ind > 0 and q == ['many'] and d['question_arg'][q_ind - 1] == ['how']) or (
q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['number']) or (
q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['count']):
re = multi_equal(question_arg_type, q_ind, ['table'], 2)
if re is not False:
# This step work for the number of [table] example
table_result = table_names[origin_table_names.index(question_arg[re])]
result.append((d['query'], d['question'], table_result, d))
break
else:
re = multi_option(question_arg, q_ind, d['table_names'], 2)
if re is not False:
table_result = re
result.append((d['query'], d['question'], table_result, d))
pass
else:
re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type))
if re is not False:
# This step work for the number of [table] example
table_result = table_names[origin_table_names.index(question_arg[re])]
result.append((d['query'], d['question'], table_result, d))
break
pass
table_result = random_choice(question_arg=question_arg,
question_arg_type=question_arg_type,
names=table_names,
ground_col_labels=ground_col_labels, q_ind=q_ind, N=2,
origin_name=origin_table_names)
result.append((d['query'], d['question'], table_result, d))
zero_count += 1
break
else:
M_OP = False
for q_ind, q in enumerate(d['question_arg']):
if M_OP is False and q in [['than'], ['least'], ['most'], ['msot'], ['fewest']] or \
question_arg_type[q_ind] == ['M_OP']:
M_OP = True
re = multi_equal(question_arg_type, q_ind, ['table'], 3)
if re is not False:
# This step work for the number of [table] example
table_result = table_names[origin_table_names.index(question_arg[re])]
result.append((d['query'], d['question'], table_result, d))
break
else:
re = multi_option(question_arg, q_ind, d['table_names'], 3)
if re is not False:
table_result = re
# print(table_result)
result.append((d['query'], d['question'], table_result, d))
pass
else:
# zero_count += 1
re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type))
if re is not False:
# This step work for the number of [table] example
table_result = table_names[origin_table_names.index(question_arg[re])]
result.append((d['query'], d['question'], table_result, d))
break
table_result = random_choice(question_arg=question_arg,
question_arg_type=question_arg_type,
names=table_names,
ground_col_labels=ground_col_labels, q_ind=q_ind, N=2,
origin_name=origin_table_names)
result.append((d['query'], d['question'], table_result, d))
pass
if M_OP is False:
table_result = random_choice(question_arg=question_arg,
question_arg_type=question_arg_type,
names=table_names, ground_col_labels=ground_col_labels, q_ind=q_ind,
N=2,
origin_name=origin_table_names)
result.append((d['query'], d['question'], table_result, d))
for re in result:
table_names = [[lemma(x) for x in names.split(' ')] for names in re[3]['table_names']]
origin_table_names = [[x for x in names.split(' ')] for names in re[3]['table_names']]
if re[2] in table_names:
re[3]['rule_count'] = table_names.index(re[2])
else:
re[3]['rule_count'] = origin_table_names.index(re[2])
for data in datas:
if 'rule_count' in data:
str_replace = 'C(0) T(' + str(data['rule_count']) + ')'
replace_result = regex.sub('C\(0\) T\(.\)', str_replace, data['model_result'])
data['model_result_replace'] = replace_result
else:
data['model_result_replace'] = data['model_result']

348
src/utils.py Normal file
Просмотреть файл

@ -0,0 +1,348 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : utils.py
# @Software: PyCharm
"""
import json
import time
import copy
import numpy as np
import os
import torch
from nltk.stem import WordNetLemmatizer
from src.dataset import Example
from src.rule import lf
from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1
wordnet_lemmatizer = WordNetLemmatizer()
def load_word_emb(file_name, use_small=False):
print ('Loading word embedding from %s'%file_name)
ret = {}
with open(file_name) as inf:
for idx, line in enumerate(inf):
if (use_small and idx >= 500000):
break
info = line.strip().split(' ')
if info[0].lower() not in ret:
ret[info[0]] = np.array(list(map(lambda x:float(x), info[1:])))
return ret
def lower_keys(x):
if isinstance(x, list):
return [lower_keys(v) for v in x]
elif isinstance(x, dict):
return dict((k.lower(), lower_keys(v)) for k, v in x.items())
else:
return x
def get_table_colNames(tab_ids, tab_cols):
table_col_dict = {}
for ci, cv in zip(tab_ids, tab_cols):
if ci != -1:
table_col_dict[ci] = table_col_dict.get(ci, []) + cv
result = []
for ci in range(len(table_col_dict)):
result.append(table_col_dict[ci])
return result
def get_col_table_dict(tab_cols, tab_ids, sql):
table_dict = {}
for c_id, c_v in enumerate(sql['col_set']):
for cor_id, cor_val in enumerate(tab_cols):
if c_v == cor_val:
table_dict[tab_ids[cor_id]] = table_dict.get(tab_ids[cor_id], []) + [c_id]
col_table_dict = {}
for key_item, value_item in table_dict.items():
for value in value_item:
col_table_dict[value] = col_table_dict.get(value, []) + [key_item]
col_table_dict[0] = [x for x in range(len(table_dict) - 1)]
return col_table_dict
def schema_linking(question_arg, question_arg_type, one_hot_type, col_set_type, col_set_iter, sql):
for count_q, t_q in enumerate(question_arg_type):
t = t_q[0]
if t == 'NONE':
continue
elif t == 'table':
one_hot_type[count_q][0] = 1
question_arg[count_q] = ['table'] + question_arg[count_q]
elif t == 'col':
one_hot_type[count_q][1] = 1
try:
col_set_type[col_set_iter.index(question_arg[count_q])][1] = 5
question_arg[count_q] = ['column'] + question_arg[count_q]
except:
print(col_set_iter, question_arg[count_q])
raise RuntimeError("not in col set")
elif t == 'agg':
one_hot_type[count_q][2] = 1
elif t == 'MORE':
one_hot_type[count_q][3] = 1
elif t == 'MOST':
one_hot_type[count_q][4] = 1
elif t == 'value':
one_hot_type[count_q][5] = 1
question_arg[count_q] = ['value'] + question_arg[count_q]
else:
if len(t_q) == 1:
for col_probase in t_q:
if col_probase == 'asd':
continue
try:
col_set_type[sql['col_set'].index(col_probase)][2] = 5
question_arg[count_q] = ['value'] + question_arg[count_q]
except:
print(sql['col_set'], col_probase)
raise RuntimeError('not in col')
one_hot_type[count_q][5] = 1
else:
for col_probase in t_q:
if col_probase == 'asd':
continue
col_set_type[sql['col_set'].index(col_probase)][3] += 1
def process(sql, table):
process_dict = {}
origin_sql = sql['question_toks']
table_names = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in table['table_names']]
sql['pre_sql'] = copy.deepcopy(sql)
tab_cols = [col[1] for col in table['column_names']]
tab_ids = [col[0] for col in table['column_names']]
col_set_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in sql['col_set']]
col_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(" ")] for x in tab_cols]
q_iter_small = [wordnet_lemmatizer.lemmatize(x).lower() for x in origin_sql]
question_arg = copy.deepcopy(sql['question_arg'])
question_arg_type = sql['question_arg_type']
one_hot_type = np.zeros((len(question_arg_type), 6))
col_set_type = np.zeros((len(col_set_iter), 4))
process_dict['col_set_iter'] = col_set_iter
process_dict['q_iter_small'] = q_iter_small
process_dict['col_set_type'] = col_set_type
process_dict['question_arg'] = question_arg
process_dict['question_arg_type'] = question_arg_type
process_dict['one_hot_type'] = one_hot_type
process_dict['tab_cols'] = tab_cols
process_dict['tab_ids'] = tab_ids
process_dict['col_iter'] = col_iter
process_dict['table_names'] = table_names
return process_dict
def is_valid(rule_label, col_table_dict, sql):
try:
lf.build_tree(copy.copy(rule_label))
except:
print(rule_label)
flag = False
for r_id, rule in enumerate(rule_label):
if type(rule) == C:
try:
assert rule_label[r_id + 1].id_c in col_table_dict[rule.id_c], print(sql['question'])
except:
flag = True
print(sql['question'])
return flag is False
def to_batch_seq(sql_data, table_data, idxes, st, ed,
is_train=True):
"""
:return:
"""
examples = []
for i in range(st, ed):
sql = sql_data[idxes[i]]
table = table_data[sql['db_id']]
process_dict = process(sql, table)
for c_id, col_ in enumerate(process_dict['col_set_iter']):
for q_id, ori in enumerate(process_dict['q_iter_small']):
if ori in col_:
process_dict['col_set_type'][c_id][0] += 1
schema_linking(process_dict['question_arg'], process_dict['question_arg_type'],
process_dict['one_hot_type'], process_dict['col_set_type'], process_dict['col_set_iter'], sql)
col_table_dict = get_col_table_dict(process_dict['tab_cols'], process_dict['tab_ids'], sql)
table_col_name = get_table_colNames(process_dict['tab_ids'], process_dict['col_iter'])
process_dict['col_set_iter'][0] = ['count', 'number', 'many']
rule_label = None
if 'rule_label' in sql:
rule_label = [eval(x) for x in sql['rule_label'].strip().split(' ')]
if is_valid(rule_label, col_table_dict=col_table_dict, sql=sql) is False:
continue
example = Example(
src_sent=process_dict['question_arg'],
col_num=len(process_dict['col_set_iter']),
vis_seq=(sql['question'], process_dict['col_set_iter'], sql['query']),
tab_cols=process_dict['col_set_iter'],
sql=sql['query'],
one_hot_type=process_dict['one_hot_type'],
col_hot_type=process_dict['col_set_type'],
table_names=process_dict['table_names'],
table_len=len(process_dict['table_names']),
col_table_dict=col_table_dict,
cols=process_dict['tab_cols'],
table_col_name=table_col_name,
table_col_len=len(table_col_name),
tokenized_src_sent=process_dict['col_set_type'],
tgt_actions=rule_label
)
example.sql_json = copy.deepcopy(sql)
examples.append(example)
if is_train:
examples.sort(key=lambda e: -len(e.src_sent))
return examples
else:
return examples
def epoch_train(model, optimizer, batch_size, sql_data, table_data,
args, epoch=0, loss_epoch_threshold=20, sketch_loss_coefficient=0.2):
model.train()
# shuffe
perm=np.random.permutation(len(sql_data))
cum_loss = 0.0
st = 0
while st < len(sql_data):
ed = st+batch_size if st+batch_size < len(perm) else len(perm)
examples = to_batch_seq(sql_data, table_data, perm, st, ed)
optimizer.zero_grad()
score = model.forward(examples)
loss_sketch = -score[0]
loss_lf = -score[1]
loss_sketch = torch.mean(loss_sketch)
loss_lf = torch.mean(loss_lf)
if epoch > loss_epoch_threshold:
loss = loss_lf + sketch_loss_coefficient * loss_sketch
else:
loss = loss_lf + loss_sketch
loss.backward()
if args.clip_grad > 0.:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
optimizer.step()
cum_loss += loss.data.cpu().numpy()*(ed - st)
st = ed
return cum_loss / len(sql_data)
def epoch_acc(model, batch_size, sql_data, table_data, beam_size=3):
model.eval()
perm = list(range(len(sql_data)))
st = 0
json_datas = []
while st < len(sql_data):
ed = st+batch_size if st+batch_size < len(perm) else len(perm)
examples = to_batch_seq(sql_data, table_data, perm, st, ed,
is_train=False)
for example in examples:
results_all = model.parse(example, beam_size=beam_size)
results = results_all[0]
list_preds = []
try:
pred = " ".join([str(x) for x in results[0].actions])
for x in results:
list_preds.append(" ".join(str(x.actions)))
except Exception as e:
# print('Epoch Acc: ', e)
# print(results)
# print(results_all)
pred = ""
simple_json = example.sql_json['pre_sql']
simple_json['sketch_result'] = " ".join(str(x) for x in results_all[1])
simple_json['model_result'] = pred
json_datas.append(simple_json)
st = ed
return json_datas
def eval_acc(preds, sqls):
sketch_correct, best_correct = 0, 0
for i, (pred, sql) in enumerate(zip(preds, sqls)):
if pred['model_result'] == sql['rule_label']:
best_correct += 1
print(best_correct / len(preds))
return best_correct / len(preds)
def load_data_new(sql_path, table_data, use_small=False):
sql_data = []
print("Loading data from %s" % sql_path)
with open(sql_path) as inf:
data = lower_keys(json.load(inf))
sql_data += data
table_data_new = {table['db_id']: table for table in table_data}
if use_small:
return sql_data[:80], table_data_new
else:
return sql_data, table_data_new
def load_dataset(dataset_dir, use_small=False):
print("Loading from datasets...")
TABLE_PATH = os.path.join(dataset_dir, "tables.json")
TRAIN_PATH = os.path.join(dataset_dir, "train.json")
DEV_PATH = os.path.join(dataset_dir, "dev.json")
with open(TABLE_PATH) as inf:
print("Loading data from %s"%TABLE_PATH)
table_data = json.load(inf)
train_sql_data, train_table_data = load_data_new(TRAIN_PATH, table_data, use_small=use_small)
val_sql_data, val_table_data = load_data_new(DEV_PATH, table_data, use_small=use_small)
return train_sql_data, train_table_data, val_sql_data, val_table_data
def save_checkpoint(model, checkpoint_name):
torch.save(model.state_dict(), checkpoint_name)
def save_args(args, path):
with open(path, 'w') as f:
f.write(json.dumps(vars(args), indent=4))
def init_log_checkpoint_path(args):
save_path = args.save
dir_name = save_path + str(int(time.time()))
save_path = os.path.join(os.path.curdir, 'saved_model', dir_name)
if os.path.exists(save_path) is False:
os.makedirs(save_path)
return save_path

117
train.py Normal file
Просмотреть файл

@ -0,0 +1,117 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/25
# @Author : Jiaqi&Zecheng
# @File : train.py
# @Software: PyCharm
"""
import time
import traceback
import os
import torch
import torch.optim as optim
import tqdm
import copy
from src import args as arg
from src import utils
from src.models.model import IRNet
from src.rule import semQL
def train(args):
"""
:param args:
:return:
"""
grammar = semQL.Grammar()
sql_data, table_data, val_sql_data,\
val_table_data= utils.load_dataset(args.dataset, use_small=args.toy)
model = IRNet(args, grammar)
if args.cuda: model.cuda()
# now get the optimizer
optimizer_cls = eval('torch.optim.%s' % args.optimizer)
optimizer = optimizer_cls(model.parameters(), lr=args.lr)
print('Enable Learning Rate Scheduler: ', args.lr_scheduler)
if args.lr_scheduler:
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar)
else:
scheduler = None
print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)
if args.load_model:
print('load pretrained model from %s'% (args.load_model))
pretrained_model = torch.load(args.load_model,
map_location=lambda storage, loc: storage)
pretrained_modeled = copy.deepcopy(pretrained_model)
for k in pretrained_model.keys():
if k not in model.state_dict().keys():
del pretrained_modeled[k]
model.load_state_dict(pretrained_modeled)
model.word_emb = utils.load_word_emb(args.glove_embed_path)
# begin train
model_save_path = utils.init_log_checkpoint_path(args)
utils.save_args(args, os.path.join(model_save_path, 'config.json'))
best_dev_acc = .0
try:
with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:
for epoch in tqdm.tqdm(range(args.epoch)):
if args.lr_scheduler:
scheduler.step()
epoch_begin = time.time()
loss = utils.epoch_train(model, optimizer, args.batch_size, sql_data, table_data, args,
loss_epoch_threshold=args.loss_epoch_threshold,
sketch_loss_coefficient=args.sketch_loss_coefficient)
epoch_end = time.time()
json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
beam_size=args.beam_size)
acc = utils.eval_acc(json_datas, val_sql_data)
if acc > best_dev_acc:
utils.save_checkpoint(model, os.path.join(model_save_path, 'best_model.model'))
best_dev_acc = acc
utils.save_checkpoint(model, os.path.join(model_save_path, '{%s}_{%s}.model') % (epoch, acc))
log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % (
epoch + 1, loss, acc, acc, epoch_end - epoch_begin)
tqdm.tqdm.write(log_str)
epoch_fd.write(log_str)
epoch_fd.flush()
except Exception as e:
# Save model
utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
print(e)
tb = traceback.format_exc()
print(tb)
else:
utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
beam_size=args.beam_size)
acc = utils.eval_acc(json_datas, val_sql_data)
print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (acc, acc, acc,))
if __name__ == '__main__':
arg_parser = arg.init_arg_parser()
args = arg.init_config(arg_parser)
print(args)
train(args)

24
train.sh Normal file
Просмотреть файл

@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#!/bin/bash
devices=$1
save_name=$2
CUDA_VISIBLE_DEVICES=$devices nohup python -u train.py --dataset ./data \
--glove_embed_path ./data/glove.42B.300d.txt \
--cuda \
--epoch 50 \
--loss_epoch_threshold 50 \
--sketch_loss_coefficie 1.0 \
--beam_size 1 \
--seed 90 \
--save ${save_name} \
--embed_size 300 \
--sentence_features \
--column_pointer \
--hidden_size 300 \
--lr_scheduler \
--lr_scheduler_gammar 0.5 \
--att_vec_size 300 > ${save_name}".log" &