647b2a62ab | ||
---|---|---|
.. | ||
data | ||
model | ||
vocab | ||
README.md | ||
ctc_finetune.py | ||
dataset.pt | ||
dynamic_data.py | ||
optimizer.py | ||
prepare.py | ||
reader.py | ||
tokenizer.py | ||
train.py | ||
trainers.py | ||
utils.py |
README.md
TUTA
TUTA is a unified pretrained model for understanding generally structured tables. TUTA introduces two mechanisms to utilize structural information: (1) explicit and implicit positional encoding based on bi-tree structure; (2) structure-aware attention to aggregatate neighboring contexts.
Models
We provide three variants of pre-trained TUTA models: TUTA (-implicit), TUTA-explicit, and TUTA-base. These pre-trained TUTA variants can be downloaded from:
Training
To run pretraining tasks, simply run
python train.py \
--dataset_paths="./dataset.pt" \
--pretrained_model_path="${tuta_model_dir}/tuta.bin" \
--output_model_path="${tuta_model_dir}/trained-tuta.bin"
# to enable a quick test, one can run
python train.py --batch_size 1 --chunk_size 10 --buffer_size 10 --report_steps 1 --total_steps 20
# to enable multi-gpu distributed training, additionally specify
--world_size 4 --gpu_ranks 0 1 2 3
Do make sure that the number of input dataset_paths
is no less that the world_size
(i.e. number of gpu_ranks
).
One can find more adjustable arguments in the main procedure.
Downstream tasks
Cell Type Classification (CTC)
To perform the task of cell type classification at downstream:
- for data processing, use
SheetReader
in the reader.py andCtcTokenizer
in the tokenizer.py; - for fine-tuning, use the
CtcHead
andTUTA(base)forCTC
in the ./model/ directory.
Table Type Classification (TTC)
To perform the task of table type classification at downstream:
- for data processing, use
SheetReader
in the reader.py andTtcTokenizer
in the tokenizer.py; - for fine-tuning, use the
TtcHead
andTUTA(base)forTTC
in the ./model/ directory.
For an end-to-end trial, run:
python ctc_finetune.py \
--folds_path="${dataset_dir}/folds_deex5.json" \
--data_file="${dataset_dir}/deex.json" \
--pretrained_model_path="${tuta_model_dir}/tuta.bin" \
--output_model_path="${tuta_model_dir}/tuta-ctc.bin" \
--target="tuta" \
--device_id=0 \
--batch_size=2 \
--max_seq_len=512 \
--max_cell_num=256 \
--epochs_num=40 \
--attention_distance=2
A preprocessed dataset of DeEx can be downloaded from:
Data Pre-processing
For a sample raw table file input, run
# for SpreadSheet
python prepare.py \
--input_dir ./data/pretrain/spreadsheet \
--source_type sheet \
--output_path ./dataset.pt
# for WikiTable
python prepare.py \
--input_path ./data/pretrain/wiki-table-samples.json \
--source_type wiki \
--output_path ./dataset.pt
# for WDCTable
python prepare.py \
--input_dir ./data/pretrain/wdc \
--source_type wdc \
--output_path ./dataset.pt
will generate a semi-processed version for pre-training inputs.
Input this data file as an argument into the pre-training script, then the data-loader will dynamically process for three pre-training objectives, namely Masked Language Model (MLM), Cell-Level Cloze(CLC), and Table Context Retrieval (TCR).