NeuronBlocks/data_encoding.py

61 строка
3.1 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
# add the project root to python path
import os
from settings import ProblemTypes, version
import argparse
import logging
from ModelConf import ModelConf
from problem import Problem
from utils.common_utils import log_set, dump_to_pkl, load_from_pkl
def main(params, data_path, save_path):
conf = ModelConf("cache", params.conf_path, version, params)
if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
if os.path.isfile(conf.problem_path):
problem.load_problem(conf.problem_path)
logging.info("Cache loaded!")
logging.debug("Cache loaded from %s" % conf.problem_path)
else:
raise Exception("Cache does not exist!")
data, length, target = problem.encode(data_path, conf.file_columns, conf.input_types, conf.file_with_col_header,
conf.object_inputs, conf.answer_column_name, conf.min_sentence_len,
extra_feature=conf.extra_feature,max_lengths=conf.max_lengths, file_format='tsv',
cpu_num_workers=conf.cpu_num_workers)
if not os.path.isdir(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
dump_to_pkl({'data': data, 'length': length, 'target': target}, save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Data encoding')
parser.add_argument("data_path", type=str)
parser.add_argument("save_path", type=str)
parser.add_argument("--conf_path", type=str, default='conf.json', help="configuration path")
parser.add_argument("--debug", type=bool, default=False)
parser.add_argument("--force", type=bool, default=False)
log_set('encoding_data.log')
params, _ = parser.parse_known_args()
if params.debug is True:
import debugger
main(params, params.data_path, params.save_path)