From a291d40aac9deb8374126bd0a34ac55ec80b038b Mon Sep 17 00:00:00 2001 From: Flyer Cheng Date: Tue, 3 Sep 2019 14:15:25 +0800 Subject: [PATCH] ModelConf Reorg (#80) --- ModelConf.py | 685 +++++++++++++++++++++++------------------- utils/common_utils.py | 26 +- 2 files changed, 388 insertions(+), 323 deletions(-) diff --git a/ModelConf.py b/ModelConf.py index 1b30ecf..bb77273 100644 --- a/ModelConf.py +++ b/ModelConf.py @@ -15,10 +15,56 @@ import shutil from losses.BaseLossConf import BaseLossConf #import traceback from settings import LanguageTypes, ProblemTypes, TaggingSchemes, SupportedMetrics, PredictionTypes, DefaultPredictionFields, ConstantStatic -from utils.common_utils import log_set, prepare_dir, md5 +from utils.common_utils import log_set, prepare_dir, md5, load_from_json, dump_to_json from utils.exceptions import ConfigurationError import numpy as np +class ConstantStaticItems(ConstantStatic): + @staticmethod + def concat_key_desc(key_prefix_desc, key): + return key_prefix_desc + '.' + key + + @staticmethod + def get_value_by_key(json, key, key_prefix='', use_default=False, default=None): + """ + Args: + json: a json object + key: a key pointing to the value wanted to acquire + use_default: if you really want to use default value when key can not be found in json object, set use_default=True + default: if key is not found and default is None, we would raise an Exception, except that use_default is True + Returns: + value: + """ + try: + value = json[key] + except: + if not use_default: + raise ConfigurationError("key[%s] can not be found in configuration file" % (key_prefix + key)) + else: + value = default + return value + + @staticmethod + def add_item(item_name, use_default=False, default=None): + def add_item_loading_func(use_default, default, func_get_value_by_key): + @classmethod + def load_data(cls, obj, json, key_prefix_desc='', use_default=use_default, default=default, func_get_value_by_key=func_get_value_by_key): + obj.__dict__[cls.__name__] = func_get_value_by_key(json, cls.__name__, key_prefix_desc, use_default, default) + return obj + return load_data + return type(item_name, (ConstantStatic, ), dict(load_data=add_item_loading_func(use_default, default, __class__.get_value_by_key))) + + @classmethod + def load_data(cls, obj, json, key_prefix_desc=''): + if cls.__name__ in json.keys(): + json = json[cls.__name__] + for key in cls.__dict__.keys(): + if not hasattr(cls.__dict__[key], 'load_data'): + continue + item = cls.__dict__[key] + obj = item.load_data(obj, json, cls.concat_key_desc(key_prefix_desc, item.__name__)) + return obj + class ModelConf(object): def __init__(self, phase, conf_path, nb_version, params=None, mode='normal'): """ loading configuration from configuration file and argparse parameters @@ -36,6 +82,7 @@ class ModelConf(object): self.params = params self.mode = mode.lower() assert self.mode in set(['normal', 'philly']), 'Your mode %s is illegal, supported modes are: normal and philly!' + self.load_from_file(conf_path) self.check_version_compat(nb_version, self.tool_version) @@ -51,321 +98,335 @@ class ModelConf(object): logging.debug('%s: %s' % (str(name), str(value))) logging.debug('=' * 80) + class Conf(ConstantStaticItems): + license = ConstantStaticItems.add_item('license') + tool_version = ConstantStaticItems.add_item('tool_version') + model_description = ConstantStaticItems.add_item('model_description') + language = ConstantStaticItems.add_item('language', use_default=True, default='english') + + class inputs(ConstantStaticItems): + use_cache = ConstantStaticItems.add_item('use_cache', use_default=True, default=True) + dataset_type = ConstantStaticItems.add_item('dataset_type') + tagging_scheme = ConstantStaticItems.add_item('tagging_scheme', use_default=True, default=None) + + class data_paths(ConstantStaticItems): + train_data_path = ConstantStaticItems.add_item('train_data_path', use_default=True, default=None) + valid_data_path = ConstantStaticItems.add_item('valid_data_path', use_default=True, default=None) + test_data_path = ConstantStaticItems.add_item('test_data_path', use_default=True, default=None) + predict_data_path = ConstantStaticItems.add_item('predict_data_path', use_default=True, default=None) + pre_trained_emb = ConstantStaticItems.add_item('pre_trained_emb', use_default=True, default=None) + pretrained_model_path = ConstantStaticItems.add_item('pretrained_model_path', use_default=True, default=None) + + file_with_col_header = ConstantStaticItems.add_item('file_with_col_header', use_default=True, default=False) + pretrained_emb_type = ConstantStaticItems.add_item('pretrained_emb_type', use_default=True, default='glove') + pretrained_emb_binary_or_text = ConstantStaticItems.add_item('pretrained_emb_binary_or_text', use_default=True, default='text') + involve_all_words_in_pretrained_emb = ConstantStaticItems.add_item('involve_all_words_in_pretrained_emb', use_default=True, default=False) + add_start_end_for_seq = ConstantStaticItems.add_item('add_start_end_for_seq', use_default=True, default=False) + file_header = ConstantStaticItems.add_item('file_header', use_default=True, default=None) + predict_file_header = ConstantStaticItems.add_item('predict_file_header', use_default=True, default=None) + model_inputs = ConstantStaticItems.add_item('model_inputs') + target = ConstantStaticItems.add_item('target', use_default=True, default=None) + positive_label = ConstantStaticItems.add_item('positive_label', use_default=True, default=None) + + class outputs(ConstantStaticItems): + save_base_dir = ConstantStaticItems.add_item('save_base_dir', use_default=True, default=None) + model_name = ConstantStaticItems.add_item('model_name') + + train_log_name = ConstantStaticItems.add_item('train_log_name', use_default=True, default=None) + test_log_name = ConstantStaticItems.add_item('test_log_name', use_default=True, default=None) + predict_log_name = ConstantStaticItems.add_item('predict_log_name', use_default=True, default=None) + predict_fields = ConstantStaticItems.add_item('predict_fields', use_default=True, default=None) + predict_output_name = ConstantStaticItems.add_item('predict_output_name', use_default=True, default='predict.tsv') + cache_dir = ConstantStaticItems.add_item('cache_dir', use_default=True, default=None) + + class training_params(ConstantStaticItems): + class vocabulary(ConstantStaticItems): + min_word_frequency = ConstantStaticItems.add_item('min_word_frequency', use_default=True, default=3) + max_vocabulary = ConstantStaticItems.add_item('max_vocabulary', use_default=True, default=800 * 1000) + max_building_lines = ConstantStaticItems.add_item('max_building_lines', use_default=True, default=1000 * 1000) + + optimizer = ConstantStaticItems.add_item('optimizer', use_default=True, default=None) + clip_grad_norm_max_norm = ConstantStaticItems.add_item('clip_grad_norm_max_norm', use_default=True, default=-1) + chunk_size = ConstantStaticItems.add_item('chunk_size', use_default=True, default=1000 * 1000) + lr_decay = ConstantStaticItems.add_item('lr_decay', use_default=True, default=1) + minimum_lr = ConstantStaticItems.add_item('minimum_lr', use_default=True, default=0) + epoch_start_lr_decay = ConstantStaticItems.add_item('epoch_start_lr_decay', use_default=True, default=1) + use_gpu = ConstantStaticItems.add_item('use_gpu', use_default=True, default=False) + cpu_num_workers = ConstantStaticItems.add_item('cpu_num_workers', use_default=True, default=-1) #by default, use all workers cpu supports + batch_size = ConstantStaticItems.add_item('batch_size', use_default=True, default=1) + batch_num_to_show_results = ConstantStaticItems.add_item('batch_num_to_show_results', use_default=True, default=10) + max_epoch = ConstantStaticItems.add_item('max_epoch', use_default=True, default=float('inf')) + valid_times_per_epoch = ConstantStaticItems.add_item('valid_times_per_epoch', use_default=True, default=None) + steps_per_validation = ConstantStaticItems.add_item('steps_per_validation', use_default=True, default=10) + text_preprocessing = ConstantStaticItems.add_item('text_preprocessing', use_default=True, default=list()) + max_lengths = ConstantStaticItems.add_item('max_lengths', use_default=True, default=None) + fixed_lengths = ConstantStaticItems.add_item('fixed_lengths', use_default=True, default=None) + tokenizer = ConstantStaticItems.add_item('tokenizer', use_default=True, default=None) + + architecture = ConstantStaticItems.add_item('architecture') + loss = ConstantStaticItems.add_item('loss', use_default=True, default=None) + metrics = ConstantStaticItems.add_item('metrics', use_default=True, default=None) + + def raise_configuration_error(self, key): + raise ConfigurationError( + "The configuration file %s is illegal. the item [%s] is not found." % (self.conf_path, key)) + def load_from_file(self, conf_path): - with codecs.open(conf_path, 'r', encoding='utf-8') as fin: - try: - self.conf = json.load(fin) - except Exception as e: - raise ConfigurationError("%s is not a legal JSON file, please check your JSON format!" % conf_path) + # load file + self.conf = load_from_json(conf_path, debug=False) + self = self.Conf.load_data(self, {'Conf' : self.conf}, key_prefix_desc='Conf') + self.language = self.language.lower() + self.configurate_outputs() + self.configurate_inputs() + self.configurate_training_params() + self.configurate_architecture() + self.configurate_loss() + self.configurate_cache() - self.tool_version = self.get_item(['tool_version']) - self.language = self.get_item(['language'], default='english').lower() - self.problem_type = self.get_item(['inputs', 'dataset_type']).lower() - #if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging: - self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True) + def configurate_outputs(self): + def configurate_logger(self): + if self.phase == 'cache': + return - if self.mode == 'normal': - self.use_cache = self.get_item(['inputs', 'use_cache'], True) - elif self.mode == 'philly': - self.use_cache = True + # dir + if hasattr(self.params, 'log_dir') and self.params.log_dir: + self.log_dir = self.params.log_dir + prepare_dir(self.log_dir, True, allow_overwrite=True) + else: + self.log_dir = self.save_base_dir + + # path + self.train_log_path = os.path.join(self.log_dir, self.train_log_name) + self.test_log_path = os.path.join(self.log_dir, self.test_log_name) + self.predict_log_path = os.path.join(self.log_dir, self.predict_log_name) + if self.phase == 'train': + log_path = self.train_log_path + elif self.phase == 'test': + log_path = self.test_log_path + elif self.phase == 'predict': + log_path = self.predict_log_path + if log_path is None: + self.raise_configuration_error(self.phase + '_log_name') - # OUTPUTS + # log level + if self.mode == 'philly' or self.params.debug: + log_set(log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file) + else: + log_set(log_path, disable_log_file=self.params.disable_log_file) + + # save base dir if hasattr(self.params, 'model_save_dir') and self.params.model_save_dir: self.save_base_dir = self.params.model_save_dir - else: - self.save_base_dir = self.get_item(['outputs', 'save_base_dir']) - - if self.phase == 'train': - # in train.py, it is called pretrained_model_path - if hasattr(self.params, 'pretrained_model_path') and self.params.pretrained_model_path: - self.pretrained_model_path = self.previous_model_path = self.params.pretrained_model_path - else: - self.pretrained_model_path = self.previous_model_path = self.get_item(['inputs', 'data_paths', 'pretrained_model_path'], default=None, use_default=True) - elif self.phase == 'test' or self.phase == 'predict': - # in test.py and predict.py, it is called pretrained_model_path - if hasattr(self.params, 'previous_model_path') and self.params.previous_model_path: - self.previous_model_path = self.pretrained_model_path = self.params.previous_model_path - else: - self.previous_model_path = self.pretrained_model_path = os.path.join(self.save_base_dir, self.get_item(['outputs', 'model_name'])) # namely, the model_save_path - - if hasattr(self, 'pretrained_model_path') and self.pretrained_model_path: # namely self.previous_model_path - tmp_saved_problem_path = os.path.join(os.path.dirname(self.pretrained_model_path), '.necessary_cache', 'problem.pkl') - self.saved_problem_path = tmp_saved_problem_path if os.path.isfile(tmp_saved_problem_path) \ - else os.path.join(os.path.dirname(self.pretrained_model_path), 'necessary_cache', 'problem.pkl') - if not (os.path.isfile(self.pretrained_model_path) and os.path.isfile(self.saved_problem_path)): - raise Exception('Previous trained model %s or its dictionaries %s does not exist!' % (self.pretrained_model_path, self.saved_problem_path)) + elif self.save_base_dir is None: + self.raise_configuration_error('save_base_dir') + # prepare save base dir if self.phase != 'cache': prepare_dir(self.save_base_dir, True, allow_overwrite=self.params.force or self.mode == 'philly', extra_info='will overwrite model file and train.log' if self.phase=='train' else 'will add %s.log and predict file'%self.phase) - if hasattr(self.params, 'log_dir') and self.params.log_dir: - self.log_dir = self.params.log_dir - if self.phase != 'cache': - prepare_dir(self.log_dir, True, allow_overwrite=True) - else: - self.log_dir = self.save_base_dir + # logger + configurate_logger(self) - if self.phase == 'train': - self.train_log_path = os.path.join(self.log_dir, self.get_item(['outputs', 'train_log_name'])) - if self.mode == 'philly' or self.params.debug: - log_set(self.train_log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file) - else: - log_set(self.train_log_path, disable_log_file=self.params.disable_log_file) - elif self.phase == 'test': - self.test_log_path = os.path.join(self.log_dir, self.get_item(['outputs', 'test_log_name'])) - if self.mode == 'philly' or self.params.debug: - log_set(self.test_log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file) - else: - log_set(self.test_log_path, disable_log_file=self.params.disable_log_file) - elif self.phase == 'predict': - self.predict_log_path = os.path.join(self.log_dir, self.get_item(['outputs', 'predict_log_name'])) - if self.mode == 'philly' or self.params.debug: - log_set(self.predict_log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file) - else: - log_set(self.predict_log_path, disable_log_file=self.params.disable_log_file) + # predict output path if self.phase != 'cache': - self.predict_output_path = self.params.predict_output_path if self.params.predict_output_path else os.path.join(self.save_base_dir, self.get_item(['outputs', 'predict_output_name'], default='predict.tsv')) + if self.params.predict_output_path: + self.predict_output_path = self.params.predict_output_path + else: + self.predict_output_path = os.path.join(self.save_base_dir, self.predict_output_name) logging.debug('Prepare dir for: %s' % self.predict_output_path) prepare_dir(self.predict_output_path, False, allow_overwrite=self.params.force or self.mode == 'philly') - self.predict_fields = self.get_item(['outputs', 'predict_fields'], default=DefaultPredictionFields[ProblemTypes[self.problem_type]]) - self.model_save_path = os.path.join(self.save_base_dir, self.get_item(['outputs', 'model_name'])) + if self.predict_fields is None: + self.predict_fields = DefaultPredictionFields[ProblemTypes[self.problem_type]] - # INPUTS - if hasattr(self.params, 'train_data_path') and self.params.train_data_path: - self.train_data_path = self.params.train_data_path - else: - if self.mode == 'normal': - self.train_data_path = self.get_item(['inputs', 'data_paths', 'train_data_path'], default=None, use_default=True) - else: + self.model_save_path = os.path.join(self.save_base_dir, self.model_name) + + def configurate_inputs(self): + + def configurate_data_path(self): + self.pretrained_emb_path =self.pre_trained_emb + + if self.mode != "normal": self.train_data_path = None - if hasattr(self.params, 'valid_data_path') and self.params.valid_data_path: - self.valid_data_path = self.params.valid_data_path - else: - if self.mode == 'normal': - self.valid_data_path = self.get_item(['inputs', 'data_paths', 'valid_data_path'], default=None, use_default=True) - else: self.valid_data_path = None - if hasattr(self.params, 'test_data_path') and self.params.test_data_path: - self.test_data_path = self.params.test_data_path - else: - if self.mode == 'normal': - self.test_data_path = self.get_item(['inputs', 'data_paths', 'test_data_path'], default=None, use_default=True) - else: self.test_data_path = None - - if self.phase == 'predict': - if self.params.predict_data_path: - self.predict_data_path = self.params.predict_data_path - else: - if self.mode == 'normal': - self.predict_data_path = self.get_item(['inputs', 'data_paths', 'predict_data_path'], default=None, use_default=True) - else: - self.predict_data_path = None - - if self.phase == 'train' or self.phase == 'cache': - if self.valid_data_path is None and self.test_data_path is not None: - # We support test_data_path == None, if someone set valid_data_path to None while test_data_path is not None, - # swap the valid_data_path and test_data_path - self.valid_data_path = self.test_data_path - self.test_data_path = None - elif self.phase == 'predict': - if self.predict_data_path is None and self.test_data_path is not None: - self.predict_data_path = self.test_data_path - self.test_data_path = None - - if self.phase == 'train' or self.phase == 'test' or self.phase == 'cache': - self.file_columns = self.get_item(['inputs', 'file_header']) - else: - self.file_columns = self.get_item(['inputs', 'file_header'], default=None, use_default=True) - - if self.phase == 'predict': - if self.file_columns is None: - self.predict_file_columns = self.get_item(['inputs', 'predict_file_header']) - else: - self.predict_file_columns = self.get_item(['inputs', 'predict_file_header'], default=None, use_default=True) - if self.predict_file_columns is None: - self.predict_file_columns = self.file_columns - - if self.phase != 'predict': - if self.phase == 'cache': - self.answer_column_name = self.get_item(['inputs', 'target'], default=None, use_default=True) - else: - self.answer_column_name = self.get_item(['inputs', 'target']) - self.input_types = self.get_item(['architecture', 0, 'conf']) - # add extra feature - feature_all = set([_.lower() for _ in self.input_types.keys()]) - formal_feature = set(['word', 'char']) - self.extra_feature = len(feature_all - formal_feature) != 0 - - # add char embedding config - # char_emb_type = None - # char_emb_type_cols = None - # for single_type in self.input_types: - # if single_type.lower() == 'char': - # char_emb_type = single_type - # char_emb_type_cols = [single_col.lower() for single_col in self.input_types[single_type]['cols']] - # break - self.object_inputs = self.get_item(['inputs', 'model_inputs']) - # if char_emb_type and char_emb_type_cols: - # for single_input in self.object_inputs: - # for single_col in char_emb_type_cols: - # if single_input.lower() in single_col: - # self.object_inputs[single_input].append(single_col) - - self.object_inputs_names = [name for name in self.object_inputs] - - # vocabulary setting - self.max_vocabulary = self.get_item(['training_params', 'vocabulary', 'max_vocabulary'], default=800000, use_default=True) - self.min_word_frequency = self.get_item(['training_params', 'vocabulary', 'min_word_frequency'], default=3, use_default=True) - self.max_building_lines = self.get_item(['training_params', 'vocabulary', 'max_building_lines'], default=1000 * 1000, use_default=True) - - # chunk_size - self.chunk_size = self.get_item(['training_params', 'chunk_size'], default=1000 * 1000, use_default=True) - - # file column header setting - self.file_with_col_header = self.get_item(['inputs', 'file_with_col_header'], default=False, use_default=True) - - if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging: - self.add_start_end_for_seq = self.get_item(['inputs', 'add_start_end_for_seq'], default=True) - else: - self.add_start_end_for_seq = self.get_item(['inputs', 'add_start_end_for_seq'], default=False) - - if hasattr(self.params, 'pretrained_emb_path') and self.params.pretrained_emb_path: - self.pretrained_emb_path = self.params.pretrained_emb_path - else: - if self.mode == 'normal': - self.pretrained_emb_path = self.get_item(['inputs', 'data_paths', 'pre_trained_emb'], default=None, use_default=True) - else: + self.predict_data_path = None self.pretrained_emb_path = None - if 'word' in self.get_item(['architecture', 0, 'conf']) and self.pretrained_emb_path: - if hasattr(self.params, 'involve_all_words_in_pretrained_emb') and self.params.involve_all_words_in_pretrained_emb: - self.involve_all_words_in_pretrained_emb = self.params.involve_all_words_in_pretrained_emb + if hasattr(self.params, 'train_data_path') and self.params.train_data_path: + self.train_data_path = self.params.train_data_path + if hasattr(self.params, 'valid_data_path') and self.params.valid_data_path: + self.valid_data_path = self.params.valid_data_path + if hasattr(self.params, 'test_data_path') and self.params.test_data_path: + self.test_data_path = self.params.test_data_path + if hasattr(self.params, 'predict_data_path') and self.params.predict_data_path: + self.predict_data_path = self.params.predict_data_path + if hasattr(self.params, 'pretrained_emb_path') and self.params.pretrained_emb_path: + self.pretrained_emb_path = self.params.pretrained_emb_path + + if self.phase == 'train' or self.phase == 'cache': + if self.valid_data_path is None and self.test_data_path is not None: + # We support test_data_path == None, if someone set valid_data_path to None while test_data_path is not None, + # swap the valid_data_path and test_data_path + self.valid_data_path = self.test_data_path + self.test_data_path = None + elif self.phase == 'predict': + if self.predict_data_path is None and self.test_data_path is not None: + self.predict_data_path = self.test_data_path + self.test_data_path = None + + return self + + def configurate_data_format(self): + # file columns + if self.phase == 'train' or self.phase == 'test' or self.phase == 'cache': + self.file_columns = self.file_header + if self.file_columns is None: + self.raise_configuration_error('file_columns') + if self.phase == 'predict': + if self.file_columns is None and self.predict_file_columns is None: + self.raise_configuration_error('predict_file_columns') + if self.file_columns and self.predict_file_columns is None: + self.predict_file_columns = self.file_columns + + # target + if self.phase != 'predict': + self.answer_column_name = self.target + if self.target is None and self.phase != 'cache': + self.raise_configuration_error('target') + + if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging and self.add_start_end_for_seq is None: + self.add_start_end_for_seq = True + + # pretrained embedding + if 'word' in self.architecture[0]['conf'] and self.pretrained_emb_path: + if hasattr(self.params, 'involve_all_words_in_pretrained_emb') and self.params.involve_all_words_in_pretrained_emb: + self.involve_all_words_in_pretrained_emb = self.params.involve_all_words_in_pretrained_emb + if hasattr(self.params, 'pretrained_emb_type') and self.params.pretrained_emb_type: + self.pretrained_emb_type = self.params.pretrained_emb_type + if hasattr(self.params, 'pretrained_emb_binary_or_text') and self.params.pretrained_emb_binary_or_text: + self.pretrained_emb_binary_or_text = self.params.pretrained_emb_binary_or_text + self.pretrained_emb_dim = self.architecture[0]['conf']['word']['dim'] else: - self.involve_all_words_in_pretrained_emb = self.get_item(['inputs', 'involve_all_words_in_pretrained_emb'], default=False) - if hasattr(self.params, 'pretrained_emb_type') and self.params.pretrained_emb_type: - self.pretrained_emb_type = self.params.pretrained_emb_type - else: - self.pretrained_emb_type = self.get_item(['inputs', 'pretrained_emb_type'], default='glove') - if hasattr(self.params, 'pretrained_emb_binary_or_text') and self.params.pretrained_emb_binary_or_text: - self.pretrained_emb_binary_or_text = self.params.pretrained_emb_binary_or_text - else: - self.pretrained_emb_binary_or_text = self.get_item(['inputs', 'pretrained_emb_binary_or_text'], default='text') - self.pretrained_emb_dim = self.get_item(['architecture', 0, 'conf', 'word', 'dim']) + self.pretrained_emb_path = None + self.involve_all_words_in_pretrained_emb = None + self.pretrained_emb_type = None + self.pretrained_emb_binary_or_text = None + self.pretrained_emb_dim = None + + return self + + def configurate_model_input(self): + self.object_inputs = self.model_inputs + self.object_inputs_names = [name for name in self.object_inputs] + + return self + + self.problem_type = self.dataset_type.lower() + + # previous model path + if hasattr(self.params, 'previous_model_path') and self.params.previous_model_path: + self.previous_model_path = self.params.previous_model_path else: - self.pretrained_emb_path = None - self.involve_all_words_in_pretrained_emb = None - self.pretrained_emb_binary_or_text = None - self.pretrained_emb_dim = None - self.pretrained_emb_type = None + self.previous_model_path = os.path.join(self.save_base_dir, self.model_name) + # pretrained model path + if hasattr(self.params, 'pretrained_model_path') and self.params.pretrained_model_path: + self.pretrained_model_path = self.params.pretrained_model_path + + # saved problem path + model_path = None if self.phase == 'train': - if hasattr(self.params, 'cache_dir') and self.params.cache_dir: - # for aether - self.cache_dir = self.params.cache_dir - else: - if self.mode == 'normal': - if self.use_cache: - self.cache_dir = self.get_item(['outputs', 'cache_dir']) - else: - self.cache_dir = os.path.join(tempfile.gettempdir(), 'neuron_blocks', ''.join(random.sample(string.ascii_letters+string.digits, 16))) - else: - # for philly mode, we can only save files in model_path or scratch_path - self.cache_dir = os.path.join(self.save_base_dir, 'cache') + model_path = self.pretrained_model_path + elif self.phase == 'test' or self.phase == 'predict': + model_path = self.previous_model_path + if model_path: + model_path_dir = os.path.dirname(model_path) + self.saved_problem_path = os.path.join(model_path_dir, '.necessary_cache', 'problem.pkl') + if not os.path.isfile(self.saved_problem_path): + self.saved_problem_path = os.path.join(model_path_dir, 'necessary_cache', 'problem.pkl') + if not (os.path.isfile(model_path) and os.path.isfile(self.saved_problem_path)): + raise Exception('Previous trained model %s or its dictionaries %s does not exist!' % (model_path, self.saved_problem_path)) - self.problem_path = os.path.join(self.cache_dir, 'problem.pkl') - if self.pretrained_emb_path is not None: - self.emb_pkl_path = os.path.join(self.cache_dir, 'emb.pkl') - else: - self.emb_pkl_path = None - else: - tmp_problem_path = os.path.join(self.save_base_dir, '.necessary_cache', 'problem.pkl') - self.problem_path = tmp_problem_path if os.path.isfile(tmp_problem_path) else os.path.join(self.save_base_dir, 'necessary_cache', 'problem.pkl') - - # cache configuration - self._load_cache_config_from_conf() - - # training params - self.training_params = self.get_item(['training_params']) + configurate_data_path(self) + configurate_data_format(self) + configurate_model_input(self) + def configurate_training_params(self): + # optimizer if self.phase == 'train': - self.optimizer_name = self.get_item(['training_params', 'optimizer', 'name']) - self.optimizer_params = self.get_item(['training_params', 'optimizer', 'params']) - - self.clip_grad_norm_max_norm = self.get_item(['training_params', 'clip_grad_norm_max_norm'], default=-1) - + if self.optimizer is None: + self.raise_configuration_error('training_params.optimizer') + if 'name' not in self.optimizer.keys(): + self.raise_configuration_error('training_params.optimizer.name') + self.optimizer_name = self.optimizer['name'] + if 'params' not in self.optimizer.keys(): + self.raise_configuration_error('training_params.optimizer.params') + self.optimizer_params = self.optimizer['params'] if hasattr(self.params, 'learning_rate') and self.params.learning_rate: self.optimizer_params['lr'] = self.params.learning_rate - + + # batch size + self.batch_size_each_gpu = self.batch_size # the batch_size in conf file is the batch_size on each GPU if hasattr(self.params, 'batch_size') and self.params.batch_size: self.batch_size_each_gpu = self.params.batch_size - else: - self.batch_size_each_gpu = self.get_item(['training_params', 'batch_size']) #the batch_size in conf file is the batch_size on each GPU - self.lr_decay = self.get_item(['training_params', 'lr_decay'], default=1) # by default, no decay - self.minimum_lr = self.get_item(['training_params', 'minimum_lr'], default=0) - self.epoch_start_lr_decay = self.get_item(['training_params', 'epoch_start_lr_decay'], default=1) + if self.batch_size_each_gpu is None: + self.raise_configuration_error('training_params.batch_size') + self.batch_size_total = self.batch_size_each_gpu + if torch.cuda.device_count() > 1: + self.batch_size_total = torch.cuda.device_count() * self.batch_size_each_gpu + self.batch_num_to_show_results = self.batch_num_to_show_results // torch.cuda.device_count() + + if hasattr(self.params, 'max_epoch') and self.params.max_epoch: self.max_epoch = self.params.max_epoch - else: - self.max_epoch = self.get_item(['training_params', 'max_epoch'], default=float('inf')) - if 'valid_times_per_epoch' in self.conf['training_params']: + + if self.valid_times_per_epoch is not None: logging.info("configuration[training_params][valid_times_per_epoch] is deprecated, please use configuration[training_params][steps_per_validation] instead") - self.steps_per_validation = self.get_item(['training_params', 'steps_per_validation'], default=10) - self.batch_num_to_show_results = self.get_item(['training_params', 'batch_num_to_show_results'], default=10) - self.max_lengths = self.get_item(['training_params', 'max_lengths'], default=None, use_default=True) - self.fixed_lengths = self.get_item(['training_params', 'fixed_lengths'], default=None, use_default=True) + + # sequence length if self.fixed_lengths: self.max_lengths = None if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging: self.fixed_lengths = None self.max_lengths = None - if torch.cuda.device_count() > 1: - self.batch_size_total = torch.cuda.device_count() * self.training_params['batch_size'] - self.batch_num_to_show_results = self.batch_num_to_show_results // torch.cuda.device_count() - else: - self.batch_size_total = self.batch_size_each_gpu - - self.cpu_num_workers = self.get_item(['training_params', 'cpu_num_workers'], default=-1) #by default, use all workers cpu supports - # text preprocessing - self.__text_preprocessing = self.get_item(['training_params', 'text_preprocessing'], default=list()) + self.__text_preprocessing = self.text_preprocessing self.DBC2SBC = True if 'DBC2SBC' in self.__text_preprocessing else False self.unicode_fix = True if 'unicode_fix' in self.__text_preprocessing else False self.remove_stopwords = True if 'remove_stopwords' in self.__text_preprocessing else False # tokenzier - if self.language == 'chinese': - self.tokenizer = self.get_item(['training_params', 'tokenizer'], default='jieba') - else: - self.tokenizer = self.get_item(['training_params', 'tokenizer'], default='nltk') - - if self.extra_feature: - if self.DBC2SBC: - logging.warning("Detect the extra feature %s, set the DBC2sbc is False." % ''.join(list(feature_all-formal_feature))) - if self.unicode_fix: - logging.warning("Detect the extra feature %s, set the unicode_fix is False." % ''.join(list(feature_all-formal_feature))) - if self.remove_stopwords: - logging.warning("Detect the extra feature %s, set the remove_stopwords is False." % ''.join(list(feature_all-formal_feature))) - - if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging: - if self.unicode_fix: - logging.warning('For sequence tagging task, unicode_fix may change the number of words.') - if self.remove_stopwords: - self.remove_stopwords = True - logging.warning('For sequence tagging task, remove stopwords is forbidden! It is disabled now.') - + if self.tokenizer is None: + self.tokenizer = 'jieba' if self.language == 'chinese' else 'nltk' + + # GPU/CPU if self.phase != 'cache': - if torch.cuda.is_available() and torch.cuda.device_count() > 0 and self.training_params.get('use_gpu', True): - self.use_gpu = True + if torch.cuda.is_available() and torch.cuda.device_count() > 0 and self.use_gpu: logging.info("Activating GPU mode, there are %d GPUs available" % torch.cuda.device_count()) else: self.use_gpu = False logging.info("Activating CPU mode") - self.architecture = self.get_item(['architecture']) + def configurate_architecture(self): + self.input_types = self.architecture[0]['conf'] + + # extra feature + feature_all = set([_.lower() for _ in self.input_types.keys()]) + formal_feature = set(['word', 'char']) + extra_feature_num = feature_all - formal_feature + self.extra_feature = len(extra_feature_num) != 0 + if self.extra_feature: + if self.DBC2SBC: + logging.warning("Detect the extra feature %s, set the DBC2sbc is False." % ''.join(list(extra_feature_num))) + if self.unicode_fix: + logging.warning("Detect the extra feature %s, set the unicode_fix is False." % ''.join(list(extra_feature_num))) + if self.remove_stopwords: + logging.warning("Detect the extra feature %s, set the remove_stopwords is False." % ''.join(list(extra_feature_num))) + + # output layer self.output_layer_id = [] for single_layer in self.architecture: if 'output_layer_flag' in single_layer and single_layer['output_layer_flag']: @@ -384,42 +445,59 @@ class ModelConf(object): self.min_sentence_len = max(self.min_sentence_len, np.max(np.array([single_conf_value]))) break - if self.phase == 'train' or self.phase == 'test': - self.loss = BaseLossConf.get_conf(**self.get_item(['loss'])) - self.metrics = self.get_item(['metrics']) - if 'auc' in self.metrics and ProblemTypes[self.problem_type] == ProblemTypes.classification: - self.pos_label = self.get_item(['inputs', 'positive_label'], default=None, use_default=True) + def configurate_loss(self): + if self.phase != 'train' and self.phase != 'test': + return + + if self.loss is None or self.metrics is None: + self.raise_configuration_error('loss/metrics') + self.loss = BaseLossConf.get_conf(**self.loss) - def get_item(self, keys, default=None, use_default=False): - """ + if 'auc' in self.metrics and ProblemTypes[self.problem_type] == ProblemTypes.classification: + self.pos_label = self.positive_label - Args: - keys: - default: if some key is not found and default is None, we would raise an Exception, except that use_default is True - use_default: if you really want to set default to None, set use_default=True + def configurate_cache(self): + # whether use cache + if self.mode == 'philly': + self.use_cache = True - Returns: - - """ - item = self.conf - valid_keys = [] - try: - for key in keys: - item = item[key] - valid_keys.append(key) - except: - error_keys = copy.deepcopy(valid_keys) - error_keys.append(key) - if default is None and use_default is False: - raise ConfigurationError( - "The configuration file %s is illegal. There should be an item configuration[%s], " - "but the item %s is not found." % (self.conf_path, "][".join(error_keys), key)) + # cache dir + if self.phase == 'train': + if hasattr(self.params, 'cache_dir') and self.params.cache_dir: + self.cache_dir = self.params.cache_dir else: - # print("configuration[%s] is not found in %s, use default value %s" % - # ("][".join(error_keys), self.conf_path, repr(default))) - item = default + if self.mode == 'normal': + if self.use_cache is False: + self.cache_dir = os.path.join(tempfile.gettempdir(), 'neuron_blocks', ''.join(random.sample(string.ascii_letters+string.digits, 16))) + else: + # for philly mode, we can only save files in model_path or scratch_path + self.cache_dir = os.path.join(self.save_base_dir, 'cache') - return item + self.problem_path = os.path.join(self.cache_dir, 'problem.pkl') + if self.pretrained_emb_path is not None: + self.emb_pkl_path = os.path.join(self.cache_dir, 'emb.pkl') + else: + self.emb_pkl_path = None + else: + tmp_problem_path = os.path.join(self.save_base_dir, '.necessary_cache', 'problem.pkl') + self.problem_path = tmp_problem_path if os.path.isfile(tmp_problem_path) else os.path.join(self.save_base_dir, 'necessary_cache', 'problem.pkl') + + # md5 of training data and problem + self.train_data_md5 = None + if self.phase == 'train' and self.train_data_path: + logging.info("Calculating the md5 of traing data ...") + self.train_data_md5 = md5([self.train_data_path]) + logging.info("the md5 of traing data is %s"%(self.train_data_md5)) + self.problem_md5 = None + + # encoding + self.encoding_cache_dir = None + self.encoding_cache_index_file_path = None + self.encoding_cache_index_file_md5_path = None + self.encoding_file_index = None + self.encoding_cache_legal_line_cnt = 0 + self.encoding_cache_illegal_line_cnt = 0 + self.load_encoding_cache_generator = None def check_conf(self): """ verify if the configuration is legal or not @@ -537,24 +615,3 @@ class ModelConf(object): def back_up(self, params): shutil.copy(params.conf_path, self.save_base_dir) logging.info('Configuration file is backed up to %s' % (self.save_base_dir)) - - def _load_cache_config_from_conf(self): - # training data - self.train_data_md5 = None - if self.phase == 'train' and self.train_data_path: - logging.info("Calculating the md5 of traing data ...") - self.train_data_md5 = md5([self.train_data_path]) - logging.info("the md5 of traing data is %s"%(self.train_data_md5)) - - # problem - self.problem_md5 = None - - # encoding - self.encoding_cache_dir = None - self.encoding_cache_index_file_path = None - self.encoding_cache_index_file_md5_path = None - self.encoding_file_index = None - self.encoding_cache_legal_line_cnt = 0 - self.encoding_cache_illegal_line_cnt = 0 - self.load_encoding_cache_generator = None - diff --git a/utils/common_utils.py b/utils/common_utils.py index 3e790ba..928d4aa 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -12,6 +12,7 @@ import time import tempfile import subprocess import hashlib +from .exceptions import ConfigurationError def log_set(log_path, console_level='INFO', console_detailed=False, disable_log_file=False): """ @@ -38,29 +39,36 @@ def log_set(log_path, console_level='INFO', console_detailed=False, disable_log_ logging.getLogger().addHandler(console) -def load_from_pkl(pkl_path): +def load_from_pkl(pkl_path, debug=True): with open(pkl_path, 'rb') as fin: obj = pkl.load(fin) - logging.debug("%s loaded!" % pkl_path) + if debug: + logging.debug("%s loaded!" % pkl_path) return obj -def dump_to_pkl(obj, pkl_path): +def dump_to_pkl(obj, pkl_path, debug=True): with open(pkl_path, 'wb') as fout: pkl.dump(obj, fout, protocol=pkl.HIGHEST_PROTOCOL) - logging.debug("Obj dumped to %s!" % pkl_path) + if debug: + logging.debug("Obj dumped to %s!" % pkl_path) -def load_from_json(json_path): +def load_from_json(json_path, debug=True): data = None with open(json_path, 'r', encoding='utf-8') as f: - data = json.loads(f.read()) - logging.debug("%s loaded!" % json_path) + try: + data = json.loads(f.read()) + except Exception as e: + raise ConfigurationError("%s is not a legal JSON file, please check your JSON format!" % json_path) + if debug: + logging.debug("%s loaded!" % json_path) return data -def dump_to_json(obj, json_path): +def dump_to_json(obj, json_path, debug=True): with open(json_path, 'w', encoding='utf-8') as f: f.write(json.dumps(obj)) - logging.debug("Obj dumped to %s!" % json_path) + if debug: + logging.debug("Obj dumped to %s!" % json_path) def get_trainable_param_num(model): """ get the number of trainable parameters