Fix a little issue on tapex lib.

This commit is contained in:
SivilTaram 2021-10-01 18:32:33 +08:00
Родитель d3737c9868
Коммит 61146c5b3c
1 изменённых файлов: 13 добавлений и 18 удалений

Просмотреть файл

@ -26,31 +26,26 @@ def setup_translation_binary_arguments(args, data_dir, resource_name, with_test_
def setup_class_input_binary_arguments(args, data_dir, resource_name, with_test_set):
args.only_source = args.only_source if getattr(args, "only_source") else True
args.trainpref = args.trainpref if getattr(args, "trainpref") else os.path.join(data_dir, "train.input0")
args.validpref = args.validpref if getattr(args, "validpref") else os.path.join(data_dir, "valid.input0")
args.destdir = args.destdir if getattr(args, "destdir") not in [None, "data-bin"]\
else os.path.join(data_dir, resource_name, "input0")
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_name, "dict.src.txt")
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_name, "dict.tgt.txt")
args.only_source = True
args.trainpref = os.path.join(data_dir, "train.input0")
args.validpref = os.path.join(data_dir, "valid.input0")
args.destdir = os.path.join(data_dir, resource_name, "input0")
args.srcdict = os.path.join(resource_name, "dict.src.txt")
args.workers = args.workers if getattr(args, "workers") else 20
if with_test_set:
args.testpref = args.testpref if getattr(args, "testpref") else os.path.join(data_dir, "test.input0")
args.testpref = os.path.join(data_dir, "test.input0")
def setup_class_label_binary_arguments(args, data_dir, resource_name, with_test_set):
args.only_source = args.only_source if getattr(args, "only_source") else True
args.trainpref = args.trainpref if getattr(args, "trainpref") else os.path.join(data_dir, "train.label")
args.validpref = args.validpref if getattr(args, "validpref") else os.path.join(data_dir, "valid.label")
args.destdir = args.destdir if getattr(args, "destdir") not in [None, "data-bin"]\
else os.path.join(data_dir, resource_name, "label")
args.srcdict = args.srcdict if getattr(args, "srcdict") else os.path.join(resource_name, "dict.src.txt")
args.tgtdict = args.tgtdict if getattr(args, "tgtdict") else os.path.join(resource_name, "dict.tgt.txt")
args.only_source = True
args.trainpref = os.path.join(data_dir, "train.label")
args.validpref = os.path.join(data_dir, "valid.label")
args.destdir = os.path.join(data_dir, resource_name, "label")
args.workers = args.workers if getattr(args, "workers") else 20
if with_test_set:
args.testpref = args.testpref if getattr(args, "testpref") else os.path.join(data_dir, "test.input0")
args.testpref = os.path.join(data_dir, "test.label")
def fairseq_binary_translation(data_dir, resource_name, with_test_set=True):
@ -76,7 +71,7 @@ def fairseq_binary_classification(data_dir, resource_name, with_test_set=True):
# preprocess data_folder
parser = options.get_preprocessing_parser()
args = parser.parse_args()
setup_class_input_binary_arguments(args, data_dir, resource_name, with_test_set)
preprocess.main(args)
setup_class_label_binary_arguments(args, data_dir, resource_name, with_test_set)
preprocess.main(args)
setup_class_input_binary_arguments(args, data_dir, resource_name, with_test_set)
preprocess.main(args)