From de4c3d02ca331fa5c2ef2e2334c86474a6117108 Mon Sep 17 00:00:00 2001 From: Olive Date: Fri, 1 Nov 2019 10:26:20 +0800 Subject: [PATCH] add files --- README.md | 88 ++++- __init__.py | 10 + eval.py | 58 +++ eval.sh | 28 ++ preprocess/data_process.py | 219 +++++++++++ preprocess/download_nltk.py | 15 + preprocess/run_me.sh | 19 + preprocess/sql2SemQL.py | 391 ++++++++++++++++++ preprocess/utils.py | 173 ++++++++ requirements.txt | 8 + sem2SQL.py | 697 ++++++++++++++++++++++++++++++++ src/__init__.py | 10 + src/args.py | 82 ++++ src/beam.py | 206 ++++++++++ src/dataset.py | 140 +++++++ src/models/__init__.py | 10 + src/models/basic_model.py | 160 ++++++++ src/models/model.py | 764 ++++++++++++++++++++++++++++++++++++ src/models/nn_utils.py | 236 +++++++++++ src/models/pointer_net.py | 106 +++++ src/rule/__init__.py | 10 + src/rule/graph.py | 130 ++++++ src/rule/lf.py | 229 +++++++++++ src/rule/semQL.py | 399 +++++++++++++++++++ src/rule/sem_utils.py | 321 +++++++++++++++ src/utils.py | 348 ++++++++++++++++ train.py | 117 ++++++ train.sh | 24 ++ 28 files changed, 4990 insertions(+), 8 deletions(-) create mode 100644 __init__.py create mode 100644 eval.py create mode 100644 eval.sh create mode 100644 preprocess/data_process.py create mode 100644 preprocess/download_nltk.py create mode 100644 preprocess/run_me.sh create mode 100644 preprocess/sql2SemQL.py create mode 100644 preprocess/utils.py create mode 100644 requirements.txt create mode 100644 sem2SQL.py create mode 100644 src/__init__.py create mode 100644 src/args.py create mode 100644 src/beam.py create mode 100644 src/dataset.py create mode 100644 src/models/__init__.py create mode 100644 src/models/basic_model.py create mode 100644 src/models/model.py create mode 100644 src/models/nn_utils.py create mode 100644 src/models/pointer_net.py create mode 100644 src/rule/__init__.py create mode 100644 src/rule/graph.py create mode 100644 src/rule/lf.py create mode 100644 src/rule/semQL.py create mode 100644 src/rule/sem_utils.py create mode 100644 src/utils.py create mode 100644 train.py create mode 100644 train.sh diff --git a/README.md b/README.md index b81a84e..8137eae 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,86 @@ +# IRNet +Code for our ACL'19 accepted paper: [Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation](https://arxiv.org/pdf/1905.08205.pdf) + +

+ +

