optimize the preprocessing procedure.
This commit is contained in:
Родитель
bca5d76d2e
Коммит
48714e91dc
|
@ -28,10 +28,12 @@ After one dataset is prepared, you can run the `tableqa/run_model.py` script to
|
|||
|
||||
### 🍳 Train
|
||||
|
||||
To train a model, you could simply run the following command, where `<dataset_dir>` refers to dirs such as `dataset/wikisql`, and `<model_path>` refers to a pre-trained model path such as `bart.base/model.pt`.
|
||||
To train a model, you could simply run the following command, where `<dataset_dir>/<model_name>` refers to directory which contains a `bin` folder such as `dataset/wikisql/bart.base`, `<model_path>` refers to a pre-trained model path such as `bart.base/model.pt`, `<model_arch>` refers to a pre-defined model architecture in fairseq such as `bart_base`.
|
||||
|
||||
**HINT**: for `tapex.base` or `tapex.large`, `<model_arch>` should be `bart_base` or `bart_large` respectively.
|
||||
|
||||
```shell
|
||||
$ python run_model.py train --dataset-dir <dataset_dir> --model-path <model_path>
|
||||
$ python run_model.py train --dataset-dir <dataset_dir>/<model_name> --model-path <model_path> --model-arch <model_arch>
|
||||
```
|
||||
|
||||
A full list of training arguments can be seen as below:
|
||||
|
@ -61,10 +63,10 @@ A full list of training arguments can be seen as below:
|
|||
|
||||
### 🍪 Evaluate
|
||||
|
||||
Once the model is fine-tuned, we can evaluate it by runing the following command, where `<dataset_dir>` refers to dirs such as `dataset/wikisql`, and `<model_path>` refers to a fine-tuned model path such as `checkpoints/checkpoint_best.pt`.
|
||||
Once the model is fine-tuned, we can evaluate it by runing the following command, where `<dataset_dir>/<model_name>` refers to directory which contains a `bin` folder such as `dataset/wikisql/bart.base`, and `<model_path>` refers to a fine-tuned model path such as `checkpoints/checkpoint_best.pt`.
|
||||
|
||||
```shell
|
||||
$ python run_model.py eval --dataset-dir <dataset_dir> --model-path <model_path>
|
||||
$ python run_model.py eval --dataset-dir <dataset_dir>/<model_name> --model-path <model_path>
|
||||
```
|
||||
|
||||
A full list of evaluating arguments can be seen as below:
|
||||
|
|
|
@ -17,8 +17,8 @@ RAW_DATASET_FOLDER = "raw_dataset"
|
|||
PROCESSED_DATASET_FOLDER = "dataset"
|
||||
TABLE_PATH = os.path.join(RAW_DATASET_FOLDER, "sqa")
|
||||
TABLE_PROCESSOR = get_default_processor(max_cell_length=15, max_input_length=1024)
|
||||
# Options: bart.base, bart.large, tapex.bart.base, tapex.bart.large
|
||||
RESOURCE_NAME = "bart.base"
|
||||
# Options: bart.base, bart.large, tapex.base, tapex.large
|
||||
MODEL_NAME = "bart.large"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -101,12 +101,12 @@ def build_sqa_huggingface_dataset(fairseq_data_dir):
|
|||
|
||||
|
||||
def preprocess_sqa_dataset(processed_data_dir):
|
||||
fairseq_bpe_translation(processed_data_dir, resource_name=RESOURCE_NAME)
|
||||
fairseq_binary_translation(processed_data_dir, resource_name=RESOURCE_NAME)
|
||||
fairseq_bpe_translation(processed_data_dir, resource_name=MODEL_NAME)
|
||||
fairseq_binary_translation(processed_data_dir, resource_name=MODEL_NAME)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("You are using the setting of {}".format(RESOURCE_NAME))
|
||||
logger.info("You are using the setting of {}".format(MODEL_NAME))
|
||||
|
||||
logger.info("*" * 80)
|
||||
logger.info("Prepare to download SQA dataset from the official link...")
|
||||
|
|
|
@ -18,8 +18,8 @@ from tapex.data_utils.format_converter import convert_fairseq_to_hf
|
|||
RAW_DATASET_FOLDER = "raw_dataset"
|
||||
PROCESSED_DATASET_FOLDER = "dataset"
|
||||
TABLE_PROCESSOR = get_default_processor(max_cell_length=15, max_input_length=1024)
|
||||
# Options: bart.base, bart.large, tapex.bart.base, tapex.bart.large
|
||||
RESOURCE_NAME = "bart.base"
|
||||
# Options: bart.base, bart.large, tapex.base, tapex.large
|
||||
MODEL_NAME = "bart.base"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -98,12 +98,12 @@ def build_wikisql_huggingface_dataset(fairseq_data_dir):
|
|||
|
||||
|
||||
def preprocess_wikisql_dataset(processed_data_dir):
|
||||
fairseq_bpe_translation(processed_data_dir, resource_name=RESOURCE_NAME)
|
||||
fairseq_binary_translation(processed_data_dir, resource_name=RESOURCE_NAME)
|
||||
fairseq_bpe_translation(processed_data_dir, resource_name=MODEL_NAME)
|
||||
fairseq_binary_translation(processed_data_dir, resource_name=MODEL_NAME)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("You are using the setting of {}".format(RESOURCE_NAME))
|
||||
logger.info("You are using the setting of {}".format(MODEL_NAME))
|
||||
|
||||
logger.info("*" * 80)
|
||||
logger.info("Prepare to download WikiSQL from the official link...")
|
||||
|
|
|
@ -17,8 +17,8 @@ RAW_DATASET_FOLDER = "raw_dataset"
|
|||
PROCESSED_DATASET_FOLDER = "dataset"
|
||||
TABLE_PATH = os.path.join(RAW_DATASET_FOLDER, "wtq")
|
||||
TABLE_PROCESSOR = get_default_processor(max_cell_length=15, max_input_length=1024)
|
||||
# Options: bart.base, bart.large, tapex.bart.base, tapex.bart.large
|
||||
RESOURCE_NAME = "bart.base"
|
||||
# Options: bart.base, bart.large, tapex.base, tapex.large
|
||||
MODEL_NAME = "bart.base"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -95,12 +95,12 @@ def build_wtq_huggingface_dataset(fairseq_data_dir):
|
|||
|
||||
|
||||
def preprocess_wtq_dataset(processed_data_dir):
|
||||
fairseq_bpe_translation(processed_data_dir, resource_name=RESOURCE_NAME)
|
||||
fairseq_binary_translation(processed_data_dir, resource_name=RESOURCE_NAME)
|
||||
fairseq_bpe_translation(processed_data_dir, resource_name=MODEL_NAME)
|
||||
fairseq_binary_translation(processed_data_dir, resource_name=MODEL_NAME)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("You are using the setting of {}".format(RESOURCE_NAME))
|
||||
logger.info("You are using the setting of {}".format(MODEL_NAME))
|
||||
|
||||
logger.info("*" * 80)
|
||||
logger.info("Prepare to download WikiTableQuestions from the official link...")
|
||||
|
|
|
@ -61,7 +61,7 @@ def set_predict_parser(parser_group):
|
|||
help="the model weight's name in the resource directory")
|
||||
|
||||
|
||||
def train_fariseq_model(args):
|
||||
def train_fairseq_model(args):
|
||||
cmd = f"""
|
||||
fairseq-train {args.dataset_dir}/bin \
|
||||
--save-dir {args.exp_dir} \
|
||||
|
@ -171,7 +171,7 @@ if __name__ == '__main__':
|
|||
|
||||
args = parser.parse_args()
|
||||
if args.subcommand == "train":
|
||||
train_fariseq_model(args)
|
||||
train_fairseq_model(args)
|
||||
elif args.subcommand == "eval":
|
||||
evaluate_fairseq_model(args)
|
||||
elif args.subcommand == "predict":
|
||||
|
|
|
@ -13,9 +13,9 @@ logger = logging.getLogger(__name__)
|
|||
# Resources are obtained and modified from https://github.com/pytorch/fairseq/tree/master/examples/bart
|
||||
RESOURCE_DICT = {
|
||||
"bart.large": "https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz",
|
||||
"tapex.bart.large": "https://github.com/microsoft/Table-Pretraining/releases/download/v1.0/tapex.large.tar.gz",
|
||||
"tapex.large": "https://github.com/microsoft/Table-Pretraining/releases/download/v1.0/tapex.large.tar.gz",
|
||||
"bart.base": "https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz",
|
||||
"tapex.bart.base": "https://github.com/microsoft/Table-Pretraining/releases/download/v1.0/tapex.base.tar.gz"
|
||||
"tapex.base": "https://github.com/microsoft/Table-Pretraining/releases/download/v1.0/tapex.base.tar.gz"
|
||||
}
|
||||
|
||||
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
|
||||
|
@ -33,6 +33,8 @@ def download_file(url, download_dir=None):
|
|||
os.makedirs(download_dir)
|
||||
with requests.get(url, stream=True) as r:
|
||||
file_name = os.path.join(download_dir, local_filename)
|
||||
if os.path.exists(file_name):
|
||||
os.remove(file_name)
|
||||
write_f = open(file_name, "wb")
|
||||
for data in tqdm(r.iter_content()):
|
||||
write_f.write(data)
|
||||
|
|
|
@ -10,76 +10,73 @@ from fairseq import options
|
|||
from fairseq_cli import preprocess
|
||||
|
||||
|
||||
def setup_translation_binary_arguments(args, data_dir, resource_dir, with_test_set):
|
||||
def setup_translation_binary_arguments(args, data_dir, resource_name, with_test_set):
|
||||
args.source_lang = args.source_lang if getattr(args, "source_lang") else "src"
|
||||
args.target_lang = args.target_lang if getattr(args, "target_lang") else "tgt"
|
||||
args.trainpref = args.trainpref if getattr(args, "trainpref") else os.path.join(data_dir, "train.bpe")
|
||||
args.validpref = args.validpref if getattr(args, "validpref") else os.path.join(data_dir, "valid.bpe")
|
||||
args.destdir = args.destdir if getattr(args, "destdir") not in [None, "data-bin"] else os.path.join(data_dir, "bin")
|
||||
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_dir, "dict.src.txt")
|
||||
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_dir, "dict.tgt.txt")
|
||||
args.destdir = args.destdir if getattr(args, "destdir") not in [None, "data-bin"]\
|
||||
else os.path.join(data_dir, resource_name, "bin")
|
||||
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_name, "dict.src.txt")
|
||||
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_name, "dict.tgt.txt")
|
||||
args.workers = args.workers if getattr(args, "workers") else 20
|
||||
|
||||
if with_test_set:
|
||||
args.testpref = args.testpref if getattr(args, "testpref") else os.path.join(data_dir, "test.bpe")
|
||||
|
||||
|
||||
def setup_class_input_binary_arguments(args, data_dir, resource_dir, with_test_set):
|
||||
def setup_class_input_binary_arguments(args, data_dir, resource_name, with_test_set):
|
||||
args.only_source = args.only_source if getattr(args, "only_source") else True
|
||||
args.trainpref = args.trainpref if getattr(args, "trainpref") else os.path.join(data_dir, "train.input0")
|
||||
args.validpref = args.validpref if getattr(args, "validpref") else os.path.join(data_dir, "valid.input0")
|
||||
args.destdir = args.destdir if getattr(args, "destdir") not in [None, "data-bin"] else os.path.join(data_dir, "input0")
|
||||
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_dir, "dict.src.txt")
|
||||
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_dir, "dict.tgt.txt")
|
||||
args.destdir = args.destdir if getattr(args, "destdir") not in [None, "data-bin"]\
|
||||
else os.path.join(data_dir, resource_name, "input0")
|
||||
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_name, "dict.src.txt")
|
||||
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_name, "dict.tgt.txt")
|
||||
args.workers = args.workers if getattr(args, "workers") else 20
|
||||
|
||||
if with_test_set:
|
||||
args.testpref = args.testpref if getattr(args, "testpref") else os.path.join(data_dir, "test.input0")
|
||||
|
||||
|
||||
def setup_class_label_binary_arguments(args, data_dir, resource_dir, with_test_set):
|
||||
def setup_class_label_binary_arguments(args, data_dir, resource_name, with_test_set):
|
||||
args.only_source = args.only_source if getattr(args, "only_source") else True
|
||||
args.trainpref = args.trainpref if getattr(args, "trainpref") else os.path.join(data_dir, "train.label")
|
||||
args.validpref = args.validpref if getattr(args, "validpref") else os.path.join(data_dir, "valid.label")
|
||||
args.destdir = args.destdir if getattr(args, "destdir") else os.path.join(data_dir, "label")
|
||||
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_dir, "dict.src.txt")
|
||||
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_dir, "dict.tgt.txt")
|
||||
args.destdir = args.destdir if getattr(args, "destdir") not in [None, "data-bin"]\
|
||||
else os.path.join(data_dir, resource_name, "label")
|
||||
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_name, "dict.src.txt")
|
||||
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_name, "dict.tgt.txt")
|
||||
args.workers = args.workers if getattr(args, "workers") else 20
|
||||
|
||||
if with_test_set:
|
||||
args.testpref = args.testpref if getattr(args, "testpref") else os.path.join(data_dir, "test.input0")
|
||||
|
||||
|
||||
def fairseq_binary_translation(data_dir, resource_name, resource_dir=None, with_test_set=True):
|
||||
def fairseq_binary_translation(data_dir, resource_name, with_test_set=True):
|
||||
"""
|
||||
Execute fairseq using default arguments, more arguments can be seen at
|
||||
https://fairseq.readthedocs.io/en/latest/
|
||||
:return:
|
||||
"""
|
||||
if resource_dir is None:
|
||||
resource_dir = resource_name
|
||||
|
||||
# preprocess data_folder
|
||||
parser = options.get_preprocessing_parser()
|
||||
args = parser.parse_args()
|
||||
setup_translation_binary_arguments(args, data_dir, resource_dir, with_test_set)
|
||||
setup_translation_binary_arguments(args, data_dir, resource_name, with_test_set)
|
||||
# pass by args and preprocess the dataset
|
||||
preprocess.main(args)
|
||||
|
||||
|
||||
def fairseq_binary_classification(data_dir, resource_name, resource_dir=None, with_test_set=True):
|
||||
def fairseq_binary_classification(data_dir, resource_name, with_test_set=True):
|
||||
"""
|
||||
Execute fairseq using default arguments, more arguments can be seen at
|
||||
https://fairseq.readthedocs.io/en/latest/
|
||||
:return:
|
||||
"""
|
||||
if resource_dir is None:
|
||||
resource_dir = resource_name
|
||||
|
||||
# preprocess data_folder
|
||||
parser = options.get_preprocessing_parser()
|
||||
args = parser.parse_args()
|
||||
setup_class_input_binary_arguments(args, data_dir, resource_dir, with_test_set)
|
||||
setup_class_input_binary_arguments(args, data_dir, resource_name, with_test_set)
|
||||
preprocess.main(args)
|
||||
setup_class_label_binary_arguments(args, data_dir, resource_dir, with_test_set)
|
||||
setup_class_label_binary_arguments(args, data_dir, resource_name, with_test_set)
|
||||
preprocess.main(args)
|
||||
|
|
|
@ -86,15 +86,15 @@ def fairseq_bpe_translation(data_dir, resource_name, resource_dir=None, with_tes
|
|||
:param data_dir: the directory which stores the dataset files, including `train.src`, `train.tgt` and so on.
|
||||
:param resource_dir: the cached directory for `resource_name`.
|
||||
:param resource_name: corresponding resource files will be automatically downloaded by specifying this parameter.
|
||||
You must select one from the choices of `bart.base`, `bart.large`, `tapex.bart.base` and `tapex.bart.large`.
|
||||
You must select one from the choices of `bart.base`, `bart.large`, `tapex.base` and `tapex.large`.
|
||||
:param with_test_set: if true, process the test set; otherwise not.
|
||||
"""
|
||||
if resource_dir is None:
|
||||
resource_dir = os.path.abspath(resource_name)
|
||||
|
||||
assert resource_name in ["bart.base", "bart.large", "tapex.bart.base", "tapex.bart.large"],\
|
||||
assert resource_name in ["bart.base", "bart.large", "tapex.base", "tapex.large"],\
|
||||
"You must specify `download_resource_from` in " \
|
||||
"`bart.base`, `bart.large`, `tapex.bart.base` and `tapex.bart.large`."
|
||||
"`bart.base`, `bart.large`, `tapex.base` and `tapex.large`."
|
||||
|
||||
if not os.path.exists(os.path.join(resource_dir, "model.pt")):
|
||||
# download file into resource folder
|
||||
|
@ -150,15 +150,15 @@ def fairseq_bpe_classification(data_dir, resource_name, resource_dir=None, with_
|
|||
:param data_dir: the directory which stores the dataset files, including `train.src`, `train.tgt` and so on.
|
||||
:param resource_dir: the cached folder for `resource_name`.
|
||||
:param resource_name: corresponding resource files will be automatically downloaded by specifying this parameter.
|
||||
You must select one from the choices of `bart.base`, `bart.large`, `tapex.bart.base` and `tapex.bart.large`.
|
||||
You must select one from the choices of `bart.base`, `bart.large`, `tapex.base` and `tapex.large`.
|
||||
:param with_test_set: if true, process the test set; otherwise not.
|
||||
"""
|
||||
if resource_dir is None:
|
||||
resource_dir = os.path.abspath(resource_name)
|
||||
|
||||
assert resource_name in ["bart.base", "bart.large", "tapex.bart.base", "tapex.bart.large"], \
|
||||
assert resource_name in ["bart.base", "bart.large", "tapex.base", "tapex.large"], \
|
||||
"You must specify `download_resource_from` in " \
|
||||
"`bart.base`, `bart.large`, `tapex.bart.base` and `tapex.bart.large`."
|
||||
"`bart.base`, `bart.large`, `tapex.base` and `tapex.large`."
|
||||
|
||||
if not os.path.exists(os.path.join(resource_dir, "model.pt")):
|
||||
# download file into resource folder
|
||||
|
|
Загрузка…
Ссылка в новой задаче