init commit
This commit is contained in:
Родитель
813c6b3b1d
Коммит
4f373bb746
|
@ -11,7 +11,7 @@ Experimental results on both the original and the re-split Spider dataset show t
|
|||
#### Enviroment Setup
|
||||
|
||||
1. The baseline codes use Python 2.7 and Pytorch 0.2.0 GPU. Install Python dependency: `pip install -r requirements.txt`
|
||||
Alternatively use docker: `docker push buaa1156/py27torch0.2cuda8vim:latest`
|
||||
Alternatively use docker: `docker pull buaa1156/py27torch0.2cuda8vim:latest`
|
||||
2. The preprocess scripts use Python >= 3.5.
|
||||
|
||||
|
||||
|
@ -46,6 +46,11 @@ Alternatively use docker: `docker push buaa1156/py27torch0.2cuda8vim:latest`
|
|||
- `DATE`: automatically set as local time while `training` and manually assigned while `testing`
|
||||
|
||||
|
||||
## Question
|
||||
|
||||
If you have any question, please go ahead and [open an issue](https://github.com/microsoft/EMNLP2019-Adjective-Knowledge-for-Text-to-SQL/issues).
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to
|
||||
|
|
|
@ -127,7 +127,7 @@ class CondPredictor(nn.Module):
|
|||
idx = perm[b]
|
||||
for i, gt_col in enumerate(chosen_col_gt[b - st]):
|
||||
dirc_feat = emb_layer.get_direction_feature(max_q_len, idx, gt_col, train)
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
# [max_len] (-1/0/1)
|
||||
mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0])
|
||||
mask_i = mask.cpu().data.numpy()[0]
|
||||
|
|
|
@ -164,7 +164,7 @@ class GroupPredictor(nn.Module):
|
|||
idx = perm[b]
|
||||
gt_col = chosen_col_gt[b - st]
|
||||
dirc_feat = emb_layer.get_direction_feature(max_q_len, idx, gt_col, train)
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
# [max_len] (-1/0/1)
|
||||
mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0])
|
||||
mask_i = mask.cpu().data.numpy()[0]
|
||||
|
|
|
@ -134,7 +134,7 @@ class OrderPredictor(nn.Module):
|
|||
idx = perm[b]
|
||||
gt_col = chosen_col_gt[b - st]
|
||||
dirc_feat = emb_layer.get_direction_feature(max_q_len, idx, gt_col, train)
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
# [max_len] (-1/0/1)
|
||||
mask = (att_prob_qc[b - st, gt_col] * dirc_feat[0])
|
||||
mask_i = mask.cpu().data.numpy()[0]
|
||||
|
|
|
@ -53,7 +53,7 @@ class WordEmbedding(nn.Module):
|
|||
return [dirc_feats, pos_feats, neg_feats]
|
||||
|
||||
def get_direction_feature(self, max_q_len, idx, gt_col, train=True):
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
return self.get_direction_feature_mask(max_q_len, idx, gt_col, train)
|
||||
|
||||
dirc_feats = np.zeros([max_q_len, len(self.pos_feats)], dtype=np.float32)
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
'''
|
||||
python3 preprocess_direction_features.py syntaxSQL|SQLNet singletable|resplitdata weighted|direct
|
||||
'''
|
||||
|
@ -34,6 +37,9 @@ if baseline == 'syntaxSQL':
|
|||
elif baseline == 'SQLNet':
|
||||
data_root = r'SQLNet/data/{}/feats'.format(data_type)
|
||||
|
||||
if not os.path.exists(data_root):
|
||||
os.makedirs(data_root)
|
||||
|
||||
shutil.copy(os.path.join(know_root, 'pos_feat.json'), os.path.join(data_root, 'pos_feat.json'))
|
||||
shutil.copy(os.path.join(know_root, 'neg_feat.json'), os.path.join(data_root, 'neg_feat.json'))
|
||||
|
||||
|
|
|
@ -42,8 +42,8 @@ if not os.path.exists(train_tgt_path):
|
|||
|
||||
train_data = json.load(open(train_data_path))
|
||||
history_option = "full"
|
||||
if len(sys.argv) > 2:
|
||||
history_option = sys.argv[2]
|
||||
if len(sys.argv) > 3:
|
||||
history_option = sys.argv[3]
|
||||
|
||||
OLD_WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||||
NEW_WHERE_OPS = ('=','>','<','>=','<=','!=','like','not in','in','between','is')
|
||||
|
|
|
@ -82,7 +82,7 @@ class DesAscLimitPredictor(nn.Module):
|
|||
hs_weighted = (hs_enc * att_prob_hc.unsqueeze(2)).sum(1)
|
||||
# dat_score: (B, 4)
|
||||
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
# [B, max_len] (-1/0/1)
|
||||
masks = (att_prob_qc * dirc_feats[0]).sum(1)
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ class OpPredictor(nn.Module):
|
|||
# Compute prediction scores
|
||||
# op_score: (B, 10)
|
||||
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
# [B, max_len] (-1/0/1)
|
||||
masks = (att_prob_qc * dirc_feats[0]).sum(1)
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ from models.multisql_predictor import MultiSqlPredictor
|
|||
from models.root_teminal_predictor import RootTeminalPredictor
|
||||
from models.andor_predictor import AndOrPredictor
|
||||
from models.op_predictor import OpPredictor
|
||||
from preprocess_train_dev_data import index_to_column_name
|
||||
|
||||
|
||||
SQL_OPS = ('none','intersect', 'union', 'except')
|
||||
|
@ -53,6 +52,12 @@ class Stack:
|
|||
return self.items.insert(i,x)
|
||||
|
||||
|
||||
def index_to_column_name(index, table):
|
||||
column_name = table["column_names"][index][1]
|
||||
table_index = table["column_names"][index][0]
|
||||
table_name = table["table_names"][table_index]
|
||||
return table_name, column_name, index
|
||||
|
||||
def to_batch_tables(tables, B, table_type):
|
||||
# col_lens = []
|
||||
col_seq = []
|
||||
|
|
|
@ -4,7 +4,14 @@ import json
|
|||
import numpy as np
|
||||
import os
|
||||
import signal
|
||||
from preprocess_train_dev_data import get_table_dict
|
||||
|
||||
|
||||
def get_table_dict(table_data_path):
|
||||
data = json.load(open(table_data_path))
|
||||
table = dict()
|
||||
for item in data:
|
||||
table[item["db_id"]] = item
|
||||
return table
|
||||
|
||||
|
||||
def load_train_dev_dataset(component,train_dev,history, root):
|
||||
|
|
|
@ -85,7 +85,7 @@ class WordEmbedding(nn.Module):
|
|||
return [dirc_feats, pos_feats, neg_feats]
|
||||
|
||||
def get_direction_feature(self, q_seq, perm, st, ed, gt_col, component=None, train=True):
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
return self.get_direction_feature_mask(q_seq, perm, st, ed, gt_col, component, train)
|
||||
|
||||
B = len(q_seq)
|
||||
|
@ -116,7 +116,7 @@ class WordEmbedding(nn.Module):
|
|||
return dirc_feats
|
||||
|
||||
def get_direction_feature_pred(self, B, q_seq, idx, gt_col):
|
||||
if self.feats_format == 'mask':
|
||||
if self.feats_format == 'direct':
|
||||
return self.get_direction_feature_pred_mask(B, q_seq, idx, gt_col)
|
||||
|
||||
dirc_feats = np.zeros([B, len(q_seq[0]) + 2, len(self.pos_feats)], dtype=np.float32)
|
||||
|
|
Загрузка…
Ссылка в новой задаче