update table_linearize processor for table processing.
This commit is contained in:
Родитель
fe690930c3
Коммит
c623dc26e8
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Utils for linearizing the table content into a flatten sequence
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def concat_input_with_table(_table_content: Dict, input_query: str, start_row_idx: int = 0):
|
||||
compact_str = input_query.lower().strip() + " " + linearize_schema(_table_content, start_row_idx)
|
||||
return compact_str.strip()
|
||||
|
||||
|
||||
def linearize_schema(_table_content: Dict, start_row_idx: int):
|
||||
"""
|
||||
Data format: col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...
|
||||
"""
|
||||
_table_str = "col : " + " | ".join(_table_content["header"]) + " "
|
||||
_table_str = _table_str.lower()
|
||||
for i, row_example in enumerate(_table_content["rows"]):
|
||||
_table_str += "row " + str(start_row_idx + i + 1) + " : "
|
||||
row_cell_values = [str(cell_value) if isinstance(cell_value, int) else cell_value.lower()
|
||||
for cell_value in row_example]
|
||||
_table_str += " | ".join(row_cell_values) + " "
|
||||
|
||||
return _table_str
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Utils for linearizing the table content into a flatten sequence
|
||||
"""
|
||||
import abc
|
||||
import abc
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class TableLinearize(abc.ABC):
|
||||
|
||||
PROMPT_MESSAGE = """
|
||||
Please check that your table must follow the following format:
|
||||
{"header": ["col1", "col2", "col3"], "rows": [["row11", "row12", "row13"], ["row21", "row22", "row23"]]}
|
||||
"""
|
||||
|
||||
def __init__(self, lower_case):
|
||||
# if lower case, return the uncased table str; otherwise the cased.
|
||||
self.lower_case = lower_case
|
||||
|
||||
def process_table(self, table_content: Dict):
|
||||
"""
|
||||
Given a table, TableLinearize aims at converting it into a flatten sequence with special symbols.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class IndexedRowTableLinearize(TableLinearize):
|
||||
"""
|
||||
FORMAT: col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...
|
||||
"""
|
||||
|
||||
def process_table(self, table_content: Dict):
|
||||
"""
|
||||
Given a table, TableLinearize aims at converting it into a flatten sequence with special symbols.
|
||||
"""
|
||||
assert "header" in table_content and "rows" in table_content, self.PROMPT_MESSAGE
|
||||
_table_str = "col : " + " | ".join(table_content["header"]) + " "
|
||||
for i, row_example in enumerate(table_content["rows"]):
|
||||
# start from row 1 not from row 0
|
||||
_table_str += "row " + str(i + 1) + " : "
|
||||
row_cell_values = []
|
||||
for cell_value in row_example:
|
||||
if isinstance(cell_value, int):
|
||||
row_cell_values.append(str(cell_value))
|
||||
elif self.lower_case:
|
||||
row_cell_values.append(cell_value.lower())
|
||||
else:
|
||||
row_cell_values.append(cell_value)
|
||||
_table_str += " | ".join(row_cell_values) + " "
|
||||
return _table_str
|
Загрузка…
Ссылка в новой задаче