From 56ec3bf5f8f9773728918a2870c738f9c6c4c82a Mon Sep 17 00:00:00 2001 From: berlino Date: Tue, 14 Jul 2020 10:22:33 +0100 Subject: [PATCH] Fix bug of not converting token ids to piece id --- ratsql/models/spider/spider_enc.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/ratsql/models/spider/spider_enc.py b/ratsql/models/spider/spider_enc.py index 9e94c47..3860dae 100644 --- a/ratsql/models/spider/spider_enc.py +++ b/ratsql/models/spider/spider_enc.py @@ -624,6 +624,22 @@ class Bertokens: new_sc_link[m_type] = _match return new_sc_link + def bert_cv_linking(self, schema): + question_tokens = self.normalized_pieces + cv_link = compute_cell_value_linking(question_tokens, schema) + + new_cv_link = {} + for m_type in cv_link: + _match = {} + for ij_str in cv_link[m_type]: + q_id_str, col_tab_id_str = ij_str.split(",") + q_id, col_tab_id = int(q_id_str), int(col_tab_id_str) + real_q_id = self.idx_map[q_id] + _match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str] + + new_cv_link[m_type] = _match + return new_cv_link + class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc): @@ -678,7 +694,7 @@ class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc): if self.compute_cv_link: question_bert_tokens = Bertokens(question) - cv_link = compute_cell_value_linking(question_bert_tokens.normalized_pieces, item.schema) + cv_link = question_bert_tokens.bert_cv_linking(item.schema) else: cv_link = {"num_date_match": {}, "cell_match": {}}