Fixed the issue where token IDs were not converted to word-piece IDs for BERT value linking. Closes #4.
This commit is contained in:
Родитель
4a013a9073
Коммит
648fc87f25
|
@ -21,6 +21,7 @@ If you use RAT-SQL in your work, please cite it as follows:
|
|||
**2020-08-14:**
|
||||
- The Docker image now inherits from a CUDA-enabled base image.
|
||||
- Clarified memory and dataset requirements on the image.
|
||||
- Fixed the issue where token IDs were not converted to word-piece IDs for BERT value linking.
|
||||
|
||||
## Usage
|
||||
|
||||
|
|
|
@ -150,7 +150,7 @@ class SpiderDataset(torch.utils.data.Dataset):
|
|||
for db_id, schema in tqdm(self.schemas.items(), desc="DB connections"):
|
||||
sqlite_path = Path(db_path) / db_id / f"{db_id}.sqlite"
|
||||
source: sqlite3.Connection
|
||||
with sqlite3.connect(sqlite_path) as source:
|
||||
with sqlite3.connect(str(sqlite_path)) as source:
|
||||
dest = sqlite3.connect(':memory:')
|
||||
dest.row_factory = sqlite3.Row
|
||||
source.backup(dest)
|
||||
|
|
|
@ -546,6 +546,7 @@ class Bertokens:
|
|||
self.pieces = pieces
|
||||
|
||||
self.normalized_pieces = None
|
||||
self.recovered_pieces = None
|
||||
self.idx_map = None
|
||||
|
||||
self.normalize_toks()
|
||||
|
@ -605,6 +606,7 @@ class Bertokens:
|
|||
normalized_toks.append(lemma_word)
|
||||
|
||||
self.normalized_pieces = normalized_toks
|
||||
self.recovered_pieces = new_toks
|
||||
|
||||
def bert_schema_linking(self, columns, tables):
|
||||
question_tokens = self.normalized_pieces
|
||||
|
@ -624,6 +626,21 @@ class Bertokens:
|
|||
new_sc_link[m_type] = _match
|
||||
return new_sc_link
|
||||
|
||||
def bert_cv_linking(self, schema):
|
||||
question_tokens = self.recovered_pieces # Not using normalized tokens here because values usually match exactly
|
||||
cv_link = compute_cell_value_linking(question_tokens, schema)
|
||||
|
||||
new_cv_link = {}
|
||||
for m_type in cv_link:
|
||||
_match = {}
|
||||
for ij_str in cv_link[m_type]:
|
||||
q_id_str, col_tab_id_str = ij_str.split(",")
|
||||
q_id, col_tab_id = int(q_id_str), int(col_tab_id_str)
|
||||
real_q_id = self.idx_map[q_id]
|
||||
_match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str]
|
||||
new_cv_link[m_type] = _match
|
||||
return new_cv_link
|
||||
|
||||
|
||||
class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc):
|
||||
|
||||
|
@ -667,8 +684,8 @@ class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc):
|
|||
def preprocess_item(self, item, validation_info):
|
||||
question = self._tokenize(item.text, item.orig['question'])
|
||||
preproc_schema = self._preprocess_schema(item.schema)
|
||||
question_bert_tokens = Bertokens(question)
|
||||
if self.compute_sc_link:
|
||||
question_bert_tokens = Bertokens(question)
|
||||
sc_link = question_bert_tokens.bert_schema_linking(
|
||||
preproc_schema.normalized_column_names,
|
||||
preproc_schema.normalized_table_names
|
||||
|
@ -677,8 +694,7 @@ class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc):
|
|||
sc_link = {"q_col_match": {}, "q_tab_match": {}}
|
||||
|
||||
if self.compute_cv_link:
|
||||
question_bert_tokens = Bertokens(question)
|
||||
cv_link = compute_cell_value_linking(question_bert_tokens.normalized_pieces, item.schema)
|
||||
cv_link = question_bert_tokens.bert_cv_linking(item.schema)
|
||||
else:
|
||||
cv_link = {"num_date_match": {}, "cell_match": {}}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче