From 648fc87f25feeb6740679b5a3cb61bbc0465c5c7 Mon Sep 17 00:00:00 2001 From: Alex Polozov Date: Fri, 14 Aug 2020 19:16:09 -0700 Subject: [PATCH] Fixed the issue where token IDs were not converted to word-piece IDs for BERT value linking. Closes #4. --- README.md | 1 + ratsql/datasets/spider.py | 2 +- ratsql/models/spider/spider_enc.py | 22 +++++++++++++++++++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index eacc0d7..52fb281 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ If you use RAT-SQL in your work, please cite it as follows: **2020-08-14:** - The Docker image now inherits from a CUDA-enabled base image. - Clarified memory and dataset requirements on the image. +- Fixed the issue where token IDs were not converted to word-piece IDs for BERT value linking. ## Usage diff --git a/ratsql/datasets/spider.py b/ratsql/datasets/spider.py index 8ab15e5..41a66b0 100644 --- a/ratsql/datasets/spider.py +++ b/ratsql/datasets/spider.py @@ -150,7 +150,7 @@ class SpiderDataset(torch.utils.data.Dataset): for db_id, schema in tqdm(self.schemas.items(), desc="DB connections"): sqlite_path = Path(db_path) / db_id / f"{db_id}.sqlite" source: sqlite3.Connection - with sqlite3.connect(sqlite_path) as source: + with sqlite3.connect(str(sqlite_path)) as source: dest = sqlite3.connect(':memory:') dest.row_factory = sqlite3.Row source.backup(dest) diff --git a/ratsql/models/spider/spider_enc.py b/ratsql/models/spider/spider_enc.py index 9e94c47..fe6b63d 100644 --- a/ratsql/models/spider/spider_enc.py +++ b/ratsql/models/spider/spider_enc.py @@ -546,6 +546,7 @@ class Bertokens: self.pieces = pieces self.normalized_pieces = None + self.recovered_pieces = None self.idx_map = None self.normalize_toks() @@ -605,6 +606,7 @@ class Bertokens: normalized_toks.append(lemma_word) self.normalized_pieces = normalized_toks + self.recovered_pieces = new_toks def bert_schema_linking(self, columns, tables): question_tokens = self.normalized_pieces @@ -624,6 +626,21 @@ class Bertokens: new_sc_link[m_type] = _match return new_sc_link + def bert_cv_linking(self, schema): + question_tokens = self.recovered_pieces # Not using normalized tokens here because values usually match exactly + cv_link = compute_cell_value_linking(question_tokens, schema) + + new_cv_link = {} + for m_type in cv_link: + _match = {} + for ij_str in cv_link[m_type]: + q_id_str, col_tab_id_str = ij_str.split(",") + q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) + real_q_id = self.idx_map[q_id] + _match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str] + new_cv_link[m_type] = _match + return new_cv_link + class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc): @@ -667,8 +684,8 @@ class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc): def preprocess_item(self, item, validation_info): question = self._tokenize(item.text, item.orig['question']) preproc_schema = self._preprocess_schema(item.schema) + question_bert_tokens = Bertokens(question) if self.compute_sc_link: - question_bert_tokens = Bertokens(question) sc_link = question_bert_tokens.bert_schema_linking( preproc_schema.normalized_column_names, preproc_schema.normalized_table_names @@ -677,8 +694,7 @@ class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc): sc_link = {"q_col_match": {}, "q_tab_match": {}} if self.compute_cv_link: - question_bert_tokens = Bertokens(question) - cv_link = compute_cell_value_linking(question_bert_tokens.normalized_pieces, item.schema) + cv_link = question_bert_tokens.bert_cv_linking(item.schema) else: cv_link = {"num_date_match": {}, "cell_match": {}}