TUTA_table_understanding/tuta
Haoyu Dong 647b2a62ab
Update ctc_finetune.py
2022-04-21 15:39:07 +08:00
..
data update readme 2022-03-22 12:23:12 +08:00
model Upload tuta's code 2021-10-29 20:47:27 +08:00
vocab Update README.md 2021-11-19 10:26:35 +08:00
README.md update readme 2022-03-22 12:23:45 +08:00
ctc_finetune.py Update ctc_finetune.py 2022-04-21 15:39:07 +08:00
dataset.pt update readme 2022-03-22 12:23:12 +08:00
dynamic_data.py Upload tuta's code 2021-10-29 20:47:27 +08:00
optimizer.py Upload tuta's code 2021-10-29 20:47:27 +08:00
prepare.py Upload tuta's code 2021-10-29 20:47:27 +08:00
reader.py Upload tuta's code 2021-10-29 20:47:27 +08:00
tokenizer.py Upload tuta's code 2021-10-29 20:47:27 +08:00
train.py Upload tuta's code 2021-10-29 20:47:27 +08:00
trainers.py Upload tuta's code 2021-10-29 20:47:27 +08:00
utils.py Upload tuta's code 2021-10-29 20:47:27 +08:00

README.md

TUTA

TUTA is a unified pretrained model for understanding generally structured tables. TUTA introduces two mechanisms to utilize structural information: (1) explicit and implicit positional encoding based on bi-tree structure; (2) structure-aware attention to aggregatate neighboring contexts.

Models

We provide three variants of pre-trained TUTA models: TUTA (-implicit), TUTA-explicit, and TUTA-base. These pre-trained TUTA variants can be downloaded from:

Training

To run pretraining tasks, simply run

python train.py                                           \
--dataset_paths="./dataset.pt"                              \
--pretrained_model_path="${tuta_model_dir}/tuta.bin"      \
--output_model_path="${tuta_model_dir}/trained-tuta.bin"

# to enable a quick test, one can run
python train.py  --batch_size 1  --chunk_size 10  --buffer_size 10  --report_steps 1  --total_steps 20

# to enable multi-gpu distributed training, additionally specify 
--world_size 4  --gpu_ranks 0 1 2 3

Do make sure that the number of input dataset_paths is no less that the world_size (i.e. number of gpu_ranks).
One can find more adjustable arguments in the main procedure.

Downstream tasks

Cell Type Classification (CTC)

To perform the task of cell type classification at downstream:

  • for data processing, use SheetReader in the reader.py and CtcTokenizer in the tokenizer.py;
  • for fine-tuning, use the CtcHead and TUTA(base)forCTC in the ./model/ directory.

Table Type Classification (TTC)

To perform the task of table type classification at downstream:

  • for data processing, use SheetReader in the reader.py and TtcTokenizer in the tokenizer.py;
  • for fine-tuning, use the TtcHead and TUTA(base)forTTC in the ./model/ directory.

For an end-to-end trial, run:

python ctc_finetune.py                                           \
--folds_path="${dataset_dir}/folds_deex5.json"                    \
--data_file="${dataset_dir}/deex.json"                            \
--pretrained_model_path="${tuta_model_dir}/tuta.bin"             \
--output_model_path="${tuta_model_dir}/tuta-ctc.bin"              \
--target="tuta"                                                   \
--device_id=0                                                   \
--batch_size=2                                                   \
--max_seq_len=512                                                 \
--max_cell_num=256                                                 \
--epochs_num=40                                                   \
--attention_distance=2                                             

A preprocessed dataset of DeEx can be downloaded from:

Data Pre-processing

For a sample raw table file input, run

# for SpreadSheet
python prepare.py                          \
--input_dir ./data/pretrain/spreadsheet   \
--source_type sheet                        \
--output_path ./dataset.pt

# for WikiTable
python prepare.py                                      \
--input_path ./data/pretrain/wiki-table-samples.json  \
--source_type wiki                                     \
--output_path ./dataset.pt

# for WDCTable
python prepare.py                         \
--input_dir ./data/pretrain/wdc          \
--source_type wdc                         \
--output_path ./dataset.pt

will generate a semi-processed version for pre-training inputs.

Input this data file as an argument into the pre-training script, then the data-loader will dynamically process for three pre-training objectives, namely Masked Language Model (MLM), Cell-Level Cloze(CLC), and Table Context Retrieval (TCR).