+ +## Environment Setup + +* `Python3.6` +* `Pytorch 0.4.0` or higher + +Install Python dependency via `pip install -r requirements.txt` when the environment of Python and Pytorch is setup. + +## Running Code + +#### Data preparation + + +* Download [Glove Embedding](https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip) and put `glove.42B.300d` under `./data/` directory +* Download [Pretrained IRNet](https://drive.google.com/open?id=1VoV28fneYss8HaZmoThGlvYU3A-aK31q) and put ` +IRNet_pretrained.model` under `./saved_model/` directory +* Download preprocessed train/dev datasets from [here](https://drive.google.com/open?id=1YFV1GoLivOMlmunKW0nkzefKULO4wtrn) and put `train.json`, `dev.json` and +`tables.json` under `./data/` directory + +##### Generating train/dev data by yourself +You could process the origin [Spider Data](https://drive.google.com/uc?export=download&id=11icoH_EA-NYb0OrPTdehRWm_d7-DIzWX) by your own. Download and put `train.json`, `dev.json` and +`tables.json` under `./data/` directory and follow the instruction on `./preprocess/` + +#### Training + +Run `train.sh` to train IRNet. + +`sh train.sh [GPU_ID] [SAVE_FOLD]` + +#### Testing + +Run `eval.sh` to eval IRNet. + +`sh eval.sh [GPU_ID] [OUTPUT_FOLD]` + + +#### Evaluation + +You could follow the general evaluation process in [Spider Page](https://github.com/taoyds/spider) + + +## Results +| **Model** | Dev
Exact Set Match
Accuracy | Test
Exact Set Match
Accuracy | +| ----------- | ------------------------------------- | -------------------------------------- | +| IRNet | 53.2 | 46.7 | +| IRNet+BERT(base) | 61.9 | **54.7** | + + +## Citation + +If you use IRNet, please cite the following work. + +``` +@article{GuoIRNet2019, + author={Jiaqi Guo and Zecheng Zhan and Yan Gao and Yan Xiao and Jian-Guang Lou and Ting Liu and Dongmei Zhang}, + title={Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation}, + journal={arXiv preprint arXiv:1905.08205}, + year={2019}, + note={version 1} +} +``` + +## Thanks +We would like to thank [Tao Yu](https://taoyds.github.io/) and [Bo Pang](https://www.linkedin.com/in/bo-pang/) for running evaluations on our submitted models. +We are also grateful to the flexible semantic parser [TranX](https://github.com/pcyin/tranX) that inspires our works. # Contributing -This project welcomes contributions and suggestions. Most contributions require you to agree to a -Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us -the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. +This project welcomes contributions and suggestions. Most contributions require you to +agree to a Contributor License Agreement (CLA) declaring that you have the right to, +and actually do, grant us the rights to use your contribution. For details, visit +https://cla.microsoft.com. -When you submit a pull request, a CLA bot will automatically determine whether you need to provide -a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions -provided by the bot. You will only need to do this once across all repos using our CLA. +When you submit a pull request, a CLA-bot will automatically determine whether you need +to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the +instructions provided by the bot. You will only need to do this once across all repositories using our CLA. This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). -For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or -contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. +For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..a668d78 --- /dev/null +++ b/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : __init__.py +# @Software: PyCharm +""" \ No newline at end of file diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..03346d6 --- /dev/null +++ b/eval.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/27 +# @Author : Jiaqi&Zecheng +# @File : eval.py +# @Software: PyCharm +""" + + +import torch +from src import args as arg +from src import utils +from src.models.model import IRNet +from src.rule import semQL + + +def evaluate(args): + """ + :param args: + :return: + """ + + grammar = semQL.Grammar() + sql_data, table_data, val_sql_data,\ + val_table_data= utils.load_dataset(args.dataset, use_small=args.toy) + + model = IRNet(args, grammar) + + if args.cuda: model.cuda() + + print('load pretrained model from %s'% (args.load_model)) + pretrained_model = torch.load(args.load_model, + map_location=lambda storage, loc: storage) + import copy + pretrained_modeled = copy.deepcopy(pretrained_model) + for k in pretrained_model.keys(): + if k not in model.state_dict().keys(): + del pretrained_modeled[k] + + model.load_state_dict(pretrained_modeled) + + model.word_emb = utils.load_word_emb(args.glove_embed_path) + + json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, + beam_size=args.beam_size) + # utils.eval_acc(json_datas, val_sql_data) + import json + with open('./predict_lf.json', 'w') as f: + json.dump(json_datas, f) + +if __name__ == '__main__': + arg_parser = arg.init_arg_parser() + args = arg.init_config(arg_parser) + print(args) + evaluate(args) \ No newline at end of file diff --git a/eval.sh b/eval.sh new file mode 100644 index 0000000..362682b --- /dev/null +++ b/eval.sh @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +#!/bin/bash + +devices=$1 +save_name=$2 + +CUDA_VISIBLE_DEVICES=$devices python -u eval.py --dataset ./data \ +--glove_embed_path ./data/glove.42B.300d.txt \ +--cuda \ +--epoch 50 \ +--loss_epoch_threshold 50 \ +--sketch_loss_coefficie 1.0 \ +--beam_size 5 \ +--seed 90 \ +--save ${save_name} \ +--embed_size 300 \ +--sentence_features \ +--column_pointer \ +--hidden_size 300 \ +--lr_scheduler \ +--lr_scheduler_gammar 0.5 \ +--att_vec_size 300 \ +--load_model ./saved_model/IRNet_pretrained.model + +python sem2SQL.py --data_path ./data --input_path predict_lf.json --output_path ${save_name} + diff --git a/preprocess/data_process.py b/preprocess/data_process.py new file mode 100644 index 0000000..43efb8b --- /dev/null +++ b/preprocess/data_process.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/24 +# @Author : Jiaqi&Zecheng +# @File : data_process.py +# @Software: PyCharm +""" +import json +import argparse +import nltk +import os +import pickle +from utils import symbol_filter, re_lemma, fully_part_header, group_header, partial_header, num2year, group_symbol, group_values, group_digital +from utils import AGG, wordnet_lemmatizer +from utils import load_dataSets + +def process_datas(datas, args): + """ + + :param datas: + :param args: + :return: + """ + with open(os.path.join(args.conceptNet, 'english_RelatedTo.pkl'), 'rb') as f: + english_RelatedTo = pickle.load(f) + + with open(os.path.join(args.conceptNet, 'english_IsA.pkl'), 'rb') as f: + english_IsA = pickle.load(f) + + # copy of the origin question_toks + for d in datas: + if 'origin_question_toks' not in d: + d['origin_question_toks'] = d['question_toks'] + + for entry in datas: + entry['question_toks'] = symbol_filter(entry['question_toks']) + origin_question_toks = symbol_filter([x for x in entry['origin_question_toks'] if x.lower() != 'the']) + question_toks = [wordnet_lemmatizer.lemmatize(x.lower()) for x in entry['question_toks'] if x.lower() != 'the'] + + entry['question_toks'] = question_toks + + table_names = [] + table_names_pattern = [] + + for y in entry['table_names']: + x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')] + table_names.append(" ".join(x)) + x = [re_lemma(x.lower()) for x in y.split(' ')] + table_names_pattern.append(" ".join(x)) + + header_toks = [] + header_toks_list = [] + + header_toks_pattern = [] + header_toks_list_pattern = [] + + for y in entry['col_set']: + x = [wordnet_lemmatizer.lemmatize(x.lower()) for x in y.split(' ')] + header_toks.append(" ".join(x)) + header_toks_list.append(x) + + x = [re_lemma(x.lower()) for x in y.split(' ')] + header_toks_pattern.append(" ".join(x)) + header_toks_list_pattern.append(x) + + num_toks = len(question_toks) + idx = 0 + tok_concol = [] + type_concol = [] + nltk_result = nltk.pos_tag(question_toks) + + while idx < num_toks: + + # fully header + end_idx, header = fully_part_header(question_toks, idx, num_toks, header_toks) + if header: + tok_concol.append(question_toks[idx: end_idx]) + type_concol.append(["col"]) + idx = end_idx + continue + + # check for table + end_idx, tname = group_header(question_toks, idx, num_toks, table_names) + if tname: + tok_concol.append(question_toks[idx: end_idx]) + type_concol.append(["table"]) + idx = end_idx + continue + + # check for column + end_idx, header = group_header(question_toks, idx, num_toks, header_toks) + if header: + tok_concol.append(question_toks[idx: end_idx]) + type_concol.append(["col"]) + idx = end_idx + continue + + # check for partial column + end_idx, tname = partial_header(question_toks, idx, header_toks_list) + if tname: + tok_concol.append(tname) + type_concol.append(["col"]) + idx = end_idx + continue + + # check for aggregation + end_idx, agg = group_header(question_toks, idx, num_toks, AGG) + if agg: + tok_concol.append(question_toks[idx: end_idx]) + type_concol.append(["agg"]) + idx = end_idx + continue + + if nltk_result[idx][1] == 'RBR' or nltk_result[idx][1] == 'JJR': + tok_concol.append([question_toks[idx]]) + type_concol.append(['MORE']) + idx += 1 + continue + + if nltk_result[idx][1] == 'RBS' or nltk_result[idx][1] == 'JJS': + tok_concol.append([question_toks[idx]]) + type_concol.append(['MOST']) + idx += 1 + continue + + # string match for Time Format + if num2year(question_toks[idx]): + question_toks[idx] = 'year' + end_idx, header = group_header(question_toks, idx, num_toks, header_toks) + if header: + tok_concol.append(question_toks[idx: end_idx]) + type_concol.append(["col"]) + idx = end_idx + continue + + def get_concept_result(toks, graph): + for begin_id in range(0, len(toks)): + for r_ind in reversed(range(1, len(toks) + 1 - begin_id)): + tmp_query = "_".join(toks[begin_id:r_ind]) + if tmp_query in graph: + mi = graph[tmp_query] + for col in entry['col_set']: + if col in mi: + return col + + end_idx, symbol = group_symbol(question_toks, idx, num_toks) + if symbol: + tmp_toks = [x for x in question_toks[idx: end_idx]] + assert len(tmp_toks) > 0, print(symbol, question_toks) + pro_result = get_concept_result(tmp_toks, english_IsA) + if pro_result is None: + pro_result = get_concept_result(tmp_toks, english_RelatedTo) + if pro_result is None: + pro_result = "NONE" + for tmp in tmp_toks: + tok_concol.append([tmp]) + type_concol.append([pro_result]) + pro_result = "NONE" + idx = end_idx + continue + + end_idx, values = group_values(origin_question_toks, idx, num_toks) + if values and (len(values) > 1 or question_toks[idx - 1] not in ['?', '.']): + tmp_toks = [wordnet_lemmatizer.lemmatize(x) for x in question_toks[idx: end_idx] if x.isalnum() is True] + assert len(tmp_toks) > 0, print(question_toks[idx: end_idx], values, question_toks, idx, end_idx) + pro_result = get_concept_result(tmp_toks, english_IsA) + if pro_result is None: + pro_result = get_concept_result(tmp_toks, english_RelatedTo) + if pro_result is None: + pro_result = "NONE" + for tmp in tmp_toks: + tok_concol.append([tmp]) + type_concol.append([pro_result]) + pro_result = "NONE" + idx = end_idx + continue + + result = group_digital(question_toks, idx) + if result is True: + tok_concol.append(question_toks[idx: idx + 1]) + type_concol.append(["value"]) + idx += 1 + continue + if question_toks[idx] == ['ha']: + question_toks[idx] = ['have'] + + tok_concol.append([question_toks[idx]]) + type_concol.append(['NONE']) + idx += 1 + continue + + entry['question_arg'] = tok_concol + entry['question_arg_type'] = type_concol + entry['nltk_pos'] = nltk_result + + return datas + + +if __name__ == '__main__': + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--data_path', type=str, help='dataset', required=True) + arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True) + arg_parser.add_argument('--output', type=str, help='output data') + args = arg_parser.parse_args() + args.conceptNet = './conceptNet' + + # loading dataSets + datas, table = load_dataSets(args) + + # process datasets + process_result = process_datas(datas, args) + + with open(args.output, 'w') as f: + json.dump(datas, f) + + diff --git a/preprocess/download_nltk.py b/preprocess/download_nltk.py new file mode 100644 index 0000000..7dc8c6b --- /dev/null +++ b/preprocess/download_nltk.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/1/29 +# @Author : Jiaqi&Zecheng +# @File : download_nltk.py +# @Software: PyCharm +""" +import nltk +nltk.download('averaged_perceptron_tagger') +nltk.download('punkt') +nltk.download('wordnet') + diff --git a/preprocess/run_me.sh b/preprocess/run_me.sh new file mode 100644 index 0000000..d870ab6 --- /dev/null +++ b/preprocess/run_me.sh @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +#!/bin/bash + +data=$1 +table_data=$2 +output=$3 + +echo "Start download NLTK data" +python download_nltk.py + +echo "Start process the origin Spider dataset" +python data_process.py --data_path ${data} --table_path ${table_data} --output "process_data.json" + +echo "Start generate SemQL from SQL" +python sql2SemQL.py --data_path process_data.json --table_path ${table_data} --output ${data} + +rm process_data.json diff --git a/preprocess/sql2SemQL.py b/preprocess/sql2SemQL.py new file mode 100644 index 0000000..d8bc806 --- /dev/null +++ b/preprocess/sql2SemQL.py @@ -0,0 +1,391 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/24 +# @Author : Jiaqi&Zecheng +# @File : sql2SemQL.py +# @Software: PyCharm +""" + +import argparse +import json +import sys + +import copy +from utils import load_dataSets + +sys.path.append("..") +from src.rule.semQL import Root1, Root, N, A, C, T, Sel, Sup, Filter, Order + +class Parser: + def __init__(self): + self.copy_selec = None + self.sel_result = [] + self.colSet = set() + + def _init_rule(self): + self.copy_selec = None + self.colSet = set() + + def _parse_root(self, sql): + """ + parsing the sql by the grammar + R ::= Select | Select Filter | Select Order | ... | + :return: [R(), states] + """ + use_sup, use_ord, use_fil = True, True, False + + if sql['sql']['limit'] == None: + use_sup = False + + if sql['sql']['orderBy'] == []: + use_ord = False + elif sql['sql']['limit'] != None: + use_ord = False + + # check the where and having + if sql['sql']['where'] != [] or \ + sql['sql']['having'] != []: + use_fil = True + + if use_fil and use_sup: + return [Root(0)], ['FILTER', 'SUP', 'SEL'] + elif use_fil and use_ord: + return [Root(1)], ['ORDER', 'FILTER', 'SEL'] + elif use_sup: + return [Root(2)], ['SUP', 'SEL'] + elif use_fil: + return [Root(3)], ['FILTER', 'SEL'] + elif use_ord: + return [Root(4)], ['ORDER', 'SEL'] + else: + return [Root(5)], ['SEL'] + + def _parser_column0(self, sql, select): + """ + Find table of column '*' + :return: T(table_id) + """ + if len(sql['sql']['from']['table_units']) == 1: + return T(sql['sql']['from']['table_units'][0][1]) + else: + table_list = [] + for tmp_t in sql['sql']['from']['table_units']: + if type(tmp_t[1]) == int: + table_list.append(tmp_t[1]) + table_set, other_set = set(table_list), set() + for sel_p in select: + if sel_p[1][1][1] != 0: + other_set.add(sql['col_table'][sel_p[1][1][1]]) + + if len(sql['sql']['where']) == 1: + other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]]) + elif len(sql['sql']['where']) == 3: + other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]]) + other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]]) + elif len(sql['sql']['where']) == 5: + other_set.add(sql['col_table'][sql['sql']['where'][0][2][1][1]]) + other_set.add(sql['col_table'][sql['sql']['where'][2][2][1][1]]) + other_set.add(sql['col_table'][sql['sql']['where'][4][2][1][1]]) + table_set = table_set - other_set + if len(table_set) == 1: + return T(list(table_set)[0]) + elif len(table_set) == 0 and sql['sql']['groupBy'] != []: + return T(sql['col_table'][sql['sql']['groupBy'][0][1]]) + else: + question = sql['question'] + self.sel_result.append(question) + print('column * table error') + return T(sql['sql']['from']['table_units'][0][1]) + + def _parse_select(self, sql): + """ + parsing the sql by the grammar + Select ::= A | AA | AAA | ... | + A ::= agg column table + :return: [Sel(), states] + """ + result = [] + select = sql['sql']['select'][1] + result.append(Sel(0)) + result.append(N(len(select) - 1)) + + for sel in select: + result.append(A(sel[0])) + self.colSet.add(sql['col_set'].index(sql['names'][sel[1][1][1]])) + result.append(C(sql['col_set'].index(sql['names'][sel[1][1][1]]))) + # now check for the situation with * + if sel[1][1][1] == 0: + result.append(self._parser_column0(sql, select)) + else: + result.append(T(sql['col_table'][sel[1][1][1]])) + if not self.copy_selec: + self.copy_selec = [copy.deepcopy(result[-2]), copy.deepcopy(result[-1])] + + return result, None + + def _parse_sup(self, sql): + """ + parsing the sql by the grammar + Sup ::= Most A | Least A + A ::= agg column table + :return: [Sup(), states] + """ + result = [] + select = sql['sql']['select'][1] + if sql['sql']['limit'] == None: + return result, None + if sql['sql']['orderBy'][0] == 'desc': + result.append(Sup(0)) + else: + result.append(Sup(1)) + + result.append(A(sql['sql']['orderBy'][1][0][1][0])) + self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])) + result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))) + if sql['sql']['orderBy'][1][0][1][1] == 0: + result.append(self._parser_column0(sql, select)) + else: + result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]])) + return result, None + + def _parse_filter(self, sql): + """ + parsing the sql by the grammar + Filter ::= and Filter Filter | ... | + A ::= agg column table + :return: [Filter(), states] + """ + result = [] + # check the where + if sql['sql']['where'] != [] and sql['sql']['having'] != []: + result.append(Filter(0)) + + if sql['sql']['where'] != []: + # check the not and/or + if len(sql['sql']['where']) == 1: + result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) + elif len(sql['sql']['where']) == 3: + if sql['sql']['where'][1] == 'or': + result.append(Filter(1)) + else: + result.append(Filter(0)) + result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) + else: + if sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'and': + result.append(Filter(0)) + result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) + result.append(Filter(0)) + result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) + elif sql['sql']['where'][1] == 'and' and sql['sql']['where'][3] == 'or': + result.append(Filter(1)) + result.append(Filter(0)) + result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) + elif sql['sql']['where'][1] == 'or' and sql['sql']['where'][3] == 'and': + result.append(Filter(1)) + result.append(Filter(0)) + result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) + else: + result.append(Filter(1)) + result.append(Filter(1)) + result.extend(self.parse_one_condition(sql['sql']['where'][0], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][2], sql['names'], sql)) + result.extend(self.parse_one_condition(sql['sql']['where'][4], sql['names'], sql)) + + # check having + if sql['sql']['having'] != []: + result.extend(self.parse_one_condition(sql['sql']['having'][0], sql['names'], sql)) + return result, None + + def _parse_order(self, sql): + """ + parsing the sql by the grammar + Order ::= asc A | desc A + A ::= agg column table + :return: [Order(), states] + """ + result = [] + + if 'order' not in sql['query_toks_no_value'] or 'by' not in sql['query_toks_no_value']: + return result, None + elif 'limit' in sql['query_toks_no_value']: + return result, None + else: + if sql['sql']['orderBy'] == []: + return result, None + else: + select = sql['sql']['select'][1] + if sql['sql']['orderBy'][0] == 'desc': + result.append(Order(0)) + else: + result.append(Order(1)) + result.append(A(sql['sql']['orderBy'][1][0][1][0])) + self.colSet.add(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]])) + result.append(C(sql['col_set'].index(sql['names'][sql['sql']['orderBy'][1][0][1][1]]))) + if sql['sql']['orderBy'][1][0][1][1] == 0: + result.append(self._parser_column0(sql, select)) + else: + result.append(T(sql['col_table'][sql['sql']['orderBy'][1][0][1][1]])) + return result, None + + + def parse_one_condition(self, sql_condit, names, sql): + result = [] + # check if V(root) + nest_query = True + if type(sql_condit[3]) != dict: + nest_query = False + + if sql_condit[0] == True: + if sql_condit[1] == 9: + # not like only with values + fil = Filter(10) + elif sql_condit[1] == 8: + # not in with Root + fil = Filter(19) + else: + print(sql_condit[1]) + raise NotImplementedError("not implement for the others FIL") + else: + # check for Filter (<,=,>,!=,between, >=, <=, ...) + single_map = {1:8,2:2,3:5,4:4,5:7,6:6,7:3} + nested_map = {1:15,2:11,3:13,4:12,5:16,6:17,7:14} + if sql_condit[1] in [1, 2, 3, 4, 5, 6, 7]: + if nest_query == False: + fil = Filter(single_map[sql_condit[1]]) + else: + fil = Filter(nested_map[sql_condit[1]]) + elif sql_condit[1] == 9: + fil = Filter(9) + elif sql_condit[1] == 8: + fil = Filter(18) + else: + print(sql_condit[1]) + raise NotImplementedError("not implement for the others FIL") + + result.append(fil) + result.append(A(sql_condit[2][1][0])) + self.colSet.add(sql['col_set'].index(sql['names'][sql_condit[2][1][1]])) + result.append(C(sql['col_set'].index(sql['names'][sql_condit[2][1][1]]))) + if sql_condit[2][1][1] == 0: + select = sql['sql']['select'][1] + result.append(self._parser_column0(sql, select)) + else: + result.append(T(sql['col_table'][sql_condit[2][1][1]])) + + # check for the nested value + if type(sql_condit[3]) == dict: + nest_query = {} + nest_query['names'] = names + nest_query['query_toks_no_value'] = "" + nest_query['sql'] = sql_condit[3] + nest_query['col_table'] = sql['col_table'] + nest_query['col_set'] = sql['col_set'] + nest_query['table_names'] = sql['table_names'] + nest_query['question'] = sql['question'] + nest_query['query'] = sql['query'] + nest_query['keys'] = sql['keys'] + result.extend(self.parser(nest_query)) + + return result + + def _parse_step(self, state, sql): + + if state == 'ROOT': + return self._parse_root(sql) + + if state == 'SEL': + return self._parse_select(sql) + + elif state == 'SUP': + return self._parse_sup(sql) + + elif state == 'FILTER': + return self._parse_filter(sql) + + elif state == 'ORDER': + return self._parse_order(sql) + else: + raise NotImplementedError("Not the right state") + + def full_parse(self, query): + sql = query['sql'] + nest_query = {} + nest_query['names'] = query['names'] + nest_query['query_toks_no_value'] = "" + nest_query['col_table'] = query['col_table'] + nest_query['col_set'] = query['col_set'] + nest_query['table_names'] = query['table_names'] + nest_query['question'] = query['question'] + nest_query['query'] = query['query'] + nest_query['keys'] = query['keys'] + + if sql['intersect']: + results = [Root1(0)] + nest_query['sql'] = sql['intersect'] + results.extend(self.parser(query)) + results.extend(self.parser(nest_query)) + return results + + if sql['union']: + results = [Root1(1)] + nest_query['sql'] = sql['union'] + results.extend(self.parser(query)) + results.extend(self.parser(nest_query)) + return results + + if sql['except']: + results = [Root1(2)] + nest_query['sql'] = sql['except'] + results.extend(self.parser(query)) + results.extend(self.parser(nest_query)) + return results + + results = [Root1(3)] + results.extend(self.parser(query)) + + return results + + def parser(self, query): + stack = ["ROOT"] + result = [] + while len(stack) > 0: + state = stack.pop() + step_result, step_state = self._parse_step(state, query) + result.extend(step_result) + if step_state: + stack.extend(step_state) + return result + +if __name__ == '__main__': + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--data_path', type=str, help='dataset', required=True) + arg_parser.add_argument('--table_path', type=str, help='table dataset', required=True) + arg_parser.add_argument('--output', type=str, help='output data', required=True) + args = arg_parser.parse_args() + + parser = Parser() + + # loading dataSets + datas, table = load_dataSets(args) + processed_data = [] + + for i, d in enumerate(datas): + if len(datas[i]['sql']['select'][1]) > 5: + continue + r = parser.full_parse(datas[i]) + datas[i]['rule_label'] = " ".join([str(x) for x in r]) + processed_data.append(datas[i]) + + print('Finished %s datas and failed %s datas' % (len(processed_data), len(datas) - len(processed_data))) + with open(args.output, 'w', encoding='utf8') as f: + f.write(json.dumps(processed_data)) + diff --git a/preprocess/utils.py b/preprocess/utils.py new file mode 100644 index 0000000..166a8e8 --- /dev/null +++ b/preprocess/utils.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/24 +# @Author : Jiaqi&Zecheng +# @File : utils.py +# @Software: PyCharm +""" +import os +import json +from pattern.en import lemma +from nltk.stem import WordNetLemmatizer + +VALUE_FILTER = ['what', 'how', 'list', 'give', 'show', 'find', 'id', 'order', 'when'] +AGG = ['average', 'sum', 'max', 'min', 'minimum', 'maximum', 'between'] + +wordnet_lemmatizer = WordNetLemmatizer() + +def load_dataSets(args): + with open(args.table_path, 'r', encoding='utf8') as f: + table_datas = json.load(f) + with open(args.data_path, 'r', encoding='utf8') as f: + datas = json.load(f) + + output_tab = {} + tables = {} + tabel_name = set() + for i in range(len(table_datas)): + table = table_datas[i] + temp = {} + temp['col_map'] = table['column_names'] + temp['table_names'] = table['table_names'] + tmp_col = [] + for cc in [x[1] for x in table['column_names']]: + if cc not in tmp_col: + tmp_col.append(cc) + table['col_set'] = tmp_col + db_name = table['db_id'] + tabel_name.add(db_name) + table['schema_content'] = [col[1] for col in table['column_names']] + table['col_table'] = [col[0] for col in table['column_names']] + output_tab[db_name] = temp + tables[db_name] = table + + for d in datas: + d['names'] = tables[d['db_id']]['schema_content'] + d['table_names'] = tables[d['db_id']]['table_names'] + d['col_set'] = tables[d['db_id']]['col_set'] + d['col_table'] = tables[d['db_id']]['col_table'] + keys = {} + for kv in tables[d['db_id']]['foreign_keys']: + keys[kv[0]] = kv[1] + keys[kv[1]] = kv[0] + for id_k in tables[d['db_id']]['primary_keys']: + keys[id_k] = id_k + d['keys'] = keys + return datas, tables + +def group_header(toks, idx, num_toks, header_toks): + for endIdx in reversed(range(idx + 1, num_toks+1)): + sub_toks = toks[idx: endIdx] + sub_toks = " ".join(sub_toks) + if sub_toks in header_toks: + return endIdx, sub_toks + return idx, None + +def fully_part_header(toks, idx, num_toks, header_toks): + for endIdx in reversed(range(idx + 1, num_toks+1)): + sub_toks = toks[idx: endIdx] + if len(sub_toks) > 1: + sub_toks = " ".join(sub_toks) + if sub_toks in header_toks: + return endIdx, sub_toks + return idx, None + +def partial_header(toks, idx, header_toks): + def check_in(list_one, list_two): + if len(set(list_one) & set(list_two)) == len(list_one) and (len(list_two) <= 3): + return True + for endIdx in reversed(range(idx + 1, len(toks))): + sub_toks = toks[idx: min(endIdx, len(toks))] + if len(sub_toks) > 1: + flag_count = 0 + tmp_heads = None + for heads in header_toks: + if check_in(sub_toks, heads): + flag_count += 1 + tmp_heads = heads + if flag_count == 1: + return endIdx, tmp_heads + return idx, None + +def symbol_filter(questions): + question_tmp_q = [] + for q_id, q_val in enumerate(questions): + if len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�', '鈥�'] and q_val[-1] in ["'", '"', '`', '鈥�']: + question_tmp_q.append("'") + question_tmp_q += ["".join(q_val[1:-1])] + question_tmp_q.append("'") + elif len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�'] : + question_tmp_q.append("'") + question_tmp_q += ["".join(q_val[1:])] + elif len(q_val) > 2 and q_val[-1] in ["'", '"', '`', '鈥�']: + question_tmp_q += ["".join(q_val[0:-1])] + question_tmp_q.append("'") + elif q_val in ["'", '"', '`', '鈥�', '鈥�', '``', "''"]: + question_tmp_q += ["'"] + else: + question_tmp_q += [q_val] + return question_tmp_q + + +def group_values(toks, idx, num_toks): + def check_isupper(tok_lists): + for tok_one in tok_lists: + if tok_one[0].isupper() is False: + return False + return True + + for endIdx in reversed(range(idx + 1, num_toks + 1)): + sub_toks = toks[idx: endIdx] + + if len(sub_toks) > 1 and check_isupper(sub_toks) is True: + return endIdx, sub_toks + if len(sub_toks) == 1: + if sub_toks[0][0].isupper() and sub_toks[0].lower() not in VALUE_FILTER and \ + sub_toks[0].lower().isalnum() is True: + return endIdx, sub_toks + return idx, None + + +def group_digital(toks, idx): + test = toks[idx].replace(':', '') + test = test.replace('.', '') + if test.isdigit(): + return True + else: + return False + +def group_symbol(toks, idx, num_toks): + if toks[idx-1] == "'": + for i in range(0, min(3, num_toks-idx)): + if toks[i + idx] == "'": + return i + idx, toks[idx:i+idx] + return idx, None + + +def num2year(tok): + if len(str(tok)) == 4 and str(tok).isdigit() and int(str(tok)[:2]) < 22 and int(str(tok)[:2]) > 15: + return True + return False + +def set_header(toks, header_toks, tok_concol, idx, num_toks): + def check_in(list_one, list_two): + if set(list_one) == set(list_two): + return True + for endIdx in range(idx, num_toks): + toks += tok_concol[endIdx] + if len(tok_concol[endIdx]) > 1: + break + for heads in header_toks: + if check_in(toks, heads): + return heads + return None + +def re_lemma(string): + lema = lemma(string.lower()) + if len(lema) > 0: + return lema + else: + return string.lower() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a0e23ba --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +nltk==3.4 +pattern +numpy==1.14.0 +pytorch-pretrained-bert==0.5.1 +tqdm==4.31.1 \ No newline at end of file diff --git a/sem2SQL.py b/sem2SQL.py new file mode 100644 index 0000000..75fe651 --- /dev/null +++ b/sem2SQL.py @@ -0,0 +1,697 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/27 +# @Author : Jiaqi&Zecheng +# @File : sem2SQL.py +# @Software: PyCharm +""" + +import argparse +import traceback + +from src.rule.graph import Graph +from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1 +from src.rule.sem_utils import alter_inter, alter_not_in, alter_column0, load_dataSets + + +def split_logical_form(lf): + indexs = [i+1 for i, letter in enumerate(lf) if letter == ')'] + indexs.insert(0, 0) + components = list() + for i in range(1, len(indexs)): + components.append(lf[indexs[i-1]:indexs[i]].strip()) + return components + + +def pop_front(array): + if len(array) == 0: + return 'None' + return array.pop(0) + + +def is_end(components, transformed_sql, is_root_processed): + end = False + c = pop_front(components) + c_instance = eval(c) + + if isinstance(c_instance, Root) and is_root_processed: + # intersect, union, except + end = True + elif isinstance(c_instance, Filter): + if 'where' not in transformed_sql: + end = True + else: + num_conjunction = 0 + for f in transformed_sql['where']: + if isinstance(f, str) and (f == 'and' or f == 'or'): + num_conjunction += 1 + current_filters = len(transformed_sql['where']) + valid_filters = current_filters - num_conjunction + if valid_filters >= num_conjunction + 1: + end = True + elif isinstance(c_instance, Order): + if 'order' not in transformed_sql: + end = True + elif len(transformed_sql['order']) == 0: + end = False + else: + end = True + elif isinstance(c_instance, Sup): + if 'sup' not in transformed_sql: + end = True + elif len(transformed_sql['sup']) == 0: + end = False + else: + end = True + components.insert(0, c) + return end + + +def _transform(components, transformed_sql, col_set, table_names, schema): + processed_root = False + current_table = schema + + while len(components) > 0: + if is_end(components, transformed_sql, processed_root): + break + c = pop_front(components) + c_instance = eval(c) + if isinstance(c_instance, Root): + processed_root = True + transformed_sql['select'] = list() + if c_instance.id_c == 0: + transformed_sql['where'] = list() + transformed_sql['sup'] = list() + elif c_instance.id_c == 1: + transformed_sql['where'] = list() + transformed_sql['order'] = list() + elif c_instance.id_c == 2: + transformed_sql['sup'] = list() + elif c_instance.id_c == 3: + transformed_sql['where'] = list() + elif c_instance.id_c == 4: + transformed_sql['order'] = list() + elif isinstance(c_instance, Sel): + continue + elif isinstance(c_instance, N): + for i in range(c_instance.id_c + 1): + agg = eval(pop_front(components)) + column = eval(pop_front(components)) + _table = pop_front(components) + table = eval(_table) + if not isinstance(table, T): + table = None + components.insert(0, _table) + assert isinstance(agg, A) and isinstance(column, C) + + transformed_sql['select'].append(( + agg.production.split()[1], + replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) if table is not None else col_set[column.id_c], + table_names[table.id_c] if table is not None else table + )) + + elif isinstance(c_instance, Sup): + transformed_sql['sup'].append(c_instance.production.split()[1]) + agg = eval(pop_front(components)) + column = eval(pop_front(components)) + _table = pop_front(components) + table = eval(_table) + if not isinstance(table, T): + table = None + components.insert(0, _table) + assert isinstance(agg, A) and isinstance(column, C) + + transformed_sql['sup'].append(agg.production.split()[1]) + if table: + fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) + else: + fix_col_id = col_set[column.id_c] + raise RuntimeError('not found table !!!!') + transformed_sql['sup'].append(fix_col_id) + transformed_sql['sup'].append(table_names[table.id_c] if table is not None else table) + + elif isinstance(c_instance, Order): + transformed_sql['order'].append(c_instance.production.split()[1]) + agg = eval(pop_front(components)) + column = eval(pop_front(components)) + _table = pop_front(components) + table = eval(_table) + if not isinstance(table, T): + table = None + components.insert(0, _table) + assert isinstance(agg, A) and isinstance(column, C) + transformed_sql['order'].append(agg.production.split()[1]) + transformed_sql['order'].append(replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table)) + transformed_sql['order'].append(table_names[table.id_c] if table is not None else table) + + elif isinstance(c_instance, Filter): + op = c_instance.production.split()[1] + if op == 'and' or op == 'or': + transformed_sql['where'].append(op) + else: + # No Supquery + agg = eval(pop_front(components)) + column = eval(pop_front(components)) + _table = pop_front(components) + table = eval(_table) + if not isinstance(table, T): + table = None + components.insert(0, _table) + assert isinstance(agg, A) and isinstance(column, C) + if len(c_instance.production.split()) == 3: + if table: + fix_col_id = replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table) + else: + fix_col_id = col_set[column.id_c] + raise RuntimeError('not found table !!!!') + transformed_sql['where'].append(( + op, + agg.production.split()[1], + fix_col_id, + table_names[table.id_c] if table is not None else table, + None + )) + else: + # Subquery + new_dict = dict() + new_dict['sql'] = transformed_sql['sql'] + transformed_sql['where'].append(( + op, + agg.production.split()[1], + replace_col_with_original_col(col_set[column.id_c], table_names[table.id_c], current_table), + table_names[table.id_c] if table is not None else table, + _transform(components, new_dict, col_set, table_names, schema) + )) + + return transformed_sql + + +def transform(query, schema, origin=None): + preprocess_schema(schema) + if origin is None: + lf = query['model_result_replace'] + else: + lf = origin + # lf = query['rule_label'] + col_set = query['col_set'] + table_names = query['table_names'] + current_table = schema + + current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']] + current_table['schema_content'] = [x[1] for x in current_table['column_names_original']] + + components = split_logical_form(lf) + + transformed_sql = dict() + transformed_sql['sql'] = query + c = pop_front(components) + c_instance = eval(c) + assert isinstance(c_instance, Root1) + if c_instance.id_c == 0: + transformed_sql['intersect'] = dict() + transformed_sql['intersect']['sql'] = query + + _transform(components, transformed_sql, col_set, table_names, schema) + _transform(components, transformed_sql['intersect'], col_set, table_names, schema) + elif c_instance.id_c == 1: + transformed_sql['union'] = dict() + transformed_sql['union']['sql'] = query + _transform(components, transformed_sql, col_set, table_names, schema) + _transform(components, transformed_sql['union'], col_set, table_names, schema) + elif c_instance.id_c == 2: + transformed_sql['except'] = dict() + transformed_sql['except']['sql'] = query + _transform(components, transformed_sql, col_set, table_names, schema) + _transform(components, transformed_sql['except'], col_set, table_names, schema) + else: + _transform(components, transformed_sql, col_set, table_names, schema) + + parse_result = to_str(transformed_sql, 1, schema) + + parse_result = parse_result.replace('\t', '') + return [parse_result] + +def col_to_str(agg, col, tab, table_names, N=1): + _col = col.replace(' ', '_') + if agg == 'none': + if tab not in table_names: + table_names[tab] = 'T' + str(len(table_names) + N) + table_alias = table_names[tab] + if col == '*': + return '*' + return '%s.%s' % (table_alias, _col) + else: + if col == '*': + if tab is not None and tab not in table_names: + table_names[tab] = 'T' + str(len(table_names) + N) + return '%s(%s)' % (agg, _col) + else: + if tab not in table_names: + table_names[tab] = 'T' + str(len(table_names) + N) + table_alias = table_names[tab] + return '%s(%s.%s)' % (agg, table_alias, _col) + + +def infer_from_clause(table_names, schema, columns): + tables = list(table_names.keys()) + # print(table_names) + start_table = None + end_table = None + join_clause = list() + if len(tables) == 1: + join_clause.append((tables[0], table_names[tables[0]])) + elif len(tables) == 2: + use_graph = True + # print(schema['graph'].vertices) + for t in tables: + if t not in schema['graph'].vertices: + use_graph = False + break + if use_graph: + start_table = tables[0] + end_table = tables[1] + _tables = list(schema['graph'].dijkstra(tables[0], tables[1])) + # print('Two tables: ', _tables) + max_key = 1 + for t, k in table_names.items(): + _k = int(k[1:]) + if _k > max_key: + max_key = _k + for t in _tables: + if t not in table_names: + table_names[t] = 'T' + str(max_key + 1) + max_key += 1 + join_clause.append((t, table_names[t],)) + else: + join_clause = list() + for t in tables: + join_clause.append((t, table_names[t],)) + else: + # > 2 + # print('More than 2 table') + for t in tables: + join_clause.append((t, table_names[t],)) + + if len(join_clause) >= 3: + star_table = None + for agg, col, tab in columns: + if col == '*': + star_table = tab + break + if star_table is not None: + star_table_count = 0 + for agg, col, tab in columns: + if tab == star_table and col != '*': + star_table_count += 1 + if star_table_count == 0 and ((end_table is None or end_table == star_table) or (start_table is None or start_table == star_table)): + # Remove the table the rest tables still can join without star_table + new_join_clause = list() + for t in join_clause: + if t[0] != star_table: + new_join_clause.append(t) + join_clause = new_join_clause + + join_clause = ' JOIN '.join(['%s AS %s' % (jc[0], jc[1]) for jc in join_clause]) + return 'FROM ' + join_clause + +def replace_col_with_original_col(query, col, current_table): + # print(query, col) + if query == '*': + return query + + cur_table = col + cur_col = query + single_final_col = None + # print(query, col) + for col_ind, col_name in enumerate(current_table['schema_content_clean']): + if col_name == cur_col: + assert cur_table in current_table['table_names'] + if current_table['table_names'][current_table['col_table'][col_ind]] == cur_table: + single_final_col = current_table['column_names_original'][col_ind][1] + break + + assert single_final_col + # if query != single_final_col: + # print(query, single_final_col) + return single_final_col + + +def build_graph(schema): + relations = list() + foreign_keys = schema['foreign_keys'] + for (fkey, pkey) in foreign_keys: + fkey_table = schema['table_names_original'][schema['column_names'][fkey][0]] + pkey_table = schema['table_names_original'][schema['column_names'][pkey][0]] + relations.append((fkey_table, pkey_table)) + relations.append((pkey_table, fkey_table)) + return Graph(relations) + + +def preprocess_schema(schema): + tmp_col = [] + for cc in [x[1] for x in schema['column_names']]: + if cc not in tmp_col: + tmp_col.append(cc) + schema['col_set'] = tmp_col + # print table + schema['schema_content'] = [col[1] for col in schema['column_names']] + schema['col_table'] = [col[0] for col in schema['column_names']] + graph = build_graph(schema) + schema['graph'] = graph + + + + +def to_str(sql_json, N_T, schema, pre_table_names=None): + all_columns = list() + select_clause = list() + table_names = dict() + current_table = schema + for (agg, col, tab) in sql_json['select']: + all_columns.append((agg, col, tab)) + select_clause.append(col_to_str(agg, col, tab, table_names, N_T)) + select_clause_str = 'SELECT ' + ', '.join(select_clause).strip() + + sup_clause = '' + order_clause = '' + direction_map = {"des": 'DESC', 'asc': 'ASC'} + + if 'sup' in sql_json: + (direction, agg, col, tab,) = sql_json['sup'] + all_columns.append((agg, col, tab)) + subject = col_to_str(agg, col, tab, table_names, N_T) + sup_clause = ('ORDER BY %s %s LIMIT 1' % (subject, direction_map[direction])).strip() + elif 'order' in sql_json: + (direction, agg, col, tab,) = sql_json['order'] + all_columns.append((agg, col, tab)) + subject = col_to_str(agg, col, tab, table_names, N_T) + order_clause = ('ORDER BY %s %s' % (subject, direction_map[direction])).strip() + + has_group_by = False + where_clause = '' + have_clause = '' + if 'where' in sql_json: + conjunctions = list() + filters = list() + # print(sql_json['where']) + for f in sql_json['where']: + if isinstance(f, str): + conjunctions.append(f) + else: + op, agg, col, tab, value = f + if value: + value['sql'] = sql_json['sql'] + all_columns.append((agg, col, tab)) + subject = col_to_str(agg, col, tab, table_names, N_T) + if value is None: + where_value = '1' + if op == 'between': + where_value = '1 AND 2' + filters.append('%s %s %s' % (subject, op, where_value)) + else: + if op == 'in' and len(value['select']) == 1 and value['select'][0][0] == 'none' \ + and 'where' not in value and 'order' not in value and 'sup' not in value: + # and value['select'][0][2] not in table_names: + if value['select'][0][2] not in table_names: + table_names[value['select'][0][2]] = 'T' + str(len(table_names) + N_T) + filters.append(None) + + else: + filters.append('%s %s %s' % (subject, op, '(' + to_str(value, len(table_names) + 1, schema) + ')')) + if len(conjunctions): + filters.append(conjunctions.pop()) + + aggs = ['count(', 'avg(', 'min(', 'max(', 'sum('] + having_filters = list() + idx = 0 + while idx < len(filters): + _filter = filters[idx] + if _filter is None: + idx += 1 + continue + for agg in aggs: + if _filter.startswith(agg): + having_filters.append(_filter) + filters.pop(idx) + # print(filters) + if 0 < idx and (filters[idx - 1] in ['and', 'or']): + filters.pop(idx - 1) + # print(filters) + break + else: + idx += 1 + if len(having_filters) > 0: + have_clause = 'HAVING ' + ' '.join(having_filters).strip() + if len(filters) > 0: + # print(filters) + filters = [_f for _f in filters if _f is not None] + conjun_num = 0 + filter_num = 0 + for _f in filters: + if _f in ['or', 'and']: + conjun_num += 1 + else: + filter_num += 1 + if conjun_num > 0 and filter_num != (conjun_num + 1): + # assert 'and' in filters + idx = 0 + while idx < len(filters): + if filters[idx] == 'and': + if idx - 1 == 0: + filters.pop(idx) + break + if filters[idx - 1] in ['and', 'or']: + filters.pop(idx) + break + if idx + 1 >= len(filters) - 1: + filters.pop(idx) + break + if filters[idx + 1] in ['and', 'or']: + filters.pop(idx) + break + idx += 1 + if len(filters) > 0: + where_clause = 'WHERE ' + ' '.join(filters).strip() + where_clause = where_clause.replace('not_in', 'NOT IN') + else: + where_clause = '' + + if len(having_filters) > 0: + has_group_by = True + + for agg in ['count(', 'avg(', 'min(', 'max(', 'sum(']: + if (len(sql_json['select']) > 1 and agg in select_clause_str)\ + or agg in sup_clause or agg in order_clause: + has_group_by = True + break + + group_by_clause = '' + if has_group_by: + if len(table_names) == 1: + # check none agg + is_agg_flag = False + for (agg, col, tab) in sql_json['select']: + + if agg == 'none': + group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T) + else: + is_agg_flag = True + + if is_agg_flag is False and len(group_by_clause) > 5: + group_by_clause = "GROUP BY" + for (agg, col, tab) in sql_json['select']: + group_by_clause = group_by_clause + ' ' + col_to_str(agg, col, tab, table_names, N_T) + + if len(group_by_clause) < 5: + if 'count(*)' in select_clause_str: + current_table = schema + for primary in current_table['primary_keys']: + if current_table['table_names'][current_table['col_table'][primary]] in table_names : + group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][primary], + current_table['table_names'][ + current_table['col_table'][primary]], + table_names, N_T) + else: + # if only one select + if len(sql_json['select']) == 1: + agg, col, tab = sql_json['select'][0] + non_lists = [tab] + fix_flag = False + # add tab from other part + for key, value in table_names.items(): + if key not in non_lists: + non_lists.append(key) + + a = non_lists[0] + b = None + for non in non_lists: + if a != non: + b = non + if b: + for pair in current_table['foreign_keys']: + t1 = current_table['table_names'][current_table['col_table'][pair[0]]] + t2 = current_table['table_names'][current_table['col_table'][pair[1]]] + if t1 in [a, b] and t2 in [a, b]: + if pre_table_names and t1 not in pre_table_names: + assert t2 in pre_table_names + t1 = t2 + group_by_clause = 'GROUP BY ' + col_to_str('none', + current_table['schema_content'][pair[0]], + t1, + table_names, N_T) + fix_flag = True + break + + if fix_flag is False: + agg, col, tab = sql_json['select'][0] + group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T) + + else: + # check if there are only one non agg + non_agg, non_agg_count = None, 0 + non_lists = [] + for (agg, col, tab) in sql_json['select']: + if agg == 'none': + non_agg = (agg, col, tab) + non_lists.append(tab) + non_agg_count += 1 + + non_lists = list(set(non_lists)) + # print(non_lists) + if non_agg_count == 1: + group_by_clause = 'GROUP BY ' + col_to_str(non_agg[0], non_agg[1], non_agg[2], table_names, N_T) + elif non_agg: + find_flag = False + fix_flag = False + find_primary = None + if len(non_lists) <= 1: + for key, value in table_names.items(): + if key not in non_lists: + non_lists.append(key) + if len(non_lists) > 1: + a = non_lists[0] + b = None + for non in non_lists: + if a != non: + b = non + if b: + for pair in current_table['foreign_keys']: + t1 = current_table['table_names'][current_table['col_table'][pair[0]]] + t2 = current_table['table_names'][current_table['col_table'][pair[1]]] + if t1 in [a, b] and t2 in [a, b]: + if pre_table_names and t1 not in pre_table_names: + assert t2 in pre_table_names + t1 = t2 + group_by_clause = 'GROUP BY ' + col_to_str('none', + current_table['schema_content'][pair[0]], + t1, + table_names, N_T) + fix_flag = True + break + tab = non_agg[2] + assert tab in current_table['table_names'] + + for primary in current_table['primary_keys']: + if current_table['table_names'][current_table['col_table'][primary]] == tab: + find_flag = True + find_primary = (current_table['schema_content'][primary], tab) + if fix_flag is False: + if find_flag is False: + # rely on count * + foreign = [] + for pair in current_table['foreign_keys']: + if current_table['table_names'][current_table['col_table'][pair[0]]] == tab: + foreign.append(pair[1]) + if current_table['table_names'][current_table['col_table'][pair[1]]] == tab: + foreign.append(pair[0]) + + for pair in foreign: + if current_table['table_names'][current_table['col_table'][pair]] in table_names: + group_by_clause = 'GROUP BY ' + col_to_str('none', current_table['schema_content'][pair], + current_table['table_names'][current_table['col_table'][pair]], + table_names, N_T) + find_flag = True + break + if find_flag is False: + for (agg, col, tab) in sql_json['select']: + if 'id' in col.lower(): + group_by_clause = 'GROUP BY ' + col_to_str(agg, col, tab, table_names, N_T) + break + if len(group_by_clause) > 5: + pass + else: + raise RuntimeError('fail to convert') + else: + group_by_clause = 'GROUP BY ' + col_to_str('none', find_primary[0], + find_primary[1], + table_names, N_T) + intersect_clause = '' + if 'intersect' in sql_json: + sql_json['intersect']['sql'] = sql_json['sql'] + intersect_clause = 'INTERSECT ' + to_str(sql_json['intersect'], len(table_names) + 1, schema, table_names) + union_clause = '' + if 'union' in sql_json: + sql_json['union']['sql'] = sql_json['sql'] + union_clause = 'UNION ' + to_str(sql_json['union'], len(table_names) + 1, schema, table_names) + except_clause = '' + if 'except' in sql_json: + sql_json['except']['sql'] = sql_json['sql'] + except_clause = 'EXCEPT ' + to_str(sql_json['except'], len(table_names) + 1, schema, table_names) + + # print(current_table['table_names_original']) + table_names_replace = {} + for a, b in zip(current_table['table_names_original'], current_table['table_names']): + table_names_replace[b] = a + new_table_names = {} + for key, value in table_names.items(): + if key is None: + continue + new_table_names[table_names_replace[key]] = value + from_clause = infer_from_clause(new_table_names, schema, all_columns).strip() + + sql = ' '.join([select_clause_str, from_clause, where_clause, group_by_clause, have_clause, sup_clause, order_clause, + intersect_clause, union_clause, except_clause]) + + return sql + + + + +if __name__ == '__main__': + + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--data_path', type=str, help='dataset path', required=True) + arg_parser.add_argument('--input_path', type=str, help='predicted logical form', required=True) + arg_parser.add_argument('--output_path', type=str, help='output data') + args = arg_parser.parse_args() + + # loading dataSets + datas, schemas = load_dataSets(args) + alter_not_in(datas, schemas=schemas) + alter_inter(datas) + alter_column0(datas) + + + index = range(len(datas)) + count = 0 + exception_count = 0 + with open(args.output_path, 'w', encoding='utf8') as d: + for i in index: + try: + result = transform(datas[i], schemas[datas[i]['db_id']]) + d.write(result[0] + '\n') + count += 1 + except Exception as e: + result = transform(datas[i], schemas[datas[i]['db_id']], origin='Root1(3) Root(5) Sel(0) N(0) A(3) C(0) T(0)') + exception_count += 1 + d.write(result[0] + '\n') + count += 1 + print(e) + print('Exception') + print(traceback.format_exc()) + print('===\n\n') + + print(count, exception_count) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..a668d78 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : __init__.py +# @Software: PyCharm +""" \ No newline at end of file diff --git a/src/args.py b/src/args.py new file mode 100644 index 0000000..83f9fa4 --- /dev/null +++ b/src/args.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : args.py +# @Software: PyCharm +""" + +import random +import argparse +import torch +import numpy as np + +def init_arg_parser(): + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--seed', default=5783287, type=int, help='random seed') + arg_parser.add_argument('--cuda', action='store_true', help='use gpu') + arg_parser.add_argument('--lr_scheduler', action='store_true', help='use learning rate scheduler') + arg_parser.add_argument('--lr_scheduler_gammar', default=0.5, type=float, help='decay rate of learning rate scheduler') + arg_parser.add_argument('--column_pointer', action='store_true', help='use column pointer') + arg_parser.add_argument('--loss_epoch_threshold', default=20, type=int, help='loss epoch threshold') + arg_parser.add_argument('--sketch_loss_coefficient', default=0.2, type=float, help='sketch loss coefficient') + arg_parser.add_argument('--sentence_features', action='store_true', help='use sentence features') + arg_parser.add_argument('--model_name', choices=['transformer', 'rnn', 'table', 'sketch'], default='rnn', + help='model name') + + arg_parser.add_argument('--lstm', choices=['lstm', 'lstm_with_dropout', 'parent_feed'], default='lstm') + + arg_parser.add_argument('--load_model', default=None, type=str, help='load a pre-trained model') + arg_parser.add_argument('--glove_embed_path', default="glove.42B.300d.txt", type=str) + + arg_parser.add_argument('--batch_size', default=64, type=int, help='batch size') + arg_parser.add_argument('--beam_size', default=5, type=int, help='beam size for beam search') + arg_parser.add_argument('--embed_size', default=300, type=int, help='size of word embeddings') + arg_parser.add_argument('--col_embed_size', default=300, type=int, help='size of word embeddings') + + arg_parser.add_argument('--action_embed_size', default=128, type=int, help='size of word embeddings') + arg_parser.add_argument('--type_embed_size', default=128, type=int, help='size of word embeddings') + arg_parser.add_argument('--hidden_size', default=100, type=int, help='size of LSTM hidden states') + arg_parser.add_argument('--att_vec_size', default=100, type=int, help='size of attentional vector') + arg_parser.add_argument('--dropout', default=0.3, type=float, help='dropout rate') + arg_parser.add_argument('--word_dropout', default=0.2, type=float, help='word dropout rate') + + # readout layer + arg_parser.add_argument('--no_query_vec_to_action_map', default=False, action='store_true') + arg_parser.add_argument('--readout', default='identity', choices=['identity', 'non_linear']) + arg_parser.add_argument('--query_vec_to_action_diff_map', default=False, action='store_true') + + + arg_parser.add_argument('--column_att', choices=['dot_prod', 'affine'], default='affine') + + arg_parser.add_argument('--decode_max_time_step', default=40, type=int, help='maximum number of time steps used ' + 'in decoding and sampling') + + + arg_parser.add_argument('--save_to', default='model', type=str, help='save trained model to') + arg_parser.add_argument('--toy', action='store_true', + help='If set, use small data; used for fast debugging.') + arg_parser.add_argument('--clip_grad', default=5., type=float, help='clip gradients') + arg_parser.add_argument('--max_epoch', default=-1, type=int, help='maximum number of training epoches') + arg_parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer') + arg_parser.add_argument('--lr', default=0.001, type=float, help='learning rate') + + arg_parser.add_argument('--dataset', default="./data", type=str) + + arg_parser.add_argument('--epoch', default=50, type=int, help='Maximum Epoch') + arg_parser.add_argument('--save', default='./', type=str, + help="Path to save the checkpoint and logs of epoch") + + return arg_parser + +def init_config(arg_parser): + args = arg_parser.parse_args() + torch.manual_seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + np.random.seed(int(args.seed * 13 / 7)) + random.seed(int(args.seed)) + return args diff --git a/src/beam.py b/src/beam.py new file mode 100644 index 0000000..45ede09 --- /dev/null +++ b/src/beam.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : beam.py +# @Software: PyCharm +""" + +import copy + +from src.rule import semQL + + +class ActionInfo(object): + """sufficient statistics for making a prediction of an action at a time step""" + + def __init__(self, action=None): + self.t = 0 + self.score = 0 + self.parent_t = -1 + self.action = action + self.frontier_prod = None + self.frontier_field = None + + # for GenToken actions only + self.copy_from_src = False + self.src_token_position = -1 + + +class Beams(object): + def __init__(self, is_sketch=False): + self.actions = [] + self.action_infos = [] + self.inputs = [] + self.score = 0. + self.t = 0 + self.is_sketch = is_sketch + self.sketch_step = 0 + self.sketch_attention_history = list() + + def get_availableClass(self): + """ + return the available action class + :return: + """ + + # TODO: it could be update by speed + # return the available class using rule + # FIXME: now should change for these 11: "Filter 1 ROOT", + def check_type(lists): + for s in lists: + if type(s) == int: + return False + return True + + stack = [semQL.Root1] + for action in self.actions: + infer_action = action.get_next_action(is_sketch=self.is_sketch) + infer_action.reverse() + if stack[-1] is type(action): + stack.pop() + # check if the are non-terminal + if check_type(infer_action): + stack.extend(infer_action) + else: + raise RuntimeError("Not the right action") + + result = stack[-1] if len(stack) > 0 else None + + return result + + @classmethod + def get_parent_action(cls, actions): + """ + + :param actions: + :return: + """ + + def check_type(lists): + for s in lists: + if type(s) == int: + return False + return True + + # check the origin state Root + if len(actions) == 0: + return None + + stack = [semQL.Root1] + for id_x, action in enumerate(actions): + infer_action = action.get_next_action() + for ac in infer_action: + ac.parent = action + ac.pt = id_x + infer_action.reverse() + if stack[-1] is type(action): + stack.pop() + # check if the are non-terminal + if check_type(infer_action): + stack.extend(infer_action) + else: + for t in actions: + if type(t) != semQL.C: + print(t, end="") + print('asd') + print(action) + print(stack[-1]) + raise RuntimeError("Not the right action") + result = stack[-1] if len(stack) > 0 else None + + return result + + def apply_action(self, action): + # TODO: not finish implement yet + self.t += 1 + self.actions.append(action) + + def clone_and_apply_action(self, action): + new_hyp = self.copy() + new_hyp.apply_action(action) + + return new_hyp + + def clone_and_apply_action_info(self, action_info): + action = action_info.action + action.score = action_info.score + new_hyp = self.clone_and_apply_action(action) + new_hyp.action_infos.append(action_info) + new_hyp.sketch_step = self.sketch_step + new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history) + + return new_hyp + + def copy(self): + new_hyp = Beams(is_sketch=self.is_sketch) + # if self.tree: + # new_hyp.tree = self.tree.copy() + + new_hyp.actions = list(self.actions) + new_hyp.score = self.score + new_hyp.t = self.t + new_hyp.sketch_step = self.sketch_step + new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history) + + return new_hyp + + def infer_n(self): + if len(self.actions) > 4: + prev_action = self.actions[-3] + if isinstance(prev_action, semQL.Filter): + if prev_action.id_c > 11: + # Nested Query, only select 1 column + return ['N A'] + if self.actions[0].id_c != 3: + return [self.actions[3].production] + return semQL.N._init_grammar() + + @property + def completed(self): + return True if self.get_availableClass() is None else False + + @property + def is_valid(self): + actions = self.actions + return self.check_sel_valid(actions) + + def check_sel_valid(self, actions): + find_sel = False + sel_actions = list() + for ac in actions: + if type(ac) == semQL.Sel: + find_sel = True + elif find_sel and type(ac) in [semQL.N, semQL.T, semQL.C, semQL.A]: + if type(ac) not in [semQL.N]: + sel_actions.append(ac) + elif find_sel and type(ac) not in [semQL.N, semQL.T, semQL.C, semQL.A]: + break + + if find_sel is False: + return True + + # not the complete sel lf + if len(sel_actions) % 3 != 0: + return True + + sel_string = list() + for i in range(len(sel_actions) // 3): + if (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c) in sel_string: + return False + else: + sel_string.append( + (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c)) + return True + + +if __name__ == '__main__': + test = Beams(is_sketch=True) + # print(semQL.Root1(1).get_next_action()) + test.actions.append(semQL.Root1(3)) + test.actions.append(semQL.Root(5)) + + print(str(test.get_availableClass())) diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..d83ad75 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : utils.py +# @Software: PyCharm +""" + +import copy + +import src.rule.semQL as define_rule +from src.models import nn_utils + + +class Example: + """ + + """ + def __init__(self, src_sent, tgt_actions=None, vis_seq=None, tab_cols=None, col_num=None, sql=None, + one_hot_type=None, col_hot_type=None, schema_len=None, tab_ids=None, + table_names=None, table_len=None, col_table_dict=None, cols=None, + table_col_name=None, table_col_len=None, + col_pred=None, tokenized_src_sent=None, + ): + + self.src_sent = src_sent + self.tokenized_src_sent = tokenized_src_sent + self.vis_seq = vis_seq + self.tab_cols = tab_cols + self.col_num = col_num + self.sql = sql + self.one_hot_type=one_hot_type + self.col_hot_type = col_hot_type + self.schema_len = schema_len + self.tab_ids = tab_ids + self.table_names = table_names + self.table_len = table_len + self.col_table_dict = col_table_dict + self.cols = cols + self.table_col_name = table_col_name + self.table_col_len = table_col_len + self.col_pred = col_pred + self.tgt_actions = tgt_actions + self.truth_actions = copy.deepcopy(tgt_actions) + + self.sketch = list() + if self.truth_actions: + for ta in self.truth_actions: + if isinstance(ta, define_rule.C) or isinstance(ta, define_rule.T) or isinstance(ta, define_rule.A): + continue + self.sketch.append(ta) + + +class cached_property(object): + """ A property that is only computed once per instance and then replaces + itself with an ordinary attribute. Deleting the attribute resets the + property. + + Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 + """ + + def __init__(self, func): + self.__doc__ = getattr(func, '__doc__') + self.func = func + + def __get__(self, obj, cls): + if obj is None: + return self + value = obj.__dict__[self.func.__name__] = self.func(obj) + return value + + +class Batch(object): + def __init__(self, examples, grammar, cuda=False): + self.examples = examples + + if examples[0].tgt_actions: + self.max_action_num = max(len(e.tgt_actions) for e in self.examples) + self.max_sketch_num = max(len(e.sketch) for e in self.examples) + + self.src_sents = [e.src_sent for e in self.examples] + self.src_sents_len = [len(e.src_sent) for e in self.examples] + self.tokenized_src_sents = [e.tokenized_src_sent for e in self.examples] + self.tokenized_src_sents_len = [len(e.tokenized_src_sent) for e in examples] + self.src_sents_word = [e.src_sent for e in self.examples] + self.table_sents_word = [[" ".join(x) for x in e.tab_cols] for e in self.examples] + + self.schema_sents_word = [[" ".join(x) for x in e.table_names] for e in self.examples] + + self.src_type = [e.one_hot_type for e in self.examples] + self.col_hot_type = [e.col_hot_type for e in self.examples] + self.table_sents = [e.tab_cols for e in self.examples] + self.col_num = [e.col_num for e in self.examples] + self.tab_ids = [e.tab_ids for e in self.examples] + self.table_names = [e.table_names for e in self.examples] + self.table_len = [e.table_len for e in examples] + self.col_table_dict = [e.col_table_dict for e in examples] + self.table_col_name = [e.table_col_name for e in examples] + self.table_col_len = [e.table_col_len for e in examples] + self.col_pred = [e.col_pred for e in examples] + + self.grammar = grammar + self.cuda = cuda + + def __len__(self): + return len(self.examples) + + + def table_dict_mask(self, table_dict): + return nn_utils.table_dict_to_mask_tensor(self.table_len, table_dict, cuda=self.cuda) + + @cached_property + def pred_col_mask(self): + return nn_utils.pred_col_mask(self.col_pred, self.col_num) + + @cached_property + def schema_token_mask(self): + return nn_utils.length_array_to_mask_tensor(self.table_len, cuda=self.cuda) + + @cached_property + def table_token_mask(self): + return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda) + + @cached_property + def table_appear_mask(self): + return nn_utils.appear_to_mask_tensor(self.col_num, cuda=self.cuda) + + @cached_property + def table_unk_mask(self): + return nn_utils.length_array_to_mask_tensor(self.col_num, cuda=self.cuda, value=None) + + @cached_property + def src_token_mask(self): + return nn_utils.length_array_to_mask_tensor(self.src_sents_len, + cuda=self.cuda) + + diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..60a98bb --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/26 +# @Author : Jiaqi&Zecheng +# @File : __init__.py.py +# @Software: PyCharm +""" \ No newline at end of file diff --git a/src/models/basic_model.py b/src/models/basic_model.py new file mode 100644 index 0000000..ea145ae --- /dev/null +++ b/src/models/basic_model.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/26 +# @Author : Jiaqi&Zecheng +# @File : basic_model.py +# @Software: PyCharm +""" + +import numpy as np +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils +from torch.autograd import Variable +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence + +from src.rule import semQL as define_rule + + +class BasicModel(nn.Module): + + def __init__(self): + super(BasicModel, self).__init__() + pass + + def embedding_cosine(self, src_embedding, table_embedding, table_unk_mask): + embedding_differ = [] + for i in range(table_embedding.size(1)): + one_table_embedding = table_embedding[:, i, :] + one_table_embedding = one_table_embedding.unsqueeze(1).expand(table_embedding.size(0), + src_embedding.size(1), + table_embedding.size(2)) + + topk_val = F.cosine_similarity(one_table_embedding, src_embedding, dim=-1) + + embedding_differ.append(topk_val) + embedding_differ = torch.stack(embedding_differ).transpose(1, 0) + embedding_differ.data.masked_fill_(table_unk_mask.unsqueeze(2).expand( + table_embedding.size(0), + table_embedding.size(1), + embedding_differ.size(2) + ), 0) + + return embedding_differ + + def encode(self, src_sents_var, src_sents_len, q_onehot_project=None): + """ + encode the source sequence + :return: + src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2) + last_state, last_cell: Variable(batch_size, hidden_size) + """ + src_token_embed = self.gen_x_batch(src_sents_var) + + if q_onehot_project is not None: + src_token_embed = torch.cat([src_token_embed, q_onehot_project], dim=-1) + + packed_src_token_embed = pack_padded_sequence(src_token_embed, src_sents_len, batch_first=True) + # src_encodings: (tgt_query_len, batch_size, hidden_size) + src_encodings, (last_state, last_cell) = self.encoder_lstm(packed_src_token_embed) + src_encodings, _ = pad_packed_sequence(src_encodings, batch_first=True) + # src_encodings: (batch_size, tgt_query_len, hidden_size) + # src_encodings = src_encodings.permute(1, 0, 2) + # (batch_size, hidden_size * 2) + last_state = torch.cat([last_state[0], last_state[1]], -1) + last_cell = torch.cat([last_cell[0], last_cell[1]], -1) + + return src_encodings, (last_state, last_cell) + + def input_type(self, values_list): + B = len(values_list) + val_len = [] + for value in values_list: + val_len.append(len(value)) + max_len = max(val_len) + # for the Begin and End + val_emb_array = np.zeros((B, max_len, values_list[0].shape[1]), dtype=np.float32) + for i in range(B): + val_emb_array[i, :val_len[i], :] = values_list[i][:, :] + + val_inp = torch.from_numpy(val_emb_array) + if self.args.cuda: + val_inp = val_inp.cuda() + val_inp_var = Variable(val_inp) + return val_inp_var + + def padding_sketch(self, sketch): + padding_result = [] + for action in sketch: + padding_result.append(action) + if type(action) == define_rule.N: + for _ in range(action.id_c + 1): + padding_result.append(define_rule.A(0)) + padding_result.append(define_rule.C(0)) + padding_result.append(define_rule.T(0)) + elif type(action) == define_rule.Filter and 'A' in action.production: + padding_result.append(define_rule.A(0)) + padding_result.append(define_rule.C(0)) + padding_result.append(define_rule.T(0)) + elif type(action) == define_rule.Order or type(action) == define_rule.Sup: + padding_result.append(define_rule.A(0)) + padding_result.append(define_rule.C(0)) + padding_result.append(define_rule.T(0)) + + return padding_result + + def gen_x_batch(self, q): + B = len(q) + val_embs = [] + val_len = np.zeros(B, dtype=np.int64) + is_list = False + if type(q[0][0]) == list: + is_list = True + for i, one_q in enumerate(q): + if not is_list: + q_val = list( + map(lambda x: self.word_emb.get(x, np.zeros(self.args.col_embed_size, dtype=np.float32)), one_q)) + else: + q_val = [] + for ws in one_q: + emb_list = [] + ws_len = len(ws) + for w in ws: + emb_list.append(self.word_emb.get(w, self.word_emb['unk'])) + if ws_len == 0: + raise Exception("word list should not be empty!") + elif ws_len == 1: + q_val.append(emb_list[0]) + else: + q_val.append(sum(emb_list) / float(ws_len)) + + val_embs.append(q_val) + val_len[i] = len(q_val) + max_len = max(val_len) + + val_emb_array = np.zeros((B, max_len, self.args.col_embed_size), dtype=np.float32) + for i in range(B): + for t in range(len(val_embs[i])): + val_emb_array[i, t, :] = val_embs[i][t] + val_inp = torch.from_numpy(val_emb_array) + if self.args.cuda: + val_inp = val_inp.cuda() + return val_inp + + def save(self, path): + dir_name = os.path.dirname(path) + if not os.path.exists(dir_name): + os.makedirs(dir_name) + + params = { + 'args': self.args, + 'vocab': self.vocab, + 'grammar': self.grammar, + 'state_dict': self.state_dict() + } + torch.save(params, path) diff --git a/src/models/model.py b/src/models/model.py new file mode 100644 index 0000000..4c747a5 --- /dev/null +++ b/src/models/model.py @@ -0,0 +1,764 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : model.py +# @Software: PyCharm +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils +from torch.autograd import Variable + +from src.beam import Beams, ActionInfo +from src.dataset import Batch +from src.models import nn_utils +from src.models.basic_model import BasicModel +from src.models.pointer_net import PointerNet +from src.rule import semQL as define_rule + + +class IRNet(BasicModel): + + def __init__(self, args, grammar): + super(IRNet, self).__init__() + self.args = args + self.grammar = grammar + self.use_column_pointer = args.column_pointer + self.use_sentence_features = args.sentence_features + + if args.cuda: + self.new_long_tensor = torch.cuda.LongTensor + self.new_tensor = torch.cuda.FloatTensor + else: + self.new_long_tensor = torch.LongTensor + self.new_tensor = torch.FloatTensor + + self.encoder_lstm = nn.LSTM(args.embed_size, args.hidden_size // 2, bidirectional=True, + batch_first=True) + + input_dim = args.action_embed_size + \ + args.att_vec_size + \ + args.type_embed_size + # previous action + # input feeding + # pre type embedding + + self.lf_decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size) + + self.sketch_decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size) + + # initialize the decoder's state and cells with encoder hidden states + self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size) + + self.att_sketch_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.att_lf_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + + self.sketch_att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False) + self.lf_att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False) + + self.prob_att = nn.Linear(args.att_vec_size, 1) + self.prob_len = nn.Linear(1, 1) + + self.col_type = nn.Linear(4, args.col_embed_size) + self.sketch_encoder = nn.LSTM(args.action_embed_size, args.action_embed_size // 2, bidirectional=True, + batch_first=True) + + self.production_embed = nn.Embedding(len(grammar.prod2id), args.action_embed_size) + self.type_embed = nn.Embedding(len(grammar.type2id), args.type_embed_size) + self.production_readout_b = nn.Parameter(torch.FloatTensor(len(grammar.prod2id)).zero_()) + + self.att_project = nn.Linear(args.hidden_size + args.type_embed_size, args.hidden_size) + + self.N_embed = nn.Embedding(len(define_rule.N._init_grammar()), args.action_embed_size) + + self.read_out_act = F.tanh if args.readout == 'non_linear' else nn_utils.identity + + self.query_vec_to_action_embed = nn.Linear(args.att_vec_size, args.action_embed_size, + bias=args.readout == 'non_linear') + + self.production_readout = lambda q: F.linear(self.read_out_act(self.query_vec_to_action_embed(q)), + self.production_embed.weight, self.production_readout_b) + + self.q_att = nn.Linear(args.hidden_size, args.embed_size) + + self.column_rnn_input = nn.Linear(args.col_embed_size, args.action_embed_size, bias=False) + self.table_rnn_input = nn.Linear(args.col_embed_size, args.action_embed_size, bias=False) + + self.dropout = nn.Dropout(args.dropout) + + self.column_pointer_net = PointerNet(args.hidden_size, args.col_embed_size, attention_type=args.column_att) + + self.table_pointer_net = PointerNet(args.hidden_size, args.col_embed_size, attention_type=args.column_att) + + # initial the embedding layers + nn.init.xavier_normal_(self.production_embed.weight.data) + nn.init.xavier_normal_(self.type_embed.weight.data) + nn.init.xavier_normal_(self.N_embed.weight.data) + print('Use Column Pointer: ', True if self.use_column_pointer else False) + + def forward(self, examples): + args = self.args + # now should implement the examples + batch = Batch(examples, self.grammar, cuda=self.args.cuda) + + table_appear_mask = batch.table_appear_mask + + + src_encodings, (last_state, last_cell) = self.encode(batch.src_sents, batch.src_sents_len, None) + + src_encodings = self.dropout(src_encodings) + + utterance_encodings_sketch_linear = self.att_sketch_linear(src_encodings) + utterance_encodings_lf_linear = self.att_lf_linear(src_encodings) + + dec_init_vec = self.init_decoder_state(last_cell) + h_tm1 = dec_init_vec + action_probs = [[] for _ in examples] + + zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_()) + zero_type_embed = Variable(self.new_tensor(args.type_embed_size).zero_()) + + sketch_attention_history = list() + + for t in range(batch.max_sketch_num): + if t == 0: + x = Variable(self.new_tensor(len(batch), self.sketch_decoder_lstm.input_size).zero_(), + requires_grad=False) + else: + a_tm1_embeds = [] + pre_types = [] + for e_id, example in enumerate(examples): + + if t < len(example.sketch): + # get the last action + # This is the action embedding + action_tm1 = example.sketch[t - 1] + if type(action_tm1) in [define_rule.Root1, + define_rule.Root, + define_rule.Sel, + define_rule.Filter, + define_rule.Sup, + define_rule.N, + define_rule.Order]: + a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] + else: + print(action_tm1, 'only for sketch') + quit() + a_tm1_embed = zero_action_embed + pass + else: + a_tm1_embed = zero_action_embed + + a_tm1_embeds.append(a_tm1_embed) + + a_tm1_embeds = torch.stack(a_tm1_embeds) + inputs = [a_tm1_embeds] + + for e_id, example in enumerate(examples): + if t < len(example.sketch): + action_tm = example.sketch[t - 1] + pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]] + else: + pre_type = zero_type_embed + pre_types.append(pre_type) + + pre_types = torch.stack(pre_types) + + inputs.append(att_tm1) + inputs.append(pre_types) + x = torch.cat(inputs, dim=-1) + + src_mask = batch.src_token_mask + + (h_t, cell_t), att_t, aw = self.step(x, h_tm1, src_encodings, + utterance_encodings_sketch_linear, self.sketch_decoder_lstm, + self.sketch_att_vec_linear, + src_token_mask=src_mask, return_att_weight=True) + sketch_attention_history.append(att_t) + + # get the Root possibility + apply_rule_prob = F.softmax(self.production_readout(att_t), dim=-1) + + for e_id, example in enumerate(examples): + if t < len(example.sketch): + action_t = example.sketch[t] + act_prob_t_i = apply_rule_prob[e_id, self.grammar.prod2id[action_t.production]] + action_probs[e_id].append(act_prob_t_i) + + h_tm1 = (h_t, cell_t) + att_tm1 = att_t + + sketch_prob_var = torch.stack( + [torch.stack(action_probs_i, dim=0).log().sum() for action_probs_i in action_probs], dim=0) + + table_embedding = self.gen_x_batch(batch.table_sents) + src_embedding = self.gen_x_batch(batch.src_sents) + schema_embedding = self.gen_x_batch(batch.table_names) + + # get emb differ + embedding_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=table_embedding, + table_unk_mask=batch.table_unk_mask) + + schema_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=schema_embedding, + table_unk_mask=batch.schema_token_mask) + + tab_ctx = (src_encodings.unsqueeze(1) * embedding_differ.unsqueeze(3)).sum(2) + schema_ctx = (src_encodings.unsqueeze(1) * schema_differ.unsqueeze(3)).sum(2) + + + table_embedding = table_embedding + tab_ctx + + schema_embedding = schema_embedding + schema_ctx + + col_type = self.input_type(batch.col_hot_type) + + col_type_var = self.col_type(col_type) + + table_embedding = table_embedding + col_type_var + + batch_table_dict = batch.col_table_dict + table_enable = np.zeros(shape=(len(examples))) + action_probs = [[] for _ in examples] + + h_tm1 = dec_init_vec + + for t in range(batch.max_action_num): + if t == 0: + # x = self.lf_begin_vec.unsqueeze(0).repeat(len(batch), 1) + x = Variable(self.new_tensor(len(batch), self.lf_decoder_lstm.input_size).zero_(), requires_grad=False) + else: + a_tm1_embeds = [] + pre_types = [] + + for e_id, example in enumerate(examples): + if t < len(example.tgt_actions): + action_tm1 = example.tgt_actions[t - 1] + if type(action_tm1) in [define_rule.Root1, + define_rule.Root, + define_rule.Sel, + define_rule.Filter, + define_rule.Sup, + define_rule.N, + define_rule.Order, + ]: + + a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] + + else: + if isinstance(action_tm1, define_rule.C): + a_tm1_embed = self.column_rnn_input(table_embedding[e_id, action_tm1.id_c]) + elif isinstance(action_tm1, define_rule.T): + a_tm1_embed = self.column_rnn_input(schema_embedding[e_id, action_tm1.id_c]) + elif isinstance(action_tm1, define_rule.A): + a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] + else: + print(action_tm1, 'not implement') + quit() + a_tm1_embed = zero_action_embed + pass + + else: + a_tm1_embed = zero_action_embed + a_tm1_embeds.append(a_tm1_embed) + + a_tm1_embeds = torch.stack(a_tm1_embeds) + + inputs = [a_tm1_embeds] + + # tgt t-1 action type + for e_id, example in enumerate(examples): + if t < len(example.tgt_actions): + action_tm = example.tgt_actions[t - 1] + pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]] + else: + pre_type = zero_type_embed + pre_types.append(pre_type) + + pre_types = torch.stack(pre_types) + + inputs.append(att_tm1) + + inputs.append(pre_types) + + x = torch.cat(inputs, dim=-1) + + src_mask = batch.src_token_mask + + (h_t, cell_t), att_t, aw = self.step(x, h_tm1, src_encodings, + utterance_encodings_lf_linear, self.lf_decoder_lstm, + self.lf_att_vec_linear, + src_token_mask=src_mask, return_att_weight=True) + + apply_rule_prob = F.softmax(self.production_readout(att_t), dim=-1) + table_appear_mask_val = torch.from_numpy(table_appear_mask) + if self.cuda: + table_appear_mask_val = table_appear_mask_val.cuda() + + if self.use_column_pointer: + gate = F.sigmoid(self.prob_att(att_t)) + weights = self.column_pointer_net(src_encodings=table_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=None) * table_appear_mask_val * gate + self.column_pointer_net( + src_encodings=table_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=None) * (1 - table_appear_mask_val) * (1 - gate) + else: + weights = self.column_pointer_net(src_encodings=table_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=batch.table_token_mask) + + weights.data.masked_fill_(batch.table_token_mask, -float('inf')) + + column_attention_weights = F.softmax(weights, dim=-1) + + table_weights = self.table_pointer_net(src_encodings=schema_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=None) + + schema_token_mask = batch.schema_token_mask.expand_as(table_weights) + table_weights.data.masked_fill_(schema_token_mask, -float('inf')) + table_dict = [batch_table_dict[x_id][int(x)] for x_id, x in enumerate(table_enable.tolist())] + table_mask = batch.table_dict_mask(table_dict) + table_weights.data.masked_fill_(table_mask, -float('inf')) + + table_weights = F.softmax(table_weights, dim=-1) + # now get the loss + for e_id, example in enumerate(examples): + if t < len(example.tgt_actions): + action_t = example.tgt_actions[t] + if isinstance(action_t, define_rule.C): + table_appear_mask[e_id, action_t.id_c] = 1 + table_enable[e_id] = action_t.id_c + act_prob_t_i = column_attention_weights[e_id, action_t.id_c] + action_probs[e_id].append(act_prob_t_i) + elif isinstance(action_t, define_rule.T): + act_prob_t_i = table_weights[e_id, action_t.id_c] + action_probs[e_id].append(act_prob_t_i) + elif isinstance(action_t, define_rule.A): + act_prob_t_i = apply_rule_prob[e_id, self.grammar.prod2id[action_t.production]] + action_probs[e_id].append(act_prob_t_i) + else: + pass + h_tm1 = (h_t, cell_t) + att_tm1 = att_t + lf_prob_var = torch.stack( + [torch.stack(action_probs_i, dim=0).log().sum() for action_probs_i in action_probs], dim=0) + + return [sketch_prob_var, lf_prob_var] + + def parse(self, examples, beam_size=5): + """ + one example a time + :param examples: + :param beam_size: + :return: + """ + batch = Batch([examples], self.grammar, cuda=self.args.cuda) + + src_encodings, (last_state, last_cell) = self.encode(batch.src_sents, batch.src_sents_len, None) + src_encodings = self.dropout(src_encodings) + + utterance_encodings_sketch_linear = self.att_sketch_linear(src_encodings) + utterance_encodings_lf_linear = self.att_lf_linear(src_encodings) + + dec_init_vec = self.init_decoder_state(last_cell) + h_tm1 = dec_init_vec + + t = 0 + beams = [Beams(is_sketch=True)] + completed_beams = [] + + while len(completed_beams) < beam_size and t < self.args.decode_max_time_step: + hyp_num = len(beams) + exp_src_enconding = src_encodings.expand(hyp_num, src_encodings.size(1), + src_encodings.size(2)) + exp_src_encodings_sketch_linear = utterance_encodings_sketch_linear.expand(hyp_num, + utterance_encodings_sketch_linear.size( + 1), + utterance_encodings_sketch_linear.size( + 2)) + if t == 0: + with torch.no_grad(): + x = Variable(self.new_tensor(1, self.sketch_decoder_lstm.input_size).zero_()) + else: + a_tm1_embeds = [] + pre_types = [] + for e_id, hyp in enumerate(beams): + action_tm1 = hyp.actions[-1] + if type(action_tm1) in [define_rule.Root1, + define_rule.Root, + define_rule.Sel, + define_rule.Filter, + define_rule.Sup, + define_rule.N, + define_rule.Order]: + a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] + else: + raise ValueError('unknown action %s' % action_tm1) + + a_tm1_embeds.append(a_tm1_embed) + a_tm1_embeds = torch.stack(a_tm1_embeds) + inputs = [a_tm1_embeds] + + for e_id, hyp in enumerate(beams): + action_tm = hyp.actions[-1] + pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]] + pre_types.append(pre_type) + + pre_types = torch.stack(pre_types) + + inputs.append(att_tm1) + inputs.append(pre_types) + x = torch.cat(inputs, dim=-1) + + (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_enconding, + exp_src_encodings_sketch_linear, self.sketch_decoder_lstm, + self.sketch_att_vec_linear, + src_token_mask=None) + + apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) + + new_hyp_meta = [] + for hyp_id, hyp in enumerate(beams): + action_class = hyp.get_availableClass() + if action_class in [define_rule.Root1, + define_rule.Root, + define_rule.Sel, + define_rule.Filter, + define_rule.Sup, + define_rule.N, + define_rule.Order]: + possible_productions = self.grammar.get_production(action_class) + for possible_production in possible_productions: + prod_id = self.grammar.prod2id[possible_production] + prod_score = apply_rule_log_prob[hyp_id, prod_id] + new_hyp_score = hyp.score + prod_score.data.cpu() + meta_entry = {'action_type': action_class, 'prod_id': prod_id, + 'score': prod_score, 'new_hyp_score': new_hyp_score, + 'prev_hyp_id': hyp_id} + new_hyp_meta.append(meta_entry) + else: + raise RuntimeError('No right action class') + + if not new_hyp_meta: break + + new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta], dim=0) + top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores, + k=min(new_hyp_scores.size(0), + beam_size - len(completed_beams))) + + live_hyp_ids = [] + new_beams = [] + for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()): + action_info = ActionInfo() + hyp_meta_entry = new_hyp_meta[meta_id] + prev_hyp_id = hyp_meta_entry['prev_hyp_id'] + prev_hyp = beams[prev_hyp_id] + action_type_str = hyp_meta_entry['action_type'] + prod_id = hyp_meta_entry['prod_id'] + if prod_id < len(self.grammar.id2prod): + production = self.grammar.id2prod[prod_id] + action = action_type_str(list(action_type_str._init_grammar()).index(production)) + else: + raise NotImplementedError + + action_info.action = action + action_info.t = t + action_info.score = hyp_meta_entry['score'] + new_hyp = prev_hyp.clone_and_apply_action_info(action_info) + new_hyp.score = new_hyp_score + new_hyp.inputs.extend(prev_hyp.inputs) + + if new_hyp.is_valid is False: + continue + + if new_hyp.completed: + completed_beams.append(new_hyp) + else: + new_beams.append(new_hyp) + live_hyp_ids.append(prev_hyp_id) + + if live_hyp_ids: + h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) + att_tm1 = att_t[live_hyp_ids] + beams = new_beams + t += 1 + else: + break + + # now get the sketch result + completed_beams.sort(key=lambda hyp: -hyp.score) + if len(completed_beams) == 0: + return [[], []] + + sketch_actions = completed_beams[0].actions + # sketch_actions = examples.sketch + + padding_sketch = self.padding_sketch(sketch_actions) + + table_embedding = self.gen_x_batch(batch.table_sents) + src_embedding = self.gen_x_batch(batch.src_sents) + schema_embedding = self.gen_x_batch(batch.table_names) + + # get emb differ + embedding_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=table_embedding, + table_unk_mask=batch.table_unk_mask) + + schema_differ = self.embedding_cosine(src_embedding=src_embedding, table_embedding=schema_embedding, + table_unk_mask=batch.schema_token_mask) + + tab_ctx = (src_encodings.unsqueeze(1) * embedding_differ.unsqueeze(3)).sum(2) + schema_ctx = (src_encodings.unsqueeze(1) * schema_differ.unsqueeze(3)).sum(2) + + table_embedding = table_embedding + tab_ctx + + schema_embedding = schema_embedding + schema_ctx + + col_type = self.input_type(batch.col_hot_type) + + col_type_var = self.col_type(col_type) + + table_embedding = table_embedding + col_type_var + + batch_table_dict = batch.col_table_dict + + h_tm1 = dec_init_vec + + t = 0 + beams = [Beams(is_sketch=False)] + completed_beams = [] + + while len(completed_beams) < beam_size and t < self.args.decode_max_time_step: + hyp_num = len(beams) + + # expand value + exp_src_encodings = src_encodings.expand(hyp_num, src_encodings.size(1), + src_encodings.size(2)) + exp_utterance_encodings_lf_linear = utterance_encodings_lf_linear.expand(hyp_num, + utterance_encodings_lf_linear.size( + 1), + utterance_encodings_lf_linear.size( + 2)) + exp_table_embedding = table_embedding.expand(hyp_num, table_embedding.size(1), + table_embedding.size(2)) + + exp_schema_embedding = schema_embedding.expand(hyp_num, schema_embedding.size(1), + schema_embedding.size(2)) + + + table_appear_mask = batch.table_appear_mask + table_appear_mask = np.zeros((hyp_num, table_appear_mask.shape[1]), dtype=np.float32) + table_enable = np.zeros(shape=(hyp_num)) + for e_id, hyp in enumerate(beams): + for act in hyp.actions: + if type(act) == define_rule.C: + table_appear_mask[e_id][act.id_c] = 1 + table_enable[e_id] = act.id_c + + if t == 0: + with torch.no_grad(): + x = Variable(self.new_tensor(1, self.lf_decoder_lstm.input_size).zero_()) + else: + a_tm1_embeds = [] + pre_types = [] + for e_id, hyp in enumerate(beams): + action_tm1 = hyp.actions[-1] + if type(action_tm1) in [define_rule.Root1, + define_rule.Root, + define_rule.Sel, + define_rule.Filter, + define_rule.Sup, + define_rule.N, + define_rule.Order]: + + a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] + hyp.sketch_step += 1 + elif isinstance(action_tm1, define_rule.C): + a_tm1_embed = self.column_rnn_input(table_embedding[0, action_tm1.id_c]) + elif isinstance(action_tm1, define_rule.T): + a_tm1_embed = self.column_rnn_input(schema_embedding[0, action_tm1.id_c]) + elif isinstance(action_tm1, define_rule.A): + a_tm1_embed = self.production_embed.weight[self.grammar.prod2id[action_tm1.production]] + else: + raise ValueError('unknown action %s' % action_tm1) + + a_tm1_embeds.append(a_tm1_embed) + + a_tm1_embeds = torch.stack(a_tm1_embeds) + + inputs = [a_tm1_embeds] + + for e_id, hyp in enumerate(beams): + action_tm = hyp.actions[-1] + pre_type = self.type_embed.weight[self.grammar.type2id[type(action_tm)]] + pre_types.append(pre_type) + + pre_types = torch.stack(pre_types) + + inputs.append(att_tm1) + inputs.append(pre_types) + x = torch.cat(inputs, dim=-1) + + (h_t, cell_t), att_t = self.step(x, h_tm1, exp_src_encodings, + exp_utterance_encodings_lf_linear, self.lf_decoder_lstm, + self.lf_att_vec_linear, + src_token_mask=None) + + apply_rule_log_prob = F.log_softmax(self.production_readout(att_t), dim=-1) + + table_appear_mask_val = torch.from_numpy(table_appear_mask) + + if self.args.cuda: table_appear_mask_val = table_appear_mask_val.cuda() + + if self.use_column_pointer: + gate = F.sigmoid(self.prob_att(att_t)) + weights = self.column_pointer_net(src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=None) * table_appear_mask_val * gate + self.column_pointer_net( + src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=None) * (1 - table_appear_mask_val) * (1 - gate) + # weights = weights + self.col_attention_out(exp_embedding_differ).squeeze() + else: + weights = self.column_pointer_net(src_encodings=exp_table_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=batch.table_token_mask) + # weights.data.masked_fill_(exp_col_pred_mask, -float('inf')) + + column_selection_log_prob = F.log_softmax(weights, dim=-1) + + table_weights = self.table_pointer_net(src_encodings=exp_schema_embedding, query_vec=att_t.unsqueeze(0), + src_token_mask=None) + # table_weights = self.table_pointer_net(src_encodings=exp_schema_embedding, query_vec=att_t.unsqueeze(0), src_token_mask=None) + + schema_token_mask = batch.schema_token_mask.expand_as(table_weights) + table_weights.data.masked_fill_(schema_token_mask, -float('inf')) + + table_dict = [batch_table_dict[0][int(x)] for x_id, x in enumerate(table_enable.tolist())] + table_mask = batch.table_dict_mask(table_dict) + table_weights.data.masked_fill_(table_mask, -float('inf')) + + table_weights = F.log_softmax(table_weights, dim=-1) + + new_hyp_meta = [] + for hyp_id, hyp in enumerate(beams): + # TODO: should change this + if type(padding_sketch[t]) == define_rule.A: + possible_productions = self.grammar.get_production(define_rule.A) + for possible_production in possible_productions: + prod_id = self.grammar.prod2id[possible_production] + prod_score = apply_rule_log_prob[hyp_id, prod_id] + + new_hyp_score = hyp.score + prod_score.data.cpu() + meta_entry = {'action_type': define_rule.A, 'prod_id': prod_id, + 'score': prod_score, 'new_hyp_score': new_hyp_score, + 'prev_hyp_id': hyp_id} + new_hyp_meta.append(meta_entry) + + elif type(padding_sketch[t]) == define_rule.C: + for col_id, _ in enumerate(batch.table_sents[0]): + col_sel_score = column_selection_log_prob[hyp_id, col_id] + new_hyp_score = hyp.score + col_sel_score.data.cpu() + meta_entry = {'action_type': define_rule.C, 'col_id': col_id, + 'score': col_sel_score, 'new_hyp_score': new_hyp_score, + 'prev_hyp_id': hyp_id} + new_hyp_meta.append(meta_entry) + elif type(padding_sketch[t]) == define_rule.T: + for t_id, _ in enumerate(batch.table_names[0]): + t_sel_score = table_weights[hyp_id, t_id] + new_hyp_score = hyp.score + t_sel_score.data.cpu() + + meta_entry = {'action_type': define_rule.T, 't_id': t_id, + 'score': t_sel_score, 'new_hyp_score': new_hyp_score, + 'prev_hyp_id': hyp_id} + new_hyp_meta.append(meta_entry) + else: + prod_id = self.grammar.prod2id[padding_sketch[t].production] + new_hyp_score = hyp.score + torch.tensor(0.0) + meta_entry = {'action_type': type(padding_sketch[t]), 'prod_id': prod_id, + 'score': torch.tensor(0.0), 'new_hyp_score': new_hyp_score, + 'prev_hyp_id': hyp_id} + new_hyp_meta.append(meta_entry) + + if not new_hyp_meta: break + + new_hyp_scores = torch.stack([x['new_hyp_score'] for x in new_hyp_meta], dim=0) + top_new_hyp_scores, meta_ids = torch.topk(new_hyp_scores, + k=min(new_hyp_scores.size(0), + beam_size - len(completed_beams))) + + live_hyp_ids = [] + new_beams = [] + for new_hyp_score, meta_id in zip(top_new_hyp_scores.data.cpu(), meta_ids.data.cpu()): + action_info = ActionInfo() + hyp_meta_entry = new_hyp_meta[meta_id] + prev_hyp_id = hyp_meta_entry['prev_hyp_id'] + prev_hyp = beams[prev_hyp_id] + + action_type_str = hyp_meta_entry['action_type'] + if 'prod_id' in hyp_meta_entry: + prod_id = hyp_meta_entry['prod_id'] + if action_type_str == define_rule.C: + col_id = hyp_meta_entry['col_id'] + action = define_rule.C(col_id) + elif action_type_str == define_rule.T: + t_id = hyp_meta_entry['t_id'] + action = define_rule.T(t_id) + elif prod_id < len(self.grammar.id2prod): + production = self.grammar.id2prod[prod_id] + action = action_type_str(list(action_type_str._init_grammar()).index(production)) + else: + raise NotImplementedError + + action_info.action = action + action_info.t = t + action_info.score = hyp_meta_entry['score'] + + new_hyp = prev_hyp.clone_and_apply_action_info(action_info) + new_hyp.score = new_hyp_score + new_hyp.inputs.extend(prev_hyp.inputs) + + if new_hyp.is_valid is False: + continue + + if new_hyp.completed: + completed_beams.append(new_hyp) + else: + new_beams.append(new_hyp) + live_hyp_ids.append(prev_hyp_id) + + if live_hyp_ids: + h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) + att_tm1 = att_t[live_hyp_ids] + + beams = new_beams + t += 1 + else: + break + + completed_beams.sort(key=lambda hyp: -hyp.score) + + return [completed_beams, sketch_actions] + + def step(self, x, h_tm1, src_encodings, src_encodings_att_linear, decoder, attention_func, src_token_mask=None, + return_att_weight=False): + # h_t: (batch_size, hidden_size) + h_t, cell_t = decoder(x, h_tm1) + + ctx_t, alpha_t = nn_utils.dot_prod_attention(h_t, + src_encodings, src_encodings_att_linear, + mask=src_token_mask) + + att_t = F.tanh(attention_func(torch.cat([h_t, ctx_t], 1))) + att_t = self.dropout(att_t) + + if return_att_weight: + return (h_t, cell_t), att_t, alpha_t + else: + return (h_t, cell_t), att_t + + def init_decoder_state(self, enc_last_cell): + h_0 = self.decoder_cell_init(enc_last_cell) + h_0 = F.tanh(h_0) + + return h_0, Variable(self.new_tensor(h_0.size()).zero_()) + diff --git a/src/models/nn_utils.py b/src/models/nn_utils.py new file mode 100644 index 0000000..17d32a4 --- /dev/null +++ b/src/models/nn_utils.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : utils.py +# @Software: PyCharm +""" +import torch.nn.functional as F +import torch.nn.init as init +import numpy as np +import torch +from torch.autograd import Variable +from six.moves import xrange + +def dot_prod_attention(h_t, src_encoding, src_encoding_att_linear, mask=None): + """ + :param h_t: (batch_size, hidden_size) + :param src_encoding: (batch_size, src_sent_len, hidden_size * 2) + :param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size) + :param mask: (batch_size, src_sent_len) + """ + # (batch_size, src_sent_len) + att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2) + if mask is not None: + att_weight.data.masked_fill_(mask, -float('inf')) + att_weight = F.softmax(att_weight, dim=-1) + + att_view = (att_weight.size(0), 1, att_weight.size(1)) + # (batch_size, hidden_size) + ctx_vec = torch.bmm(att_weight.view(*att_view), src_encoding).squeeze(1) + + return ctx_vec, att_weight + + +def length_array_to_mask_tensor(length_array, cuda=False, value=None): + max_len = max(length_array) + batch_size = len(length_array) + + mask = np.ones((batch_size, max_len), dtype=np.uint8) + for i, seq_len in enumerate(length_array): + mask[i][:seq_len] = 0 + + if value != None: + for b_id in range(len(value)): + for c_id, c in enumerate(value[b_id]): + if value[b_id][c_id] == [3]: + mask[b_id][c_id] = 1 + + mask = torch.ByteTensor(mask) + return mask.cuda() if cuda else mask + + +def table_dict_to_mask_tensor(length_array, table_dict, cuda=False ): + max_len = max(length_array) + batch_size = len(table_dict) + + mask = np.ones((batch_size, max_len), dtype=np.uint8) + for i, ta_val in enumerate(table_dict): + for tt in ta_val: + mask[i][tt] = 0 + + mask = torch.ByteTensor(mask) + return mask.cuda() if cuda else mask + + +def length_position_tensor(length_array, cuda=False, value=None): + max_len = max(length_array) + batch_size = len(length_array) + + mask = np.zeros((batch_size, max_len), dtype=np.float32) + + for b_id in range(batch_size): + for len_c in range(length_array[b_id]): + mask[b_id][len_c] = len_c + 1 + + mask = torch.LongTensor(mask) + return mask.cuda() if cuda else mask + + +def appear_to_mask_tensor(length_array, cuda=False, value=None): + max_len = max(length_array) + batch_size = len(length_array) + mask = np.zeros((batch_size, max_len), dtype=np.float32) + return mask + +def pred_col_mask(value, max_len): + max_len = max(max_len) + batch_size = len(value) + mask = np.ones((batch_size, max_len), dtype=np.uint8) + for v_ind, v_val in enumerate(value): + for v in v_val: + mask[v_ind][v] = 0 + mask = torch.ByteTensor(mask) + return mask.cuda() + + +def input_transpose(sents, pad_token): + """ + transform the input List[sequence] of size (batch_size, max_sent_len) + into a list of size (batch_size, max_sent_len), with proper padding + """ + max_len = max(len(s) for s in sents) + batch_size = len(sents) + sents_t = [] + masks = [] + for e_id in range(batch_size): + if type(sents[0][0]) != list: + sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else pad_token for i in range(max_len)]) + else: + sents_t.append([sents[e_id][i] if len(sents[e_id]) > i else [pad_token] for i in range(max_len)]) + + masks.append([1 if len(sents[e_id]) > i else 0 for i in range(max_len)]) + + return sents_t, masks + + +def word2id(sents, vocab): + if type(sents[0]) == list: + if type(sents[0][0]) != list: + return [[vocab[w] for w in s] for s in sents] + else: + return [[[vocab[w] for w in s] for s in v] for v in sents ] + else: + return [vocab[w] for w in sents] + + +def id2word(sents, vocab): + if type(sents[0]) == list: + return [[vocab.id2word[w] for w in s] for s in sents] + else: + return [vocab.id2word[w] for w in sents] + + +def to_input_variable(sequences, vocab, cuda=False, training=True): + """ + given a list of sequences, + return a tensor of shape (max_sent_len, batch_size) + """ + word_ids = word2id(sequences, vocab) + sents_t, masks = input_transpose(word_ids, vocab['']) + + if type(sents_t[0][0]) != list: + with torch.no_grad(): + sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False) + if cuda: + sents_var = sents_var.cuda() + else: + sents_var = sents_t + + return sents_var + + +def variable_constr(x, v, cuda=False): + return Variable(torch.cuda.x(v)) if cuda else Variable(torch.x(v)) + + +def batch_iter(examples, batch_size, shuffle=False): + index_arr = np.arange(len(examples)) + if shuffle: + np.random.shuffle(index_arr) + + batch_num = int(np.ceil(len(examples) / float(batch_size))) + for batch_id in xrange(batch_num): + batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)] + batch_examples = [examples[i] for i in batch_ids] + + yield batch_examples + + +def isnan(data): + data = data.cpu().numpy() + return np.isnan(data).any() or np.isinf(data).any() + + +def log_sum_exp(inputs, dim=None, keepdim=False): + """Numerically stable logsumexp. + source: https://github.com/pytorch/pytorch/issues/2591 + + Args: + inputs: A Variable with any shape. + dim: An integer. + keepdim: A boolean. + + Returns: + Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). + """ + # For a 1-D array x (any array along a single dimension), + # log sum exp(x) = s + log sum exp(x - s) + # with s = max(x) being a common choice. + + if dim is None: + inputs = inputs.view(-1) + dim = 0 + s, _ = torch.max(inputs, dim=dim, keepdim=True) + outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() + if not keepdim: + outputs = outputs.squeeze(dim) + return outputs + + +def uniform_init(lower, upper, params): + for p in params: + p.data.uniform_(lower, upper) + + +def glorot_init(params): + for p in params: + if len(p.data.size()) > 1: + init.xavier_normal(p.data) + + +def identity(x): + return x + + +def pad_matrix(matrixs, cuda=False): + """ + :param matrixs: + :return: [batch_size, max_shape, max_shape], [batch_size] + """ + shape = [m.shape[0] for m in matrixs] + max_shape = max(shape) + tensors = list() + for s, m in zip(shape, matrixs): + delta = max_shape - s + if s > 0: + tensors.append(torch.as_tensor(np.pad(m, [(0, delta), (0, delta)], mode='constant'), dtype=torch.float)) + else: + tensors.append(torch.as_tensor(m, dtype=torch.float)) + tensors = torch.stack(tensors) + if cuda: + tensors = tensors.cuda() + return tensors diff --git a/src/models/pointer_net.py b/src/models/pointer_net.py new file mode 100644 index 0000000..a962beb --- /dev/null +++ b/src/models/pointer_net.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# coding=utf8 + +import torch +import torch.nn as nn +import torch.nn.utils +from torch.nn import Parameter + + +class AuxiliaryPointerNet(nn.Module): + + def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): + super(AuxiliaryPointerNet, self).__init__() + + assert attention_type in ('affine', 'dot_prod') + if attention_type == 'affine': + self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False) + self.auxiliary_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False) + self.attention_type = attention_type + + def forward(self, src_encodings, src_context_encodings, src_token_mask, query_vec): + """ + :param src_context_encodings: Variable(batch_size, src_sent_len, src_encoding_size) + :param src_encodings: Variable(batch_size, src_sent_len, src_encoding_size) + :param src_token_mask: Variable(batch_size, src_sent_len) + :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size) + :return: Variable(tgt_action_num, batch_size, src_sent_len) + """ + + # (batch_size, 1, src_sent_len, query_vec_size) + encodings = src_encodings.clone() + context_encodings = src_context_encodings.clone() + if self.attention_type == 'affine': + encodings = self.src_encoding_linear(src_encodings) + context_encodings = self.auxiliary_encoding_linear(src_context_encodings) + encodings = encodings.unsqueeze(1) + context_encodings = context_encodings.unsqueeze(1) + + # (batch_size, tgt_action_num, query_vec_size, 1) + q = query_vec.permute(1, 0, 2).unsqueeze(3) + + # (batch_size, tgt_action_num, src_sent_len) + weights = torch.matmul(encodings, q).squeeze(3) + context_weights = torch.matmul(context_encodings, q).squeeze(3) + + # (tgt_action_num, batch_size, src_sent_len) + weights = weights.permute(1, 0, 2) + context_weights = context_weights.permute(1, 0, 2) + + if src_token_mask is not None: + # (tgt_action_num, batch_size, src_sent_len) + src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights) + weights.data.masked_fill_(src_token_mask, -float('inf')) + context_weights.data.masked_fill_(src_token_mask, -float('inf')) + + sigma = 0.1 + return weights.squeeze(0) + sigma * context_weights.squeeze(0) + + +class PointerNet(nn.Module): + def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): + super(PointerNet, self).__init__() + + assert attention_type in ('affine', 'dot_prod') + if attention_type == 'affine': + self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False) + + self.attention_type = attention_type + self.input_linear = nn.Linear(query_vec_size, query_vec_size) + self.type_linear = nn.Linear(32, query_vec_size) + self.V = Parameter(torch.FloatTensor(query_vec_size), requires_grad=True) + self.tanh = nn.Tanh() + self.context_linear = nn.Conv1d(src_encoding_size, query_vec_size, 1, 1) + self.coverage_linear = nn.Conv1d(1, query_vec_size, 1, 1) + + + nn.init.uniform_(self.V, -1, 1) + + def forward(self, src_encodings, src_token_mask, query_vec): + """ + :param src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2) + :param src_token_mask: Variable(batch_size, src_sent_len) + :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size) + :return: Variable(tgt_action_num, batch_size, src_sent_len) + """ + + # (batch_size, 1, src_sent_len, query_vec_size) + + if self.attention_type == 'affine': + src_encodings = self.src_encoding_linear(src_encodings) + src_encodings = src_encodings.unsqueeze(1) + + # (batch_size, tgt_action_num, query_vec_size, 1) + q = query_vec.permute(1, 0, 2).unsqueeze(3) + + weights = torch.matmul(src_encodings, q).squeeze(3) + + weights = weights.permute(1, 0, 2) + + if src_token_mask is not None: + src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights) + weights.data.masked_fill_(src_token_mask, -float('inf')) + + return weights.squeeze(0) \ No newline at end of file diff --git a/src/rule/__init__.py b/src/rule/__init__.py new file mode 100644 index 0000000..d3c4dff --- /dev/null +++ b/src/rule/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/27 +# @Author : Jiaqi&Zecheng +# @File : __init__.py.py +# @Software: PyCharm +""" \ No newline at end of file diff --git a/src/rule/graph.py b/src/rule/graph.py new file mode 100644 index 0000000..49ffead --- /dev/null +++ b/src/rule/graph.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : utils.py +# @Software: PyCharm +""" + +from collections import deque, namedtuple + + +# we'll use infinity as a default distance to nodes. +inf = float('inf') +Edge = namedtuple('Edge', 'start, end, cost') + + +def make_edge(start, end, cost=1): + return Edge(start, end, cost) + + +class Graph: + def __init__(self, edges): + # let's check that the data is right + wrong_edges = [i for i in edges if len(i) not in [2, 3]] + if wrong_edges: + raise ValueError('Wrong edges data: {}'.format(wrong_edges)) + + self.edges = [make_edge(*edge) for edge in edges] + + @property + def vertices(self): + return set( + # this piece of magic turns ([1,2], [3,4]) into [1, 2, 3, 4] + # the set above makes it's elements unique. + sum( + ([edge.start, edge.end] for edge in self.edges), [] + ) + ) + + def get_node_pairs(self, n1, n2, both_ends=True): + if both_ends: + node_pairs = [[n1, n2], [n2, n1]] + else: + node_pairs = [[n1, n2]] + return node_pairs + + def remove_edge(self, n1, n2, both_ends=True): + node_pairs = self.get_node_pairs(n1, n2, both_ends) + edges = self.edges[:] + for edge in edges: + if [edge.start, edge.end] in node_pairs: + self.edges.remove(edge) + + def add_edge(self, n1, n2, cost=1, both_ends=True): + node_pairs = self.get_node_pairs(n1, n2, both_ends) + for edge in self.edges: + if [edge.start, edge.end] in node_pairs: + return ValueError('Edge {} {} already exists'.format(n1, n2)) + + self.edges.append(Edge(start=n1, end=n2, cost=cost)) + if both_ends: + self.edges.append(Edge(start=n2, end=n1, cost=cost)) + + @property + def neighbours(self): + neighbours = {vertex: set() for vertex in self.vertices} + for edge in self.edges: + neighbours[edge.start].add((edge.end, edge.cost)) + + return neighbours + + def dijkstra(self, source, dest): + assert source in self.vertices, 'Such source node doesn\'t exist' + assert dest in self.vertices, 'Such source node doesn\'t exis' + + # 1. Mark all nodes unvisited and store them. + # 2. Set the distance to zero for our initial node + # and to infinity for other nodes. + distances = {vertex: inf for vertex in self.vertices} + previous_vertices = { + vertex: None for vertex in self.vertices + } + distances[source] = 0 + vertices = self.vertices.copy() + + while vertices: + # 3. Select the unvisited node with the smallest distance, + # it's current node now. + current_vertex = min( + vertices, key=lambda vertex: distances[vertex]) + + # 6. Stop, if the smallest distance + # among the unvisited nodes is infinity. + if distances[current_vertex] == inf: + break + + # 4. Find unvisited neighbors for the current node + # and calculate their distances through the current node. + for neighbour, cost in self.neighbours[current_vertex]: + alternative_route = distances[current_vertex] + cost + + # Compare the newly calculated distance to the assigned + # and save the smaller one. + if alternative_route < distances[neighbour]: + distances[neighbour] = alternative_route + previous_vertices[neighbour] = current_vertex + + # 5. Mark the current node as visited + # and remove it from the unvisited set. + vertices.remove(current_vertex) + + path, current_vertex = deque(), dest + while previous_vertices[current_vertex] is not None: + path.appendleft(current_vertex) + current_vertex = previous_vertices[current_vertex] + if path: + path.appendleft(current_vertex) + return path + + +if __name__ == '__main__': + graph = Graph([ + ("a", "b", 7), ("a", "c", 9), ("a", "f", 14), ("b", "c", 10), + ("b", "d", 15), ("c", "d", 11), ("c", "f", 2), ("d", "e", 6), + ("e", "f", 9)]) + + print(graph.dijkstra("a", "e")) \ No newline at end of file diff --git a/src/rule/lf.py b/src/rule/lf.py new file mode 100644 index 0000000..c74b0fe --- /dev/null +++ b/src/rule/lf.py @@ -0,0 +1,229 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : utils.py +# @Software: PyCharm +""" +import copy +import json + +import numpy as np + +from src.rule import semQL as define_rule +from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1 + + +def _build_single_filter(lf, f): + # No conjunction + agg = lf.pop(0) + column = lf.pop(0) + if len(lf) == 0: + table = None + else: + table = lf.pop(0) + if not isinstance(table, define_rule.T): + lf.insert(0, table) + table = None + assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C) + if len(f.production.split()) == 3: + f.add_children(agg) + agg.set_parent(f) + agg.add_children(column) + column.set_parent(agg) + if table is not None: + column.add_children(table) + table.set_parent(column) + else: + # Subquery + f.add_children(agg) + agg.set_parent(f) + agg.add_children(column) + column.set_parent(agg) + if table is not None: + column.add_children(table) + table.set_parent(column) + _root = _build(lf) + f.add_children(_root) + _root.set_parent(f) + + +def _build_filter(lf, root_filter): + assert isinstance(root_filter, define_rule.Filter) + op = root_filter.production.split()[1] + if op == 'and' or op == 'or': + for i in range(2): + child = lf.pop(0) + op = child.production.split()[1] + if op == 'and' or op == 'or': + _f = _build_filter(lf, child) + root_filter.add_children(_f) + _f.set_parent(root_filter) + else: + _build_single_filter(lf, child) + root_filter.add_children(child) + child.set_parent(root_filter) + else: + _build_single_filter(lf, root_filter) + return root_filter + + +def _build(lf): + root = lf.pop(0) + assert isinstance(root, define_rule.Root) + length = len(root.production.split()) - 1 + while len(root.children) != length: + c_instance = lf.pop(0) + if isinstance(c_instance, define_rule.Sel): + sel_instance = c_instance + root.add_children(sel_instance) + sel_instance.set_parent(root) + + # define_rule.N + c_instance = lf.pop(0) + c_instance.set_parent(sel_instance) + sel_instance.add_children(c_instance) + assert isinstance(c_instance, define_rule.N) + for i in range(c_instance.id_c + 1): + agg = lf.pop(0) + column = lf.pop(0) + if len(lf) == 0: + table = None + else: + table = lf.pop(0) + if not isinstance(table, define_rule.T): + lf.insert(0, table) + table = None + assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C) + c_instance.add_children(agg) + agg.set_parent(c_instance) + agg.add_children(column) + column.set_parent(agg) + if table is not None: + column.add_children(table) + table.set_parent(column) + + elif isinstance(c_instance, define_rule.Sup) or isinstance(c_instance, define_rule.Order): + root.add_children(c_instance) + c_instance.set_parent(root) + + agg = lf.pop(0) + column = lf.pop(0) + if len(lf) == 0: + table = None + else: + table = lf.pop(0) + if not isinstance(table, define_rule.T): + lf.insert(0, table) + table = None + assert isinstance(agg, define_rule.A) and isinstance(column, define_rule.C) + c_instance.add_children(agg) + agg.set_parent(c_instance) + agg.add_children(column) + column.set_parent(agg) + if table is not None: + column.add_children(table) + table.set_parent(column) + + elif isinstance(c_instance, define_rule.Filter): + _build_filter(lf, c_instance) + root.add_children(c_instance) + c_instance.set_parent(root) + + return root + + +def build_tree(lf): + root = lf.pop(0) + assert isinstance(root, define_rule.Root1) + if root.id_c == 0 or root.id_c == 1 or root.id_c == 2: + root_1 = _build(lf) + root_2 = _build(lf) + root.add_children(root_1) + root.add_children(root_2) + root_1.set_parent(root) + root_2.set_parent(root) + else: + root_1 = _build(lf) + root.add_children(root_1) + root_1.set_parent(root) + verify(root) + # eliminate_parent(root) + + +def eliminate_parent(node): + for child in node.children: + eliminate_parent(child) + node.children = list() + + +def verify(node): + if isinstance(node, C) and len(node.children) > 0: + table = node.children[0] + assert table is None or isinstance(table, T) + if isinstance(node, T): + return + children_num = len(node.children) + if isinstance(node, Root1): + if node.id_c == 0 or node.id_c == 1 or node.id_c == 2: + assert children_num == 2 + else: + assert children_num == 1 + elif isinstance(node, Root): + assert children_num == len(node.production.split()) - 1 + elif isinstance(node, N): + assert children_num == int(node.id_c) + 1 + elif isinstance(node, Sup) or isinstance(node, Order) or isinstance(node, Sel): + assert children_num == 1 + elif isinstance(node, Filter): + op = node.production.split()[1] + if op == 'and' or op == 'or': + assert children_num == 2 + else: + if len(node.production.split()) == 3: + assert children_num == 1 + else: + assert children_num == 2 + for child in node.children: + assert child.parent == node + verify(child) + + +def label_matrix(lf, matrix, node): + nindex = lf.index(node) + for child in node.children: + if child not in lf: + continue + index = lf.index(child) + matrix[nindex][index] = 1 + label_matrix(lf, matrix, child) + + +def build_adjacency_matrix(lf, symmetry=False): + _lf = list() + for rule in lf: + if isinstance(rule, A) or isinstance(rule, C) or isinstance(rule, T): + continue + _lf.append(rule) + length = len(_lf) + matrix = np.zeros((length, length,)) + label_matrix(_lf, matrix, _lf[0]) + if symmetry: + matrix += matrix.T + return matrix + + +if __name__ == '__main__': + with open(r'..\data\train.json', 'r') as f: + data = json.load(f) + for d in data: + rule_label = [eval(x) for x in d['rule_label'].strip().split(' ')] + print(d['question']) + print(rule_label) + build_tree(copy.copy(rule_label)) + adjacency_matrix = build_adjacency_matrix(rule_label, symmetry=True) + print(adjacency_matrix) + print('===\n\n') diff --git a/src/rule/semQL.py b/src/rule/semQL.py new file mode 100644 index 0000000..38ca7a0 --- /dev/null +++ b/src/rule/semQL.py @@ -0,0 +1,399 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/24 +# @Author : Jiaqi&Zecheng +# @File : semQL.py +# @Software: PyCharm +""" + +Keywords = ['des', 'asc', 'and', 'or', 'sum', 'min', 'max', 'avg', 'none', '=', '!=', '<', '>', '<=', '>=', 'between', 'like', 'not_like'] + [ + 'in', 'not_in', 'count', 'intersect', 'union', 'except' +] + + +class Grammar(object): + def __init__(self, is_sketch=False): + self.begin = 0 + self.type_id = 0 + self.is_sketch = is_sketch + self.prod2id = {} + self.type2id = {} + self._init_grammar(Sel) + self._init_grammar(Root) + self._init_grammar(Sup) + self._init_grammar(Filter) + self._init_grammar(Order) + self._init_grammar(N) + self._init_grammar(Root1) + + if not self.is_sketch: + self._init_grammar(A) + + self._init_id2prod() + self.type2id[C] = self.type_id + self.type_id += 1 + self.type2id[T] = self.type_id + + def _init_grammar(self, Cls): + """ + get the production of class Cls + :param Cls: + :return: + """ + production = Cls._init_grammar() + for p in production: + self.prod2id[p] = self.begin + self.begin += 1 + self.type2id[Cls] = self.type_id + self.type_id += 1 + + def _init_id2prod(self): + self.id2prod = {} + for key, value in self.prod2id.items(): + self.id2prod[value] = key + + def get_production(self, Cls): + return Cls._init_grammar() + + +class Action(object): + def __init__(self): + self.pt = 0 + self.production = None + self.children = list() + + def get_next_action(self, is_sketch=False): + actions = list() + for x in self.production.split(' ')[1:]: + if x not in Keywords: + rule_type = eval(x) + if is_sketch: + if rule_type is not A: + actions.append(rule_type) + else: + actions.append(rule_type) + return actions + + def set_parent(self, parent): + self.parent = parent + + def add_children(self, child): + self.children.append(child) + + +class Root1(Action): + def __init__(self, id_c, parent=None): + super(Root1, self).__init__() + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + # TODO: should add Root grammar to this + self.grammar_dict = { + 0: 'Root1 intersect Root Root', + 1: 'Root1 union Root Root', + 2: 'Root1 except Root Root', + 3: 'Root1 Root', + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'Root1(' + str(self.id_c) + ')' + + def __repr__(self): + return 'Root1(' + str(self.id_c) + ')' + + +class Root(Action): + def __init__(self, id_c, parent=None): + super(Root, self).__init__() + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + # TODO: should add Root grammar to this + self.grammar_dict = { + 0: 'Root Sel Sup Filter', + 1: 'Root Sel Filter Order', + 2: 'Root Sel Sup', + 3: 'Root Sel Filter', + 4: 'Root Sel Order', + 5: 'Root Sel' + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'Root(' + str(self.id_c) + ')' + + def __repr__(self): + return 'Root(' + str(self.id_c) + ')' + + +class N(Action): + """ + Number of Columns + """ + def __init__(self, id_c, parent=None): + super(N, self).__init__() + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + self.grammar_dict = { + 0: 'N A', + 1: 'N A A', + 2: 'N A A A', + 3: 'N A A A A', + 4: 'N A A A A A' + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'N(' + str(self.id_c) + ')' + + def __repr__(self): + return 'N(' + str(self.id_c) + ')' + +class C(Action): + """ + Column + """ + def __init__(self, id_c, parent=None): + super(C, self).__init__() + self.parent = parent + self.id_c = id_c + self.production = 'C T' + self.table = None + + def __str__(self): + return 'C(' + str(self.id_c) + ')' + + def __repr__(self): + return 'C(' + str(self.id_c) + ')' + + +class T(Action): + """ + Table + """ + def __init__(self, id_c, parent=None): + super(T, self).__init__() + + self.parent = parent + self.id_c = id_c + self.production = 'T min' + self.table = None + + def __str__(self): + return 'T(' + str(self.id_c) + ')' + + def __repr__(self): + return 'T(' + str(self.id_c) + ')' + + +class A(Action): + """ + Aggregator + """ + def __init__(self, id_c, parent=None): + super(A, self).__init__() + + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + # TODO: should add Root grammar to this + self.grammar_dict = { + 0: 'A none C', + 1: 'A max C', + 2: "A min C", + 3: "A count C", + 4: "A sum C", + 5: "A avg C" + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'A(' + str(self.id_c) + ')' + + def __repr__(self): + return 'A(' + str(self.grammar_dict[self.id_c].split(' ')[1]) + ')' + + +class Sel(Action): + """ + Select + """ + def __init__(self, id_c, parent=None): + super(Sel, self).__init__() + + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + self.grammar_dict = { + 0: 'Sel N', + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'Sel(' + str(self.id_c) + ')' + + def __repr__(self): + return 'Sel(' + str(self.id_c) + ')' + +class Filter(Action): + """ + Filter + """ + def __init__(self, id_c, parent=None): + super(Filter, self).__init__() + + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + self.grammar_dict = { + # 0: "Filter 1" + 0: 'Filter and Filter Filter', + 1: 'Filter or Filter Filter', + 2: 'Filter = A', + 3: 'Filter != A', + 4: 'Filter < A', + 5: 'Filter > A', + 6: 'Filter <= A', + 7: 'Filter >= A', + 8: 'Filter between A', + 9: 'Filter like A', + 10: 'Filter not_like A', + # now begin root + 11: 'Filter = A Root', + 12: 'Filter < A Root', + 13: 'Filter > A Root', + 14: 'Filter != A Root', + 15: 'Filter between A Root', + 16: 'Filter >= A Root', + 17: 'Filter <= A Root', + # now for In + 18: 'Filter in A Root', + 19: 'Filter not_in A Root' + + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'Filter(' + str(self.id_c) + ')' + + def __repr__(self): + return 'Filter(' + str(self.grammar_dict[self.id_c]) + ')' + + +class Sup(Action): + """ + Superlative + """ + def __init__(self, id_c, parent=None): + super(Sup, self).__init__() + + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + self.grammar_dict = { + 0: 'Sup des A', + 1: 'Sup asc A', + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'Sup(' + str(self.id_c) + ')' + + def __repr__(self): + return 'Sup(' + str(self.id_c) + ')' + + +class Order(Action): + """ + Order + """ + def __init__(self, id_c, parent=None): + super(Order, self).__init__() + + self.parent = parent + self.id_c = id_c + self._init_grammar() + self.production = self.grammar_dict[id_c] + + @classmethod + def _init_grammar(self): + self.grammar_dict = { + 0: 'Order des A', + 1: 'Order asc A', + } + self.production_id = {} + for id_x, value in enumerate(self.grammar_dict.values()): + self.production_id[value] = id_x + + return self.grammar_dict.values() + + def __str__(self): + return 'Order(' + str(self.id_c) + ')' + + def __repr__(self): + return 'Order(' + str(self.id_c) + ')' + + +if __name__ == '__main__': + print(list(Root._init_grammar())) diff --git a/src/rule/sem_utils.py b/src/rule/sem_utils.py new file mode 100644 index 0000000..c85c334 --- /dev/null +++ b/src/rule/sem_utils.py @@ -0,0 +1,321 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/27 +# @Author : Jiaqi&Zecheng +# @File : sem_utils.py +# @Software: PyCharm +""" + +import os +import json +import argparse +import re as regex +from nltk.stem import WordNetLemmatizer +from pattern.en import lemma +wordnet_lemmatizer = WordNetLemmatizer() + + +def load_dataSets(args): + with open(args.input_path, 'r') as f: + datas = json.load(f) + with open(os.path.join(args.data_path, 'tables.json'), 'r', encoding='utf8') as f: + table_datas = json.load(f) + schemas = dict() + for i in range(len(table_datas)): + schemas[table_datas[i]['db_id']] = table_datas[i] + return datas, schemas + + +def partial_match(query, table_name): + query = [lemma(x) for x in query] + table_name = [lemma(x) for x in table_name] + if query in table_name: + return True + return False + + +def is_partial_match(query, table_names): + query = lemma(query) + table_names = [[lemma(x) for x in names.split(' ') ] for names in table_names] + same_count = 0 + result = None + for names in table_names: + if query in names: + same_count += 1 + result = names + return result if same_count == 1 else False + + +def multi_option(question, q_ind, names, N): + for i in range(q_ind + 1, q_ind + N + 1): + if i < len(question): + re = is_partial_match(question[i][0], names) + if re is not False: + return re + return False + + +def multi_equal(question, q_ind, names, N): + for i in range(q_ind + 1, q_ind + N + 1): + if i < len(question): + if question[i] == names: + return i + return False + + +def random_choice(question_arg, question_arg_type, names, ground_col_labels, q_ind, N, origin_name): + # first try if there are other table + for t_ind, t_val in enumerate(question_arg_type): + if t_val == ['table']: + return names[origin_name.index(question_arg[t_ind])] + for i in range(q_ind + 1, q_ind + N + 1): + if i < len(question_arg): + if len(ground_col_labels) == 0: + for n in names: + if partial_match(question_arg[i][0], n) is True: + return n + else: + for n_id, n in enumerate(names): + if n_id in ground_col_labels and partial_match(question_arg[i][0], n) is True: + return n + if len(ground_col_labels) > 0: + return names[ground_col_labels[0]] + else: + return names[0] + + +def find_table(cur_table, origin_table_names, question_arg_type, question_arg): + h_table = None + for i in range(len(question_arg_type))[::-1]: + if question_arg_type[i] == ['table']: + h_table = question_arg[i] + h_table = origin_table_names.index(h_table) + if h_table != cur_table: + break + if h_table != cur_table: + return h_table + + # find partial + for i in range(len(question_arg_type))[::-1]: + if question_arg_type[i] == ['NONE']: + for t_id, table_name in enumerate(origin_table_names): + if partial_match(question_arg[i], table_name) is True and t_id != h_table: + return t_id + + # random return + for i in range(len(question_arg_type))[::-1]: + if question_arg_type[i] == ['table']: + h_table = question_arg[i] + h_table = origin_table_names.index(h_table) + return h_table + + return cur_table + + +def alter_not_in(datas, schemas): + for d in datas: + if 'Filter(19)' in d['model_result']: + current_table = schemas[d['db_id']] + current_table['schema_content_clean'] = [x[1] for x in current_table['column_names']] + current_table['col_table'] = [col[0] for col in current_table['column_names']] + origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in + d['table_names']] + question_arg_type = d['question_arg_type'] + question_arg = d['question_arg'] + pred_label = d['model_result'].split(' ') + + # get potiantial table + cur_table = None + for label_id, label_val in enumerate(pred_label): + if label_val in ['Filter(19)']: + cur_table = int(pred_label[label_id - 1][2:-1]) + break + + h_table = find_table(cur_table, origin_table_names, question_arg_type, question_arg) + + for label_id, label_val in enumerate(pred_label): + if label_val in ['Filter(19)']: + for primary in current_table['primary_keys']: + if int(current_table['col_table'][primary]) == int(pred_label[label_id - 1][2:-1]): + pred_label[label_id + 2] = 'C(' + str( + d['col_set'].index(current_table['schema_content_clean'][primary])) + ')' + break + for pair in current_table['foreign_keys']: + if int(current_table['col_table'][pair[0]]) == h_table and d['col_set'].index( + current_table['schema_content_clean'][pair[1]]) == int(pred_label[label_id + 2][2:-1]): + pred_label[label_id + 8] = 'C(' + str( + d['col_set'].index(current_table['schema_content_clean'][pair[0]])) + ')' + pred_label[label_id + 9] = 'T(' + str(h_table) + ')' + break + elif int(current_table['col_table'][pair[1]]) == h_table and d['col_set'].index( + current_table['schema_content_clean'][pair[0]]) == int(pred_label[label_id + 2][2:-1]): + pred_label[label_id + 8] = 'C(' + str( + d['col_set'].index(current_table['schema_content_clean'][pair[1]])) + ')' + pred_label[label_id + 9] = 'T(' + str(h_table) + ')' + break + pred_label[label_id + 3] = pred_label[label_id - 1] + + d['model_result'] = " ".join(pred_label) + + +def alter_inter(datas): + for d in datas: + if 'Filter(0)' in d['model_result']: + now_result = d['model_result'].split(' ') + index = now_result.index('Filter(0)') + c1 = None + c2 = None + for i in range(index + 1, len(now_result)): + if c1 is None and 'C(' in now_result[i]: + c1 = now_result[i] + elif c1 is not None and c2 is None and 'C(' in now_result[i]: + c2 = now_result[i] + + if c1 != c2 or c1 is None or c2 is None: + continue + replace_result = ['Root1(0)'] + now_result[1:now_result.index('Filter(0)')] + for r_id, r_val in enumerate(now_result[now_result.index('Filter(0)') + 2:]): + if 'Filter' in r_val: + break + + replace_result = replace_result + now_result[now_result.index('Filter(0)') + 1:r_id + now_result.index( + 'Filter(0)') + 2] + replace_result = replace_result + now_result[1:now_result.index('Filter(0)')] + + replace_result = replace_result + now_result[r_id + now_result.index('Filter(0)') + 2:] + replace_result = " ".join(replace_result) + d['model_result'] = replace_result + + +def alter_column0(datas): + """ + Attach column * table + :return: model_result_replace + """ + zero_count = 0 + count = 0 + result = [] + for d in datas: + if 'C(0)' in d['model_result']: + pattern = regex.compile('C\(.*?\) T\(.*?\)') + result_pattern = list(set(pattern.findall(d['model_result']))) + ground_col_labels = [] + for pa in result_pattern: + pa = pa.split(' ') + if pa[0] != 'C(0)': + index = int(pa[1][2:-1]) + ground_col_labels.append(index) + + ground_col_labels = list(set(ground_col_labels)) + question_arg_type = d['question_arg_type'] + question_arg = d['question_arg'] + table_names = [[lemma(x) for x in names.split(' ')] for names in d['table_names']] + origin_table_names = [[wordnet_lemmatizer.lemmatize(x.lower()) for x in names.split(' ')] for names in + d['table_names']] + count += 1 + easy_flag = False + for q_ind, q in enumerate(d['question_arg']): + q = [lemma(x) for x in q] + q_str = " ".join(" ".join(x) for x in d['question_arg']) + if 'how many' in q_str or 'number of' in q_str or 'count of' in q_str: + easy_flag = True + if easy_flag: + # check for the last one is a table word + for q_ind, q in enumerate(d['question_arg']): + if (q_ind > 0 and q == ['many'] and d['question_arg'][q_ind - 1] == ['how']) or ( + q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['number']) or ( + q_ind > 0 and q == ['of'] and d['question_arg'][q_ind - 1] == ['count']): + re = multi_equal(question_arg_type, q_ind, ['table'], 2) + if re is not False: + # This step work for the number of [table] example + table_result = table_names[origin_table_names.index(question_arg[re])] + result.append((d['query'], d['question'], table_result, d)) + break + else: + re = multi_option(question_arg, q_ind, d['table_names'], 2) + if re is not False: + table_result = re + result.append((d['query'], d['question'], table_result, d)) + pass + else: + re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type)) + if re is not False: + # This step work for the number of [table] example + table_result = table_names[origin_table_names.index(question_arg[re])] + result.append((d['query'], d['question'], table_result, d)) + break + pass + table_result = random_choice(question_arg=question_arg, + question_arg_type=question_arg_type, + names=table_names, + ground_col_labels=ground_col_labels, q_ind=q_ind, N=2, + origin_name=origin_table_names) + result.append((d['query'], d['question'], table_result, d)) + + zero_count += 1 + break + + else: + M_OP = False + for q_ind, q in enumerate(d['question_arg']): + if M_OP is False and q in [['than'], ['least'], ['most'], ['msot'], ['fewest']] or \ + question_arg_type[q_ind] == ['M_OP']: + M_OP = True + re = multi_equal(question_arg_type, q_ind, ['table'], 3) + if re is not False: + # This step work for the number of [table] example + table_result = table_names[origin_table_names.index(question_arg[re])] + result.append((d['query'], d['question'], table_result, d)) + break + else: + re = multi_option(question_arg, q_ind, d['table_names'], 3) + if re is not False: + table_result = re + # print(table_result) + result.append((d['query'], d['question'], table_result, d)) + pass + else: + # zero_count += 1 + re = multi_equal(question_arg_type, q_ind, ['table'], len(question_arg_type)) + if re is not False: + # This step work for the number of [table] example + table_result = table_names[origin_table_names.index(question_arg[re])] + result.append((d['query'], d['question'], table_result, d)) + break + + table_result = random_choice(question_arg=question_arg, + question_arg_type=question_arg_type, + names=table_names, + ground_col_labels=ground_col_labels, q_ind=q_ind, N=2, + origin_name=origin_table_names) + result.append((d['query'], d['question'], table_result, d)) + + pass + if M_OP is False: + table_result = random_choice(question_arg=question_arg, + question_arg_type=question_arg_type, + names=table_names, ground_col_labels=ground_col_labels, q_ind=q_ind, + N=2, + origin_name=origin_table_names) + result.append((d['query'], d['question'], table_result, d)) + + for re in result: + table_names = [[lemma(x) for x in names.split(' ')] for names in re[3]['table_names']] + origin_table_names = [[x for x in names.split(' ')] for names in re[3]['table_names']] + if re[2] in table_names: + re[3]['rule_count'] = table_names.index(re[2]) + else: + re[3]['rule_count'] = origin_table_names.index(re[2]) + + for data in datas: + if 'rule_count' in data: + str_replace = 'C(0) T(' + str(data['rule_count']) + ')' + replace_result = regex.sub('C\(0\) T\(.\)', str_replace, data['model_result']) + data['model_result_replace'] = replace_result + else: + data['model_result_replace'] = data['model_result'] + + diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..c6b6da1 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,348 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : utils.py +# @Software: PyCharm +""" + +import json +import time + +import copy +import numpy as np +import os +import torch +from nltk.stem import WordNetLemmatizer + +from src.dataset import Example +from src.rule import lf +from src.rule.semQL import Sup, Sel, Order, Root, Filter, A, N, C, T, Root1 + +wordnet_lemmatizer = WordNetLemmatizer() + + +def load_word_emb(file_name, use_small=False): + print ('Loading word embedding from %s'%file_name) + ret = {} + with open(file_name) as inf: + for idx, line in enumerate(inf): + if (use_small and idx >= 500000): + break + info = line.strip().split(' ') + if info[0].lower() not in ret: + ret[info[0]] = np.array(list(map(lambda x:float(x), info[1:]))) + return ret + +def lower_keys(x): + if isinstance(x, list): + return [lower_keys(v) for v in x] + elif isinstance(x, dict): + return dict((k.lower(), lower_keys(v)) for k, v in x.items()) + else: + return x + +def get_table_colNames(tab_ids, tab_cols): + table_col_dict = {} + for ci, cv in zip(tab_ids, tab_cols): + if ci != -1: + table_col_dict[ci] = table_col_dict.get(ci, []) + cv + result = [] + for ci in range(len(table_col_dict)): + result.append(table_col_dict[ci]) + return result + +def get_col_table_dict(tab_cols, tab_ids, sql): + table_dict = {} + for c_id, c_v in enumerate(sql['col_set']): + for cor_id, cor_val in enumerate(tab_cols): + if c_v == cor_val: + table_dict[tab_ids[cor_id]] = table_dict.get(tab_ids[cor_id], []) + [c_id] + + col_table_dict = {} + for key_item, value_item in table_dict.items(): + for value in value_item: + col_table_dict[value] = col_table_dict.get(value, []) + [key_item] + col_table_dict[0] = [x for x in range(len(table_dict) - 1)] + return col_table_dict + + +def schema_linking(question_arg, question_arg_type, one_hot_type, col_set_type, col_set_iter, sql): + + for count_q, t_q in enumerate(question_arg_type): + t = t_q[0] + if t == 'NONE': + continue + elif t == 'table': + one_hot_type[count_q][0] = 1 + question_arg[count_q] = ['table'] + question_arg[count_q] + elif t == 'col': + one_hot_type[count_q][1] = 1 + try: + col_set_type[col_set_iter.index(question_arg[count_q])][1] = 5 + question_arg[count_q] = ['column'] + question_arg[count_q] + except: + print(col_set_iter, question_arg[count_q]) + raise RuntimeError("not in col set") + elif t == 'agg': + one_hot_type[count_q][2] = 1 + elif t == 'MORE': + one_hot_type[count_q][3] = 1 + elif t == 'MOST': + one_hot_type[count_q][4] = 1 + elif t == 'value': + one_hot_type[count_q][5] = 1 + question_arg[count_q] = ['value'] + question_arg[count_q] + else: + if len(t_q) == 1: + for col_probase in t_q: + if col_probase == 'asd': + continue + try: + col_set_type[sql['col_set'].index(col_probase)][2] = 5 + question_arg[count_q] = ['value'] + question_arg[count_q] + except: + print(sql['col_set'], col_probase) + raise RuntimeError('not in col') + one_hot_type[count_q][5] = 1 + else: + for col_probase in t_q: + if col_probase == 'asd': + continue + col_set_type[sql['col_set'].index(col_probase)][3] += 1 + +def process(sql, table): + + process_dict = {} + + origin_sql = sql['question_toks'] + table_names = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in table['table_names']] + + sql['pre_sql'] = copy.deepcopy(sql) + + tab_cols = [col[1] for col in table['column_names']] + tab_ids = [col[0] for col in table['column_names']] + + col_set_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(' ')] for x in sql['col_set']] + col_iter = [[wordnet_lemmatizer.lemmatize(v).lower() for v in x.split(" ")] for x in tab_cols] + q_iter_small = [wordnet_lemmatizer.lemmatize(x).lower() for x in origin_sql] + question_arg = copy.deepcopy(sql['question_arg']) + question_arg_type = sql['question_arg_type'] + one_hot_type = np.zeros((len(question_arg_type), 6)) + + col_set_type = np.zeros((len(col_set_iter), 4)) + + process_dict['col_set_iter'] = col_set_iter + process_dict['q_iter_small'] = q_iter_small + process_dict['col_set_type'] = col_set_type + process_dict['question_arg'] = question_arg + process_dict['question_arg_type'] = question_arg_type + process_dict['one_hot_type'] = one_hot_type + process_dict['tab_cols'] = tab_cols + process_dict['tab_ids'] = tab_ids + process_dict['col_iter'] = col_iter + process_dict['table_names'] = table_names + + return process_dict + +def is_valid(rule_label, col_table_dict, sql): + try: + lf.build_tree(copy.copy(rule_label)) + except: + print(rule_label) + + flag = False + for r_id, rule in enumerate(rule_label): + if type(rule) == C: + try: + assert rule_label[r_id + 1].id_c in col_table_dict[rule.id_c], print(sql['question']) + except: + flag = True + print(sql['question']) + return flag is False + + +def to_batch_seq(sql_data, table_data, idxes, st, ed, + is_train=True): + """ + + :return: + """ + examples = [] + + for i in range(st, ed): + sql = sql_data[idxes[i]] + table = table_data[sql['db_id']] + + process_dict = process(sql, table) + + for c_id, col_ in enumerate(process_dict['col_set_iter']): + for q_id, ori in enumerate(process_dict['q_iter_small']): + if ori in col_: + process_dict['col_set_type'][c_id][0] += 1 + + schema_linking(process_dict['question_arg'], process_dict['question_arg_type'], + process_dict['one_hot_type'], process_dict['col_set_type'], process_dict['col_set_iter'], sql) + + col_table_dict = get_col_table_dict(process_dict['tab_cols'], process_dict['tab_ids'], sql) + table_col_name = get_table_colNames(process_dict['tab_ids'], process_dict['col_iter']) + + process_dict['col_set_iter'][0] = ['count', 'number', 'many'] + + rule_label = None + if 'rule_label' in sql: + rule_label = [eval(x) for x in sql['rule_label'].strip().split(' ')] + if is_valid(rule_label, col_table_dict=col_table_dict, sql=sql) is False: + continue + + example = Example( + src_sent=process_dict['question_arg'], + col_num=len(process_dict['col_set_iter']), + vis_seq=(sql['question'], process_dict['col_set_iter'], sql['query']), + tab_cols=process_dict['col_set_iter'], + sql=sql['query'], + one_hot_type=process_dict['one_hot_type'], + col_hot_type=process_dict['col_set_type'], + table_names=process_dict['table_names'], + table_len=len(process_dict['table_names']), + col_table_dict=col_table_dict, + cols=process_dict['tab_cols'], + table_col_name=table_col_name, + table_col_len=len(table_col_name), + tokenized_src_sent=process_dict['col_set_type'], + tgt_actions=rule_label + ) + example.sql_json = copy.deepcopy(sql) + examples.append(example) + + if is_train: + examples.sort(key=lambda e: -len(e.src_sent)) + return examples + else: + return examples + +def epoch_train(model, optimizer, batch_size, sql_data, table_data, + args, epoch=0, loss_epoch_threshold=20, sketch_loss_coefficient=0.2): + model.train() + # shuffe + perm=np.random.permutation(len(sql_data)) + cum_loss = 0.0 + st = 0 + while st < len(sql_data): + ed = st+batch_size if st+batch_size < len(perm) else len(perm) + examples = to_batch_seq(sql_data, table_data, perm, st, ed) + optimizer.zero_grad() + + score = model.forward(examples) + loss_sketch = -score[0] + loss_lf = -score[1] + + loss_sketch = torch.mean(loss_sketch) + loss_lf = torch.mean(loss_lf) + + if epoch > loss_epoch_threshold: + loss = loss_lf + sketch_loss_coefficient * loss_sketch + else: + loss = loss_lf + loss_sketch + + loss.backward() + if args.clip_grad > 0.: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) + optimizer.step() + cum_loss += loss.data.cpu().numpy()*(ed - st) + st = ed + return cum_loss / len(sql_data) + +def epoch_acc(model, batch_size, sql_data, table_data, beam_size=3): + model.eval() + perm = list(range(len(sql_data))) + st = 0 + + json_datas = [] + while st < len(sql_data): + ed = st+batch_size if st+batch_size < len(perm) else len(perm) + examples = to_batch_seq(sql_data, table_data, perm, st, ed, + is_train=False) + for example in examples: + results_all = model.parse(example, beam_size=beam_size) + results = results_all[0] + list_preds = [] + try: + + pred = " ".join([str(x) for x in results[0].actions]) + for x in results: + list_preds.append(" ".join(str(x.actions))) + except Exception as e: + # print('Epoch Acc: ', e) + # print(results) + # print(results_all) + pred = "" + + simple_json = example.sql_json['pre_sql'] + + simple_json['sketch_result'] = " ".join(str(x) for x in results_all[1]) + simple_json['model_result'] = pred + + json_datas.append(simple_json) + st = ed + return json_datas + +def eval_acc(preds, sqls): + sketch_correct, best_correct = 0, 0 + for i, (pred, sql) in enumerate(zip(preds, sqls)): + if pred['model_result'] == sql['rule_label']: + best_correct += 1 + print(best_correct / len(preds)) + return best_correct / len(preds) + + +def load_data_new(sql_path, table_data, use_small=False): + sql_data = [] + + print("Loading data from %s" % sql_path) + with open(sql_path) as inf: + data = lower_keys(json.load(inf)) + sql_data += data + + table_data_new = {table['db_id']: table for table in table_data} + + if use_small: + return sql_data[:80], table_data_new + else: + return sql_data, table_data_new + + +def load_dataset(dataset_dir, use_small=False): + print("Loading from datasets...") + + TABLE_PATH = os.path.join(dataset_dir, "tables.json") + TRAIN_PATH = os.path.join(dataset_dir, "train.json") + DEV_PATH = os.path.join(dataset_dir, "dev.json") + with open(TABLE_PATH) as inf: + print("Loading data from %s"%TABLE_PATH) + table_data = json.load(inf) + + train_sql_data, train_table_data = load_data_new(TRAIN_PATH, table_data, use_small=use_small) + val_sql_data, val_table_data = load_data_new(DEV_PATH, table_data, use_small=use_small) + + return train_sql_data, train_table_data, val_sql_data, val_table_data + + +def save_checkpoint(model, checkpoint_name): + torch.save(model.state_dict(), checkpoint_name) + + +def save_args(args, path): + with open(path, 'w') as f: + f.write(json.dumps(vars(args), indent=4)) + +def init_log_checkpoint_path(args): + save_path = args.save + dir_name = save_path + str(int(time.time())) + save_path = os.path.join(os.path.curdir, 'saved_model', dir_name) + if os.path.exists(save_path) is False: + os.makedirs(save_path) + return save_path \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..1a3548c --- /dev/null +++ b/train.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# -*- coding: utf-8 -*- +""" +# @Time : 2019/5/25 +# @Author : Jiaqi&Zecheng +# @File : train.py +# @Software: PyCharm +""" + +import time +import traceback + +import os +import torch +import torch.optim as optim +import tqdm +import copy + +from src import args as arg +from src import utils +from src.models.model import IRNet +from src.rule import semQL + + +def train(args): + """ + :param args: + :return: + """ + + grammar = semQL.Grammar() + sql_data, table_data, val_sql_data,\ + val_table_data= utils.load_dataset(args.dataset, use_small=args.toy) + + model = IRNet(args, grammar) + + + if args.cuda: model.cuda() + + # now get the optimizer + optimizer_cls = eval('torch.optim.%s' % args.optimizer) + optimizer = optimizer_cls(model.parameters(), lr=args.lr) + print('Enable Learning Rate Scheduler: ', args.lr_scheduler) + if args.lr_scheduler: + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar) + else: + scheduler = None + + print('Loss epoch threshold: %d' % args.loss_epoch_threshold) + print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient) + + if args.load_model: + print('load pretrained model from %s'% (args.load_model)) + pretrained_model = torch.load(args.load_model, + map_location=lambda storage, loc: storage) + pretrained_modeled = copy.deepcopy(pretrained_model) + for k in pretrained_model.keys(): + if k not in model.state_dict().keys(): + del pretrained_modeled[k] + + model.load_state_dict(pretrained_modeled) + + model.word_emb = utils.load_word_emb(args.glove_embed_path) + # begin train + + model_save_path = utils.init_log_checkpoint_path(args) + utils.save_args(args, os.path.join(model_save_path, 'config.json')) + best_dev_acc = .0 + + try: + with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd: + for epoch in tqdm.tqdm(range(args.epoch)): + if args.lr_scheduler: + scheduler.step() + epoch_begin = time.time() + loss = utils.epoch_train(model, optimizer, args.batch_size, sql_data, table_data, args, + loss_epoch_threshold=args.loss_epoch_threshold, + sketch_loss_coefficient=args.sketch_loss_coefficient) + epoch_end = time.time() + json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, + beam_size=args.beam_size) + acc = utils.eval_acc(json_datas, val_sql_data) + + if acc > best_dev_acc: + utils.save_checkpoint(model, os.path.join(model_save_path, 'best_model.model')) + best_dev_acc = acc + utils.save_checkpoint(model, os.path.join(model_save_path, '{%s}_{%s}.model') % (epoch, acc)) + + log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % ( + epoch + 1, loss, acc, acc, epoch_end - epoch_begin) + tqdm.tqdm.write(log_str) + epoch_fd.write(log_str) + epoch_fd.flush() + except Exception as e: + # Save model + utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model')) + print(e) + tb = traceback.format_exc() + print(tb) + else: + utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model')) + json_datas = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data, + beam_size=args.beam_size) + acc = utils.eval_acc(json_datas, val_sql_data) + + print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (acc, acc, acc,)) + + +if __name__ == '__main__': + arg_parser = arg.init_arg_parser() + args = arg.init_config(arg_parser) + print(args) + train(args) \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..7f10800 --- /dev/null +++ b/train.sh @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +#!/bin/bash + +devices=$1 +save_name=$2 + +CUDA_VISIBLE_DEVICES=$devices nohup python -u train.py --dataset ./data \ +--glove_embed_path ./data/glove.42B.300d.txt \ +--cuda \ +--epoch 50 \ +--loss_epoch_threshold 50 \ +--sketch_loss_coefficie 1.0 \ +--beam_size 1 \ +--seed 90 \ +--save ${save_name} \ +--embed_size 300 \ +--sentence_features \ +--column_pointer \ +--hidden_size 300 \ +--lr_scheduler \ +--lr_scheduler_gammar 0.5 \ +--att_vec_size 300 > ${save_name}".log" &