This commit is contained in:
haoyanliu 2019-10-17 21:28:27 +08:00
Родитель 813c6b3b1d
Коммит 4f373bb746
12 изменённых файлов: 36 добавлений и 13 удалений

Просмотреть файл

@ -11,7 +11,7 @@ Experimental results on both the original and the re-split Spider dataset show t
#### Enviroment Setup
1. The baseline codes use Python 2.7 and Pytorch 0.2.0 GPU. Install Python dependency: `pip install -r requirements.txt`
Alternatively use docker: `docker push buaa1156/py27torch0.2cuda8vim:latest`
Alternatively use docker: `docker pull buaa1156/py27torch0.2cuda8vim:latest`
2. The preprocess scripts use Python >= 3.5.
@ -46,6 +46,11 @@ Alternatively use docker: `docker push buaa1156/py27torch0.2cuda8vim:latest`
- `DATE`: automatically set as local time while `training` and manually assigned while `testing`
## Question
If you have any question, please go ahead and [open an issue](https://github.com/microsoft/EMNLP2019-Adjective-Knowledge-for-Text-to-SQL/issues).
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to

Просмотреть файл

@ -127,7 +127,7 @@ class CondPredictor(nn.Module):
idx = perm[b]
for i, gt_col in enumerate(chosen_col_gt[b - st]):
dirc_feat = emb_layer.get_direction_feature(max_q_len, idx, gt_col, train)
if self.feats_format == 'mask':
if self.feats_format == 'direct':
# [max_len] (-1/0/1)
mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0])
mask_i = mask.cpu().data.numpy()[0]

Просмотреть файл

@ -164,7 +164,7 @@ class GroupPredictor(nn.Module):
idx = perm[b]
gt_col = chosen_col_gt[b - st]
dirc_feat = emb_layer.get_direction_feature(max_q_len, idx, gt_col, train)
if self.feats_format == 'mask':
if self.feats_format == 'direct':
# [max_len] (-1/0/1)
mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0])
mask_i = mask.cpu().data.numpy()[0]

Просмотреть файл

@ -134,7 +134,7 @@ class OrderPredictor(nn.Module):
idx = perm[b]
gt_col = chosen_col_gt[b - st]
dirc_feat = emb_layer.get_direction_feature(max_q_len, idx, gt_col, train)
if self.feats_format == 'mask':
if self.feats_format == 'direct':
# [max_len] (-1/0/1)
mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0])
mask_i = mask.cpu().data.numpy()[0]

Просмотреть файл

@ -53,7 +53,7 @@ class WordEmbedding(nn.Module):
return [dirc_feats, pos_feats, neg_feats]
def get_direction_feature(self, max_q_len, idx, gt_col, train=True):
if self.feats_format == 'mask':
if self.feats_format == 'direct':
return self.get_direction_feature_mask(max_q_len, idx, gt_col, train)
dirc_feats = np.zeros([max_q_len, len(self.pos_feats)], dtype=np.float32)

Просмотреть файл

@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
'''
python3 preprocess_direction_features.py syntaxSQL|SQLNet singletable|resplitdata weighted|direct
'''
@ -34,6 +37,9 @@ if baseline == 'syntaxSQL':
elif baseline == 'SQLNet':
data_root = r'SQLNet/data/{}/feats'.format(data_type)
if not os.path.exists(data_root):
os.makedirs(data_root)
shutil.copy(os.path.join(know_root, 'pos_feat.json'), os.path.join(data_root, 'pos_feat.json'))
shutil.copy(os.path.join(know_root, 'neg_feat.json'), os.path.join(data_root, 'neg_feat.json'))

Просмотреть файл

@ -42,8 +42,8 @@ if not os.path.exists(train_tgt_path):
train_data = json.load(open(train_data_path))
history_option = "full"
if len(sys.argv) > 2:
history_option = sys.argv[2]
if len(sys.argv) > 3:
history_option = sys.argv[3]
OLD_WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
NEW_WHERE_OPS = ('=','>','<','>=','<=','!=','like','not in','in','between','is')

Просмотреть файл

@ -82,7 +82,7 @@ class DesAscLimitPredictor(nn.Module):
hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)
# dat_score: (B, 4)
if self.feats_format == 'mask':
if self.feats_format == 'direct':
# [B, max_len] (-1/0/1)
masks = (att_prob_qc * dirc_feats[0]).sum(1)

Просмотреть файл

@ -122,7 +122,7 @@ class OpPredictor(nn.Module):
# Compute prediction scores
# op_score: (B, 10)
if self.feats_format == 'mask':
if self.feats_format == 'direct':
# [B, max_len] (-1/0/1)
masks = (att_prob_qc * dirc_feats[0]).sum(1)

Просмотреть файл

@ -19,7 +19,6 @@ from models.multisql_predictor import MultiSqlPredictor
from models.root_teminal_predictor import RootTeminalPredictor
from models.andor_predictor import AndOrPredictor
from models.op_predictor import OpPredictor
from preprocess_train_dev_data import index_to_column_name
SQL_OPS = ('none','intersect', 'union', 'except')
@ -53,6 +52,12 @@ class Stack:
return self.items.insert(i,x)
def index_to_column_name(index, table):
column_name = table["column_names"][index][1]
table_index = table["column_names"][index][0]
table_name = table["table_names"][table_index]
return table_name, column_name, index
def to_batch_tables(tables, B, table_type):
# col_lens = []
col_seq = []

Просмотреть файл

@ -4,7 +4,14 @@ import json
import numpy as np
import os
import signal
from preprocess_train_dev_data import get_table_dict
def get_table_dict(table_data_path):
data = json.load(open(table_data_path))
table = dict()
for item in data:
table[item["db_id"]] = item
return table
def load_train_dev_dataset(component,train_dev,history, root):

Просмотреть файл

@ -85,7 +85,7 @@ class WordEmbedding(nn.Module):
return [dirc_feats, pos_feats, neg_feats]
def get_direction_feature(self, q_seq, perm, st, ed, gt_col, component=None, train=True):
if self.feats_format == 'mask':
if self.feats_format == 'direct':
return self.get_direction_feature_mask(q_seq, perm, st, ed, gt_col, component, train)
B = len(q_seq)
@ -116,7 +116,7 @@ class WordEmbedding(nn.Module):
return dirc_feats
def get_direction_feature_pred(self, B, q_seq, idx, gt_col):
if self.feats_format == 'mask':
if self.feats_format == 'direct':
return self.get_direction_feature_pred_mask(B, q_seq, idx, gt_col)
dirc_feats = np.zeros([B, len(q_seq[0]) + 2, len(self.pos_feats)], dtype=np.float32)