Fixed the issue where token IDs were not converted to word-piece IDs for BERT value linking. Closes #4.

This commit is contained in:
Alex Polozov 2020-08-14 19:16:09 -07:00
Родитель 4a013a9073
Коммит 648fc87f25
3 изменённых файлов: 21 добавлений и 4 удалений

Просмотреть файл

@ -21,6 +21,7 @@ If you use RAT-SQL in your work, please cite it as follows:
**2020-08-14:**
- The Docker image now inherits from a CUDA-enabled base image.
- Clarified memory and dataset requirements on the image.
- Fixed the issue where token IDs were not converted to word-piece IDs for BERT value linking.
## Usage

Просмотреть файл

@ -150,7 +150,7 @@ class SpiderDataset(torch.utils.data.Dataset):
for db_id, schema in tqdm(self.schemas.items(), desc="DB connections"):
sqlite_path = Path(db_path) / db_id / f"{db_id}.sqlite"
source: sqlite3.Connection
with sqlite3.connect(sqlite_path) as source:
with sqlite3.connect(str(sqlite_path)) as source:
dest = sqlite3.connect(':memory:')
dest.row_factory = sqlite3.Row
source.backup(dest)

Просмотреть файл

@ -546,6 +546,7 @@ class Bertokens:
self.pieces = pieces
self.normalized_pieces = None
self.recovered_pieces = None
self.idx_map = None
self.normalize_toks()
@ -605,6 +606,7 @@ class Bertokens:
normalized_toks.append(lemma_word)
self.normalized_pieces = normalized_toks
self.recovered_pieces = new_toks
def bert_schema_linking(self, columns, tables):
question_tokens = self.normalized_pieces
@ -624,6 +626,21 @@ class Bertokens:
new_sc_link[m_type] = _match
return new_sc_link
def bert_cv_linking(self, schema):
question_tokens = self.recovered_pieces # Not using normalized tokens here because values usually match exactly
cv_link = compute_cell_value_linking(question_tokens, schema)
new_cv_link = {}
for m_type in cv_link:
_match = {}
for ij_str in cv_link[m_type]:
q_id_str, col_tab_id_str = ij_str.split(",")
q_id, col_tab_id = int(q_id_str), int(col_tab_id_str)
real_q_id = self.idx_map[q_id]
_match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str]
new_cv_link[m_type] = _match
return new_cv_link
class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc):
@ -667,8 +684,8 @@ class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc):
def preprocess_item(self, item, validation_info):
question = self._tokenize(item.text, item.orig['question'])
preproc_schema = self._preprocess_schema(item.schema)
question_bert_tokens = Bertokens(question)
if self.compute_sc_link:
question_bert_tokens = Bertokens(question)
sc_link = question_bert_tokens.bert_schema_linking(
preproc_schema.normalized_column_names,
preproc_schema.normalized_table_names
@ -677,8 +694,7 @@ class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc):
sc_link = {"q_col_match": {}, "q_tab_match": {}}
if self.compute_cv_link:
question_bert_tokens = Bertokens(question)
cv_link = compute_cell_value_linking(question_bert_tokens.normalized_pieces, item.schema)
cv_link = question_bert_tokens.bert_cv_linking(item.schema)
else:
cv_link = {"num_date_match": {}, "cell_match": {}}