update scripts and the base setting of tapex.base.

This commit is contained in:
SivilTaram 2021-08-28 09:41:18 +08:00
Родитель 0b87efa253
Коммит 4e1adfe68b
5 изменённых файлов: 20 добавлений и 17 удалений

Просмотреть файл

@ -24,16 +24,18 @@ Now we support the **one-stop services** for the following datasets, and you can
Note that the one-stop service includes the procedure of downloading datasets and pretrained tapex models, truncating long inputs, converting to the fairseq machine translation format, applying BPE tokenization and preprocessing for fairseq model training.
By default, these scripts will process data using the dictionary of `tapex.base`. If you want to switch pre-trained models, please change the variable `MODLE_NAME` at line 21.
After one dataset is prepared, you can run the `tableqa/run_model.py` script to train your TableQA models on different datasets.
### 🍳 Train
To train a model, you could simply run the following command, where `<dataset_dir>/<model_name>` refers to directory which contains a `bin` folder such as `dataset/wikisql/bart.base`, `<model_path>` refers to a pre-trained model path such as `bart.base/model.pt`, `<model_arch>` refers to a pre-defined model architecture in fairseq such as `bart_base`.
To train a model, you could simply run the following command, where `<dataset_dir>` refers to directory which contains a `bin` folder such as `dataset/wikisql/tapex.base`, `<model_path>` refers to a pre-trained model path such as `tapex.base/model.pt`, `<model_arch>` refers to a pre-defined model architecture in fairseq such as `bart_base`.
**HINT**: for `tapex.base` or `tapex.large`, `<model_arch>` should be `bart_base` or `bart_large` respectively.
```shell
$ python run_model.py train --dataset-dir <dataset_dir>/<model_name> --model-path <model_path> --model-arch <model_arch>
$ python run_model.py train --dataset-dir <dataset_dir> --model-path <model_path> --model-arch <model_arch>
```
A full list of training arguments can be seen as below:
@ -63,10 +65,10 @@ A full list of training arguments can be seen as below:
### 🍪 Evaluate
Once the model is fine-tuned, we can evaluate it by runing the following command, where `<dataset_dir>/<model_name>` refers to directory which contains a `bin` folder such as `dataset/wikisql/bart.base`, and `<model_path>` refers to a fine-tuned model path such as `checkpoints/checkpoint_best.pt`.
Once the model is fine-tuned, we can evaluate it by runing the following command, where `<dataset_dir>` refers to directory which contains a `bin` folder such as `dataset/wikisql/tapex.base`, and `<model_path>` refers to a fine-tuned model path such as `checkpoints/checkpoint_best.pt`.
```shell
$ python run_model.py eval --dataset-dir <dataset_dir>/<model_name> --model-path <model_path>
$ python run_model.py eval --dataset-dir <dataset_dir> --model-path <model_path>
```
A full list of evaluating arguments can be seen as below:
@ -106,6 +108,7 @@ Then you can predict the answer online with the following command, where `<model
```shell
$ python run_model.py predict --resource-dir <resource_dir> --checkpoint-name <model_name>
```
> Note that if <resource_dir> is under the current working directory, you should still specify a prefix `./` to make the path like a local path (e.g., ./tapex.base). Otherwise, fairseq will regard it as a model name.
## 🔎 Table Fact Verification (Released by Sep. 5)

Просмотреть файл

@ -18,7 +18,7 @@ PROCESSED_DATASET_FOLDER = "dataset"
TABLE_PATH = os.path.join(RAW_DATASET_FOLDER, "sqa")
TABLE_PROCESSOR = get_default_processor(max_cell_length=15, max_input_length=1024)
# Options: bart.base, bart.large, tapex.base, tapex.large
MODEL_NAME = "bart.large"
MODEL_NAME = "tapex.base"
logger = logging.getLogger(__name__)
@ -132,5 +132,5 @@ if __name__ == '__main__':
build_sqa_huggingface_dataset(processed_sqa_data_dir)
logger.info("*" * 80)
logger.info("Now you can train models using {} as the passed argument. "
"More details in `train_generation.py`.".format(processed_sqa_data_dir))
logger.info("Now you can train models using {} as the <data_dir> argument. "
"More details in `run_model.py`.".format(os.path.join(processed_sqa_data_dir, MODEL_NAME)))

Просмотреть файл

@ -19,7 +19,7 @@ RAW_DATASET_FOLDER = "raw_dataset"
PROCESSED_DATASET_FOLDER = "dataset"
TABLE_PROCESSOR = get_default_processor(max_cell_length=15, max_input_length=1024)
# Options: bart.base, bart.large, tapex.base, tapex.large
MODEL_NAME = "bart.base"
MODEL_NAME = "tapex.base"
logger = logging.getLogger(__name__)
@ -129,5 +129,5 @@ if __name__ == '__main__':
build_wikisql_huggingface_dataset(processed_wikisql_data_dir)
logger.info("*" * 80)
logger.info("Now you can train models using {} as the passed argument. "
"More details in `train_generation.py`.".format(processed_wikisql_data_dir))
logger.info("Now you can train models using {} as the <data_dir> argument. "
"More details in `run_model.py`.".format(os.path.join(processed_wikisql_data_dir, MODEL_NAME)))

Просмотреть файл

@ -18,7 +18,7 @@ PROCESSED_DATASET_FOLDER = "dataset"
TABLE_PATH = os.path.join(RAW_DATASET_FOLDER, "wtq")
TABLE_PROCESSOR = get_default_processor(max_cell_length=15, max_input_length=1024)
# Options: bart.base, bart.large, tapex.base, tapex.large
MODEL_NAME = "bart.base"
MODEL_NAME = "tapex.base"
logger = logging.getLogger(__name__)
@ -126,5 +126,5 @@ if __name__ == '__main__':
build_wtq_huggingface_dataset(processed_wtq_data_dir)
logger.info("*" * 80)
logger.info("Now you can train models using {} as the passed argument. "
"More details in `train_generation.py`.".format(processed_wtq_data_dir))
logger.info("Now you can train models using {} as the <data_dir> argument. "
"More details in `run_model.py`.".format(os.path.join(processed_wtq_data_dir, MODEL_NAME)))

Просмотреть файл

@ -21,7 +21,7 @@ def set_train_parser(parser_group):
help="dataset directory where train.src is located in")
train_parser.add_argument("--exp-dir", type=str, default="checkpoints",
help="experiment directory which stores the checkpoint weights")
train_parser.add_argument("--model-path", type=str, default="bart.base/model.pt",
train_parser.add_argument("--model-path", type=str, default="tapex.base/model.pt",
help="the directory of pre-trained model path")
train_parser.add_argument("--model-arch", type=str, default="bart_base", choices=["bart_large", "bart_base"],
help="tapex large should correspond to bart_large, and tapex base should be bart_base")
@ -41,8 +41,8 @@ def set_eval_parser(parser_group):
eval_parser = parser_group.add_parser("eval")
eval_parser.add_argument("--dataset-dir", type=str, required=True, default="",
help="dataset directory where train.src is located in")
eval_parser.add_argument("--model-path", type=str, default="wikisql.tapex.base/model.pt",
help="the directory of fine-tuned model path such as wikisql.tapex.base")
eval_parser.add_argument("--model-path", type=str, default="tapex.base.wikisql/model.pt",
help="the directory of fine-tuned model path such as tapex.base.wikisql/model.pt")
eval_parser.add_argument("--sub-dir", type=str, default="valid", choices=["train", "valid", "test"],
help="the directory of pre-trained model path, and the default should be in"
"{bart.base, bart.large, tapex.base, tapex.large}.")
@ -54,7 +54,7 @@ def set_eval_parser(parser_group):
def set_predict_parser(parser_group):
predict_parser = parser_group.add_parser("predict")
predict_parser.add_argument("--resource-dir", type=str, required=True, default="",
predict_parser.add_argument("--resource-dir", type=str, required=True, default="./tapex.base",
help="the resource dir which contains the model weights, vocab.bpe, "
"dict.src.txt, dict.tgt.txt and encoder.json.")
predict_parser.add_argument("--checkpoint-name", type=str, default="model.pt",