ModelConf Reorg (#80)
This commit is contained in:
Родитель
58ad563a23
Коммит
a291d40aac
685
ModelConf.py
685
ModelConf.py
|
@ -15,10 +15,56 @@ import shutil
|
|||
from losses.BaseLossConf import BaseLossConf
|
||||
#import traceback
|
||||
from settings import LanguageTypes, ProblemTypes, TaggingSchemes, SupportedMetrics, PredictionTypes, DefaultPredictionFields, ConstantStatic
|
||||
from utils.common_utils import log_set, prepare_dir, md5
|
||||
from utils.common_utils import log_set, prepare_dir, md5, load_from_json, dump_to_json
|
||||
from utils.exceptions import ConfigurationError
|
||||
import numpy as np
|
||||
|
||||
class ConstantStaticItems(ConstantStatic):
|
||||
@staticmethod
|
||||
def concat_key_desc(key_prefix_desc, key):
|
||||
return key_prefix_desc + '.' + key
|
||||
|
||||
@staticmethod
|
||||
def get_value_by_key(json, key, key_prefix='', use_default=False, default=None):
|
||||
"""
|
||||
Args:
|
||||
json: a json object
|
||||
key: a key pointing to the value wanted to acquire
|
||||
use_default: if you really want to use default value when key can not be found in json object, set use_default=True
|
||||
default: if key is not found and default is None, we would raise an Exception, except that use_default is True
|
||||
Returns:
|
||||
value:
|
||||
"""
|
||||
try:
|
||||
value = json[key]
|
||||
except:
|
||||
if not use_default:
|
||||
raise ConfigurationError("key[%s] can not be found in configuration file" % (key_prefix + key))
|
||||
else:
|
||||
value = default
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def add_item(item_name, use_default=False, default=None):
|
||||
def add_item_loading_func(use_default, default, func_get_value_by_key):
|
||||
@classmethod
|
||||
def load_data(cls, obj, json, key_prefix_desc='', use_default=use_default, default=default, func_get_value_by_key=func_get_value_by_key):
|
||||
obj.__dict__[cls.__name__] = func_get_value_by_key(json, cls.__name__, key_prefix_desc, use_default, default)
|
||||
return obj
|
||||
return load_data
|
||||
return type(item_name, (ConstantStatic, ), dict(load_data=add_item_loading_func(use_default, default, __class__.get_value_by_key)))
|
||||
|
||||
@classmethod
|
||||
def load_data(cls, obj, json, key_prefix_desc=''):
|
||||
if cls.__name__ in json.keys():
|
||||
json = json[cls.__name__]
|
||||
for key in cls.__dict__.keys():
|
||||
if not hasattr(cls.__dict__[key], 'load_data'):
|
||||
continue
|
||||
item = cls.__dict__[key]
|
||||
obj = item.load_data(obj, json, cls.concat_key_desc(key_prefix_desc, item.__name__))
|
||||
return obj
|
||||
|
||||
class ModelConf(object):
|
||||
def __init__(self, phase, conf_path, nb_version, params=None, mode='normal'):
|
||||
""" loading configuration from configuration file and argparse parameters
|
||||
|
@ -36,6 +82,7 @@ class ModelConf(object):
|
|||
self.params = params
|
||||
self.mode = mode.lower()
|
||||
assert self.mode in set(['normal', 'philly']), 'Your mode %s is illegal, supported modes are: normal and philly!'
|
||||
|
||||
self.load_from_file(conf_path)
|
||||
|
||||
self.check_version_compat(nb_version, self.tool_version)
|
||||
|
@ -51,321 +98,335 @@ class ModelConf(object):
|
|||
logging.debug('%s: %s' % (str(name), str(value)))
|
||||
logging.debug('=' * 80)
|
||||
|
||||
class Conf(ConstantStaticItems):
|
||||
license = ConstantStaticItems.add_item('license')
|
||||
tool_version = ConstantStaticItems.add_item('tool_version')
|
||||
model_description = ConstantStaticItems.add_item('model_description')
|
||||
language = ConstantStaticItems.add_item('language', use_default=True, default='english')
|
||||
|
||||
class inputs(ConstantStaticItems):
|
||||
use_cache = ConstantStaticItems.add_item('use_cache', use_default=True, default=True)
|
||||
dataset_type = ConstantStaticItems.add_item('dataset_type')
|
||||
tagging_scheme = ConstantStaticItems.add_item('tagging_scheme', use_default=True, default=None)
|
||||
|
||||
class data_paths(ConstantStaticItems):
|
||||
train_data_path = ConstantStaticItems.add_item('train_data_path', use_default=True, default=None)
|
||||
valid_data_path = ConstantStaticItems.add_item('valid_data_path', use_default=True, default=None)
|
||||
test_data_path = ConstantStaticItems.add_item('test_data_path', use_default=True, default=None)
|
||||
predict_data_path = ConstantStaticItems.add_item('predict_data_path', use_default=True, default=None)
|
||||
pre_trained_emb = ConstantStaticItems.add_item('pre_trained_emb', use_default=True, default=None)
|
||||
pretrained_model_path = ConstantStaticItems.add_item('pretrained_model_path', use_default=True, default=None)
|
||||
|
||||
file_with_col_header = ConstantStaticItems.add_item('file_with_col_header', use_default=True, default=False)
|
||||
pretrained_emb_type = ConstantStaticItems.add_item('pretrained_emb_type', use_default=True, default='glove')
|
||||
pretrained_emb_binary_or_text = ConstantStaticItems.add_item('pretrained_emb_binary_or_text', use_default=True, default='text')
|
||||
involve_all_words_in_pretrained_emb = ConstantStaticItems.add_item('involve_all_words_in_pretrained_emb', use_default=True, default=False)
|
||||
add_start_end_for_seq = ConstantStaticItems.add_item('add_start_end_for_seq', use_default=True, default=False)
|
||||
file_header = ConstantStaticItems.add_item('file_header', use_default=True, default=None)
|
||||
predict_file_header = ConstantStaticItems.add_item('predict_file_header', use_default=True, default=None)
|
||||
model_inputs = ConstantStaticItems.add_item('model_inputs')
|
||||
target = ConstantStaticItems.add_item('target', use_default=True, default=None)
|
||||
positive_label = ConstantStaticItems.add_item('positive_label', use_default=True, default=None)
|
||||
|
||||
class outputs(ConstantStaticItems):
|
||||
save_base_dir = ConstantStaticItems.add_item('save_base_dir', use_default=True, default=None)
|
||||
model_name = ConstantStaticItems.add_item('model_name')
|
||||
|
||||
train_log_name = ConstantStaticItems.add_item('train_log_name', use_default=True, default=None)
|
||||
test_log_name = ConstantStaticItems.add_item('test_log_name', use_default=True, default=None)
|
||||
predict_log_name = ConstantStaticItems.add_item('predict_log_name', use_default=True, default=None)
|
||||
predict_fields = ConstantStaticItems.add_item('predict_fields', use_default=True, default=None)
|
||||
predict_output_name = ConstantStaticItems.add_item('predict_output_name', use_default=True, default='predict.tsv')
|
||||
cache_dir = ConstantStaticItems.add_item('cache_dir', use_default=True, default=None)
|
||||
|
||||
class training_params(ConstantStaticItems):
|
||||
class vocabulary(ConstantStaticItems):
|
||||
min_word_frequency = ConstantStaticItems.add_item('min_word_frequency', use_default=True, default=3)
|
||||
max_vocabulary = ConstantStaticItems.add_item('max_vocabulary', use_default=True, default=800 * 1000)
|
||||
max_building_lines = ConstantStaticItems.add_item('max_building_lines', use_default=True, default=1000 * 1000)
|
||||
|
||||
optimizer = ConstantStaticItems.add_item('optimizer', use_default=True, default=None)
|
||||
clip_grad_norm_max_norm = ConstantStaticItems.add_item('clip_grad_norm_max_norm', use_default=True, default=-1)
|
||||
chunk_size = ConstantStaticItems.add_item('chunk_size', use_default=True, default=1000 * 1000)
|
||||
lr_decay = ConstantStaticItems.add_item('lr_decay', use_default=True, default=1)
|
||||
minimum_lr = ConstantStaticItems.add_item('minimum_lr', use_default=True, default=0)
|
||||
epoch_start_lr_decay = ConstantStaticItems.add_item('epoch_start_lr_decay', use_default=True, default=1)
|
||||
use_gpu = ConstantStaticItems.add_item('use_gpu', use_default=True, default=False)
|
||||
cpu_num_workers = ConstantStaticItems.add_item('cpu_num_workers', use_default=True, default=-1) #by default, use all workers cpu supports
|
||||
batch_size = ConstantStaticItems.add_item('batch_size', use_default=True, default=1)
|
||||
batch_num_to_show_results = ConstantStaticItems.add_item('batch_num_to_show_results', use_default=True, default=10)
|
||||
max_epoch = ConstantStaticItems.add_item('max_epoch', use_default=True, default=float('inf'))
|
||||
valid_times_per_epoch = ConstantStaticItems.add_item('valid_times_per_epoch', use_default=True, default=None)
|
||||
steps_per_validation = ConstantStaticItems.add_item('steps_per_validation', use_default=True, default=10)
|
||||
text_preprocessing = ConstantStaticItems.add_item('text_preprocessing', use_default=True, default=list())
|
||||
max_lengths = ConstantStaticItems.add_item('max_lengths', use_default=True, default=None)
|
||||
fixed_lengths = ConstantStaticItems.add_item('fixed_lengths', use_default=True, default=None)
|
||||
tokenizer = ConstantStaticItems.add_item('tokenizer', use_default=True, default=None)
|
||||
|
||||
architecture = ConstantStaticItems.add_item('architecture')
|
||||
loss = ConstantStaticItems.add_item('loss', use_default=True, default=None)
|
||||
metrics = ConstantStaticItems.add_item('metrics', use_default=True, default=None)
|
||||
|
||||
def raise_configuration_error(self, key):
|
||||
raise ConfigurationError(
|
||||
"The configuration file %s is illegal. the item [%s] is not found." % (self.conf_path, key))
|
||||
|
||||
def load_from_file(self, conf_path):
|
||||
with codecs.open(conf_path, 'r', encoding='utf-8') as fin:
|
||||
try:
|
||||
self.conf = json.load(fin)
|
||||
except Exception as e:
|
||||
raise ConfigurationError("%s is not a legal JSON file, please check your JSON format!" % conf_path)
|
||||
# load file
|
||||
self.conf = load_from_json(conf_path, debug=False)
|
||||
self = self.Conf.load_data(self, {'Conf' : self.conf}, key_prefix_desc='Conf')
|
||||
self.language = self.language.lower()
|
||||
self.configurate_outputs()
|
||||
self.configurate_inputs()
|
||||
self.configurate_training_params()
|
||||
self.configurate_architecture()
|
||||
self.configurate_loss()
|
||||
self.configurate_cache()
|
||||
|
||||
self.tool_version = self.get_item(['tool_version'])
|
||||
self.language = self.get_item(['language'], default='english').lower()
|
||||
self.problem_type = self.get_item(['inputs', 'dataset_type']).lower()
|
||||
#if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True)
|
||||
def configurate_outputs(self):
|
||||
def configurate_logger(self):
|
||||
if self.phase == 'cache':
|
||||
return
|
||||
|
||||
if self.mode == 'normal':
|
||||
self.use_cache = self.get_item(['inputs', 'use_cache'], True)
|
||||
elif self.mode == 'philly':
|
||||
self.use_cache = True
|
||||
# dir
|
||||
if hasattr(self.params, 'log_dir') and self.params.log_dir:
|
||||
self.log_dir = self.params.log_dir
|
||||
prepare_dir(self.log_dir, True, allow_overwrite=True)
|
||||
else:
|
||||
self.log_dir = self.save_base_dir
|
||||
|
||||
# path
|
||||
self.train_log_path = os.path.join(self.log_dir, self.train_log_name)
|
||||
self.test_log_path = os.path.join(self.log_dir, self.test_log_name)
|
||||
self.predict_log_path = os.path.join(self.log_dir, self.predict_log_name)
|
||||
if self.phase == 'train':
|
||||
log_path = self.train_log_path
|
||||
elif self.phase == 'test':
|
||||
log_path = self.test_log_path
|
||||
elif self.phase == 'predict':
|
||||
log_path = self.predict_log_path
|
||||
if log_path is None:
|
||||
self.raise_configuration_error(self.phase + '_log_name')
|
||||
|
||||
# OUTPUTS
|
||||
# log level
|
||||
if self.mode == 'philly' or self.params.debug:
|
||||
log_set(log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file)
|
||||
else:
|
||||
log_set(log_path, disable_log_file=self.params.disable_log_file)
|
||||
|
||||
# save base dir
|
||||
if hasattr(self.params, 'model_save_dir') and self.params.model_save_dir:
|
||||
self.save_base_dir = self.params.model_save_dir
|
||||
else:
|
||||
self.save_base_dir = self.get_item(['outputs', 'save_base_dir'])
|
||||
|
||||
if self.phase == 'train':
|
||||
# in train.py, it is called pretrained_model_path
|
||||
if hasattr(self.params, 'pretrained_model_path') and self.params.pretrained_model_path:
|
||||
self.pretrained_model_path = self.previous_model_path = self.params.pretrained_model_path
|
||||
else:
|
||||
self.pretrained_model_path = self.previous_model_path = self.get_item(['inputs', 'data_paths', 'pretrained_model_path'], default=None, use_default=True)
|
||||
elif self.phase == 'test' or self.phase == 'predict':
|
||||
# in test.py and predict.py, it is called pretrained_model_path
|
||||
if hasattr(self.params, 'previous_model_path') and self.params.previous_model_path:
|
||||
self.previous_model_path = self.pretrained_model_path = self.params.previous_model_path
|
||||
else:
|
||||
self.previous_model_path = self.pretrained_model_path = os.path.join(self.save_base_dir, self.get_item(['outputs', 'model_name'])) # namely, the model_save_path
|
||||
|
||||
if hasattr(self, 'pretrained_model_path') and self.pretrained_model_path: # namely self.previous_model_path
|
||||
tmp_saved_problem_path = os.path.join(os.path.dirname(self.pretrained_model_path), '.necessary_cache', 'problem.pkl')
|
||||
self.saved_problem_path = tmp_saved_problem_path if os.path.isfile(tmp_saved_problem_path) \
|
||||
else os.path.join(os.path.dirname(self.pretrained_model_path), 'necessary_cache', 'problem.pkl')
|
||||
if not (os.path.isfile(self.pretrained_model_path) and os.path.isfile(self.saved_problem_path)):
|
||||
raise Exception('Previous trained model %s or its dictionaries %s does not exist!' % (self.pretrained_model_path, self.saved_problem_path))
|
||||
elif self.save_base_dir is None:
|
||||
self.raise_configuration_error('save_base_dir')
|
||||
|
||||
# prepare save base dir
|
||||
if self.phase != 'cache':
|
||||
prepare_dir(self.save_base_dir, True, allow_overwrite=self.params.force or self.mode == 'philly',
|
||||
extra_info='will overwrite model file and train.log' if self.phase=='train' else 'will add %s.log and predict file'%self.phase)
|
||||
|
||||
if hasattr(self.params, 'log_dir') and self.params.log_dir:
|
||||
self.log_dir = self.params.log_dir
|
||||
if self.phase != 'cache':
|
||||
prepare_dir(self.log_dir, True, allow_overwrite=True)
|
||||
else:
|
||||
self.log_dir = self.save_base_dir
|
||||
# logger
|
||||
configurate_logger(self)
|
||||
|
||||
if self.phase == 'train':
|
||||
self.train_log_path = os.path.join(self.log_dir, self.get_item(['outputs', 'train_log_name']))
|
||||
if self.mode == 'philly' or self.params.debug:
|
||||
log_set(self.train_log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file)
|
||||
else:
|
||||
log_set(self.train_log_path, disable_log_file=self.params.disable_log_file)
|
||||
elif self.phase == 'test':
|
||||
self.test_log_path = os.path.join(self.log_dir, self.get_item(['outputs', 'test_log_name']))
|
||||
if self.mode == 'philly' or self.params.debug:
|
||||
log_set(self.test_log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file)
|
||||
else:
|
||||
log_set(self.test_log_path, disable_log_file=self.params.disable_log_file)
|
||||
elif self.phase == 'predict':
|
||||
self.predict_log_path = os.path.join(self.log_dir, self.get_item(['outputs', 'predict_log_name']))
|
||||
if self.mode == 'philly' or self.params.debug:
|
||||
log_set(self.predict_log_path, console_level='DEBUG', console_detailed=True, disable_log_file=self.params.disable_log_file)
|
||||
else:
|
||||
log_set(self.predict_log_path, disable_log_file=self.params.disable_log_file)
|
||||
# predict output path
|
||||
if self.phase != 'cache':
|
||||
self.predict_output_path = self.params.predict_output_path if self.params.predict_output_path else os.path.join(self.save_base_dir, self.get_item(['outputs', 'predict_output_name'], default='predict.tsv'))
|
||||
if self.params.predict_output_path:
|
||||
self.predict_output_path = self.params.predict_output_path
|
||||
else:
|
||||
self.predict_output_path = os.path.join(self.save_base_dir, self.predict_output_name)
|
||||
logging.debug('Prepare dir for: %s' % self.predict_output_path)
|
||||
prepare_dir(self.predict_output_path, False, allow_overwrite=self.params.force or self.mode == 'philly')
|
||||
self.predict_fields = self.get_item(['outputs', 'predict_fields'], default=DefaultPredictionFields[ProblemTypes[self.problem_type]])
|
||||
|
||||
self.model_save_path = os.path.join(self.save_base_dir, self.get_item(['outputs', 'model_name']))
|
||||
if self.predict_fields is None:
|
||||
self.predict_fields = DefaultPredictionFields[ProblemTypes[self.problem_type]]
|
||||
|
||||
# INPUTS
|
||||
if hasattr(self.params, 'train_data_path') and self.params.train_data_path:
|
||||
self.train_data_path = self.params.train_data_path
|
||||
else:
|
||||
if self.mode == 'normal':
|
||||
self.train_data_path = self.get_item(['inputs', 'data_paths', 'train_data_path'], default=None, use_default=True)
|
||||
else:
|
||||
self.model_save_path = os.path.join(self.save_base_dir, self.model_name)
|
||||
|
||||
def configurate_inputs(self):
|
||||
|
||||
def configurate_data_path(self):
|
||||
self.pretrained_emb_path =self.pre_trained_emb
|
||||
|
||||
if self.mode != "normal":
|
||||
self.train_data_path = None
|
||||
if hasattr(self.params, 'valid_data_path') and self.params.valid_data_path:
|
||||
self.valid_data_path = self.params.valid_data_path
|
||||
else:
|
||||
if self.mode == 'normal':
|
||||
self.valid_data_path = self.get_item(['inputs', 'data_paths', 'valid_data_path'], default=None, use_default=True)
|
||||
else:
|
||||
self.valid_data_path = None
|
||||
if hasattr(self.params, 'test_data_path') and self.params.test_data_path:
|
||||
self.test_data_path = self.params.test_data_path
|
||||
else:
|
||||
if self.mode == 'normal':
|
||||
self.test_data_path = self.get_item(['inputs', 'data_paths', 'test_data_path'], default=None, use_default=True)
|
||||
else:
|
||||
self.test_data_path = None
|
||||
|
||||
if self.phase == 'predict':
|
||||
if self.params.predict_data_path:
|
||||
self.predict_data_path = self.params.predict_data_path
|
||||
else:
|
||||
if self.mode == 'normal':
|
||||
self.predict_data_path = self.get_item(['inputs', 'data_paths', 'predict_data_path'], default=None, use_default=True)
|
||||
else:
|
||||
self.predict_data_path = None
|
||||
|
||||
if self.phase == 'train' or self.phase == 'cache':
|
||||
if self.valid_data_path is None and self.test_data_path is not None:
|
||||
# We support test_data_path == None, if someone set valid_data_path to None while test_data_path is not None,
|
||||
# swap the valid_data_path and test_data_path
|
||||
self.valid_data_path = self.test_data_path
|
||||
self.test_data_path = None
|
||||
elif self.phase == 'predict':
|
||||
if self.predict_data_path is None and self.test_data_path is not None:
|
||||
self.predict_data_path = self.test_data_path
|
||||
self.test_data_path = None
|
||||
|
||||
if self.phase == 'train' or self.phase == 'test' or self.phase == 'cache':
|
||||
self.file_columns = self.get_item(['inputs', 'file_header'])
|
||||
else:
|
||||
self.file_columns = self.get_item(['inputs', 'file_header'], default=None, use_default=True)
|
||||
|
||||
if self.phase == 'predict':
|
||||
if self.file_columns is None:
|
||||
self.predict_file_columns = self.get_item(['inputs', 'predict_file_header'])
|
||||
else:
|
||||
self.predict_file_columns = self.get_item(['inputs', 'predict_file_header'], default=None, use_default=True)
|
||||
if self.predict_file_columns is None:
|
||||
self.predict_file_columns = self.file_columns
|
||||
|
||||
if self.phase != 'predict':
|
||||
if self.phase == 'cache':
|
||||
self.answer_column_name = self.get_item(['inputs', 'target'], default=None, use_default=True)
|
||||
else:
|
||||
self.answer_column_name = self.get_item(['inputs', 'target'])
|
||||
self.input_types = self.get_item(['architecture', 0, 'conf'])
|
||||
# add extra feature
|
||||
feature_all = set([_.lower() for _ in self.input_types.keys()])
|
||||
formal_feature = set(['word', 'char'])
|
||||
self.extra_feature = len(feature_all - formal_feature) != 0
|
||||
|
||||
# add char embedding config
|
||||
# char_emb_type = None
|
||||
# char_emb_type_cols = None
|
||||
# for single_type in self.input_types:
|
||||
# if single_type.lower() == 'char':
|
||||
# char_emb_type = single_type
|
||||
# char_emb_type_cols = [single_col.lower() for single_col in self.input_types[single_type]['cols']]
|
||||
# break
|
||||
self.object_inputs = self.get_item(['inputs', 'model_inputs'])
|
||||
# if char_emb_type and char_emb_type_cols:
|
||||
# for single_input in self.object_inputs:
|
||||
# for single_col in char_emb_type_cols:
|
||||
# if single_input.lower() in single_col:
|
||||
# self.object_inputs[single_input].append(single_col)
|
||||
|
||||
self.object_inputs_names = [name for name in self.object_inputs]
|
||||
|
||||
# vocabulary setting
|
||||
self.max_vocabulary = self.get_item(['training_params', 'vocabulary', 'max_vocabulary'], default=800000, use_default=True)
|
||||
self.min_word_frequency = self.get_item(['training_params', 'vocabulary', 'min_word_frequency'], default=3, use_default=True)
|
||||
self.max_building_lines = self.get_item(['training_params', 'vocabulary', 'max_building_lines'], default=1000 * 1000, use_default=True)
|
||||
|
||||
# chunk_size
|
||||
self.chunk_size = self.get_item(['training_params', 'chunk_size'], default=1000 * 1000, use_default=True)
|
||||
|
||||
# file column header setting
|
||||
self.file_with_col_header = self.get_item(['inputs', 'file_with_col_header'], default=False, use_default=True)
|
||||
|
||||
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
self.add_start_end_for_seq = self.get_item(['inputs', 'add_start_end_for_seq'], default=True)
|
||||
else:
|
||||
self.add_start_end_for_seq = self.get_item(['inputs', 'add_start_end_for_seq'], default=False)
|
||||
|
||||
if hasattr(self.params, 'pretrained_emb_path') and self.params.pretrained_emb_path:
|
||||
self.pretrained_emb_path = self.params.pretrained_emb_path
|
||||
else:
|
||||
if self.mode == 'normal':
|
||||
self.pretrained_emb_path = self.get_item(['inputs', 'data_paths', 'pre_trained_emb'], default=None, use_default=True)
|
||||
else:
|
||||
self.predict_data_path = None
|
||||
self.pretrained_emb_path = None
|
||||
|
||||
if 'word' in self.get_item(['architecture', 0, 'conf']) and self.pretrained_emb_path:
|
||||
if hasattr(self.params, 'involve_all_words_in_pretrained_emb') and self.params.involve_all_words_in_pretrained_emb:
|
||||
self.involve_all_words_in_pretrained_emb = self.params.involve_all_words_in_pretrained_emb
|
||||
if hasattr(self.params, 'train_data_path') and self.params.train_data_path:
|
||||
self.train_data_path = self.params.train_data_path
|
||||
if hasattr(self.params, 'valid_data_path') and self.params.valid_data_path:
|
||||
self.valid_data_path = self.params.valid_data_path
|
||||
if hasattr(self.params, 'test_data_path') and self.params.test_data_path:
|
||||
self.test_data_path = self.params.test_data_path
|
||||
if hasattr(self.params, 'predict_data_path') and self.params.predict_data_path:
|
||||
self.predict_data_path = self.params.predict_data_path
|
||||
if hasattr(self.params, 'pretrained_emb_path') and self.params.pretrained_emb_path:
|
||||
self.pretrained_emb_path = self.params.pretrained_emb_path
|
||||
|
||||
if self.phase == 'train' or self.phase == 'cache':
|
||||
if self.valid_data_path is None and self.test_data_path is not None:
|
||||
# We support test_data_path == None, if someone set valid_data_path to None while test_data_path is not None,
|
||||
# swap the valid_data_path and test_data_path
|
||||
self.valid_data_path = self.test_data_path
|
||||
self.test_data_path = None
|
||||
elif self.phase == 'predict':
|
||||
if self.predict_data_path is None and self.test_data_path is not None:
|
||||
self.predict_data_path = self.test_data_path
|
||||
self.test_data_path = None
|
||||
|
||||
return self
|
||||
|
||||
def configurate_data_format(self):
|
||||
# file columns
|
||||
if self.phase == 'train' or self.phase == 'test' or self.phase == 'cache':
|
||||
self.file_columns = self.file_header
|
||||
if self.file_columns is None:
|
||||
self.raise_configuration_error('file_columns')
|
||||
if self.phase == 'predict':
|
||||
if self.file_columns is None and self.predict_file_columns is None:
|
||||
self.raise_configuration_error('predict_file_columns')
|
||||
if self.file_columns and self.predict_file_columns is None:
|
||||
self.predict_file_columns = self.file_columns
|
||||
|
||||
# target
|
||||
if self.phase != 'predict':
|
||||
self.answer_column_name = self.target
|
||||
if self.target is None and self.phase != 'cache':
|
||||
self.raise_configuration_error('target')
|
||||
|
||||
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging and self.add_start_end_for_seq is None:
|
||||
self.add_start_end_for_seq = True
|
||||
|
||||
# pretrained embedding
|
||||
if 'word' in self.architecture[0]['conf'] and self.pretrained_emb_path:
|
||||
if hasattr(self.params, 'involve_all_words_in_pretrained_emb') and self.params.involve_all_words_in_pretrained_emb:
|
||||
self.involve_all_words_in_pretrained_emb = self.params.involve_all_words_in_pretrained_emb
|
||||
if hasattr(self.params, 'pretrained_emb_type') and self.params.pretrained_emb_type:
|
||||
self.pretrained_emb_type = self.params.pretrained_emb_type
|
||||
if hasattr(self.params, 'pretrained_emb_binary_or_text') and self.params.pretrained_emb_binary_or_text:
|
||||
self.pretrained_emb_binary_or_text = self.params.pretrained_emb_binary_or_text
|
||||
self.pretrained_emb_dim = self.architecture[0]['conf']['word']['dim']
|
||||
else:
|
||||
self.involve_all_words_in_pretrained_emb = self.get_item(['inputs', 'involve_all_words_in_pretrained_emb'], default=False)
|
||||
if hasattr(self.params, 'pretrained_emb_type') and self.params.pretrained_emb_type:
|
||||
self.pretrained_emb_type = self.params.pretrained_emb_type
|
||||
else:
|
||||
self.pretrained_emb_type = self.get_item(['inputs', 'pretrained_emb_type'], default='glove')
|
||||
if hasattr(self.params, 'pretrained_emb_binary_or_text') and self.params.pretrained_emb_binary_or_text:
|
||||
self.pretrained_emb_binary_or_text = self.params.pretrained_emb_binary_or_text
|
||||
else:
|
||||
self.pretrained_emb_binary_or_text = self.get_item(['inputs', 'pretrained_emb_binary_or_text'], default='text')
|
||||
self.pretrained_emb_dim = self.get_item(['architecture', 0, 'conf', 'word', 'dim'])
|
||||
self.pretrained_emb_path = None
|
||||
self.involve_all_words_in_pretrained_emb = None
|
||||
self.pretrained_emb_type = None
|
||||
self.pretrained_emb_binary_or_text = None
|
||||
self.pretrained_emb_dim = None
|
||||
|
||||
return self
|
||||
|
||||
def configurate_model_input(self):
|
||||
self.object_inputs = self.model_inputs
|
||||
self.object_inputs_names = [name for name in self.object_inputs]
|
||||
|
||||
return self
|
||||
|
||||
self.problem_type = self.dataset_type.lower()
|
||||
|
||||
# previous model path
|
||||
if hasattr(self.params, 'previous_model_path') and self.params.previous_model_path:
|
||||
self.previous_model_path = self.params.previous_model_path
|
||||
else:
|
||||
self.pretrained_emb_path = None
|
||||
self.involve_all_words_in_pretrained_emb = None
|
||||
self.pretrained_emb_binary_or_text = None
|
||||
self.pretrained_emb_dim = None
|
||||
self.pretrained_emb_type = None
|
||||
self.previous_model_path = os.path.join(self.save_base_dir, self.model_name)
|
||||
|
||||
# pretrained model path
|
||||
if hasattr(self.params, 'pretrained_model_path') and self.params.pretrained_model_path:
|
||||
self.pretrained_model_path = self.params.pretrained_model_path
|
||||
|
||||
# saved problem path
|
||||
model_path = None
|
||||
if self.phase == 'train':
|
||||
if hasattr(self.params, 'cache_dir') and self.params.cache_dir:
|
||||
# for aether
|
||||
self.cache_dir = self.params.cache_dir
|
||||
else:
|
||||
if self.mode == 'normal':
|
||||
if self.use_cache:
|
||||
self.cache_dir = self.get_item(['outputs', 'cache_dir'])
|
||||
else:
|
||||
self.cache_dir = os.path.join(tempfile.gettempdir(), 'neuron_blocks', ''.join(random.sample(string.ascii_letters+string.digits, 16)))
|
||||
else:
|
||||
# for philly mode, we can only save files in model_path or scratch_path
|
||||
self.cache_dir = os.path.join(self.save_base_dir, 'cache')
|
||||
model_path = self.pretrained_model_path
|
||||
elif self.phase == 'test' or self.phase == 'predict':
|
||||
model_path = self.previous_model_path
|
||||
if model_path:
|
||||
model_path_dir = os.path.dirname(model_path)
|
||||
self.saved_problem_path = os.path.join(model_path_dir, '.necessary_cache', 'problem.pkl')
|
||||
if not os.path.isfile(self.saved_problem_path):
|
||||
self.saved_problem_path = os.path.join(model_path_dir, 'necessary_cache', 'problem.pkl')
|
||||
if not (os.path.isfile(model_path) and os.path.isfile(self.saved_problem_path)):
|
||||
raise Exception('Previous trained model %s or its dictionaries %s does not exist!' % (model_path, self.saved_problem_path))
|
||||
|
||||
self.problem_path = os.path.join(self.cache_dir, 'problem.pkl')
|
||||
if self.pretrained_emb_path is not None:
|
||||
self.emb_pkl_path = os.path.join(self.cache_dir, 'emb.pkl')
|
||||
else:
|
||||
self.emb_pkl_path = None
|
||||
else:
|
||||
tmp_problem_path = os.path.join(self.save_base_dir, '.necessary_cache', 'problem.pkl')
|
||||
self.problem_path = tmp_problem_path if os.path.isfile(tmp_problem_path) else os.path.join(self.save_base_dir, 'necessary_cache', 'problem.pkl')
|
||||
|
||||
# cache configuration
|
||||
self._load_cache_config_from_conf()
|
||||
|
||||
# training params
|
||||
self.training_params = self.get_item(['training_params'])
|
||||
configurate_data_path(self)
|
||||
configurate_data_format(self)
|
||||
configurate_model_input(self)
|
||||
|
||||
def configurate_training_params(self):
|
||||
# optimizer
|
||||
if self.phase == 'train':
|
||||
self.optimizer_name = self.get_item(['training_params', 'optimizer', 'name'])
|
||||
self.optimizer_params = self.get_item(['training_params', 'optimizer', 'params'])
|
||||
|
||||
self.clip_grad_norm_max_norm = self.get_item(['training_params', 'clip_grad_norm_max_norm'], default=-1)
|
||||
|
||||
if self.optimizer is None:
|
||||
self.raise_configuration_error('training_params.optimizer')
|
||||
if 'name' not in self.optimizer.keys():
|
||||
self.raise_configuration_error('training_params.optimizer.name')
|
||||
self.optimizer_name = self.optimizer['name']
|
||||
if 'params' not in self.optimizer.keys():
|
||||
self.raise_configuration_error('training_params.optimizer.params')
|
||||
self.optimizer_params = self.optimizer['params']
|
||||
if hasattr(self.params, 'learning_rate') and self.params.learning_rate:
|
||||
self.optimizer_params['lr'] = self.params.learning_rate
|
||||
|
||||
|
||||
# batch size
|
||||
self.batch_size_each_gpu = self.batch_size # the batch_size in conf file is the batch_size on each GPU
|
||||
if hasattr(self.params, 'batch_size') and self.params.batch_size:
|
||||
self.batch_size_each_gpu = self.params.batch_size
|
||||
else:
|
||||
self.batch_size_each_gpu = self.get_item(['training_params', 'batch_size']) #the batch_size in conf file is the batch_size on each GPU
|
||||
self.lr_decay = self.get_item(['training_params', 'lr_decay'], default=1) # by default, no decay
|
||||
self.minimum_lr = self.get_item(['training_params', 'minimum_lr'], default=0)
|
||||
self.epoch_start_lr_decay = self.get_item(['training_params', 'epoch_start_lr_decay'], default=1)
|
||||
if self.batch_size_each_gpu is None:
|
||||
self.raise_configuration_error('training_params.batch_size')
|
||||
self.batch_size_total = self.batch_size_each_gpu
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.batch_size_total = torch.cuda.device_count() * self.batch_size_each_gpu
|
||||
self.batch_num_to_show_results = self.batch_num_to_show_results // torch.cuda.device_count()
|
||||
|
||||
|
||||
if hasattr(self.params, 'max_epoch') and self.params.max_epoch:
|
||||
self.max_epoch = self.params.max_epoch
|
||||
else:
|
||||
self.max_epoch = self.get_item(['training_params', 'max_epoch'], default=float('inf'))
|
||||
if 'valid_times_per_epoch' in self.conf['training_params']:
|
||||
|
||||
if self.valid_times_per_epoch is not None:
|
||||
logging.info("configuration[training_params][valid_times_per_epoch] is deprecated, please use configuration[training_params][steps_per_validation] instead")
|
||||
self.steps_per_validation = self.get_item(['training_params', 'steps_per_validation'], default=10)
|
||||
self.batch_num_to_show_results = self.get_item(['training_params', 'batch_num_to_show_results'], default=10)
|
||||
self.max_lengths = self.get_item(['training_params', 'max_lengths'], default=None, use_default=True)
|
||||
self.fixed_lengths = self.get_item(['training_params', 'fixed_lengths'], default=None, use_default=True)
|
||||
|
||||
# sequence length
|
||||
if self.fixed_lengths:
|
||||
self.max_lengths = None
|
||||
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
self.fixed_lengths = None
|
||||
self.max_lengths = None
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
self.batch_size_total = torch.cuda.device_count() * self.training_params['batch_size']
|
||||
self.batch_num_to_show_results = self.batch_num_to_show_results // torch.cuda.device_count()
|
||||
else:
|
||||
self.batch_size_total = self.batch_size_each_gpu
|
||||
|
||||
self.cpu_num_workers = self.get_item(['training_params', 'cpu_num_workers'], default=-1) #by default, use all workers cpu supports
|
||||
|
||||
# text preprocessing
|
||||
self.__text_preprocessing = self.get_item(['training_params', 'text_preprocessing'], default=list())
|
||||
self.__text_preprocessing = self.text_preprocessing
|
||||
self.DBC2SBC = True if 'DBC2SBC' in self.__text_preprocessing else False
|
||||
self.unicode_fix = True if 'unicode_fix' in self.__text_preprocessing else False
|
||||
self.remove_stopwords = True if 'remove_stopwords' in self.__text_preprocessing else False
|
||||
|
||||
# tokenzier
|
||||
if self.language == 'chinese':
|
||||
self.tokenizer = self.get_item(['training_params', 'tokenizer'], default='jieba')
|
||||
else:
|
||||
self.tokenizer = self.get_item(['training_params', 'tokenizer'], default='nltk')
|
||||
|
||||
if self.extra_feature:
|
||||
if self.DBC2SBC:
|
||||
logging.warning("Detect the extra feature %s, set the DBC2sbc is False." % ''.join(list(feature_all-formal_feature)))
|
||||
if self.unicode_fix:
|
||||
logging.warning("Detect the extra feature %s, set the unicode_fix is False." % ''.join(list(feature_all-formal_feature)))
|
||||
if self.remove_stopwords:
|
||||
logging.warning("Detect the extra feature %s, set the remove_stopwords is False." % ''.join(list(feature_all-formal_feature)))
|
||||
|
||||
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
if self.unicode_fix:
|
||||
logging.warning('For sequence tagging task, unicode_fix may change the number of words.')
|
||||
if self.remove_stopwords:
|
||||
self.remove_stopwords = True
|
||||
logging.warning('For sequence tagging task, remove stopwords is forbidden! It is disabled now.')
|
||||
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = 'jieba' if self.language == 'chinese' else 'nltk'
|
||||
|
||||
# GPU/CPU
|
||||
if self.phase != 'cache':
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0 and self.training_params.get('use_gpu', True):
|
||||
self.use_gpu = True
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0 and self.use_gpu:
|
||||
logging.info("Activating GPU mode, there are %d GPUs available" % torch.cuda.device_count())
|
||||
else:
|
||||
self.use_gpu = False
|
||||
logging.info("Activating CPU mode")
|
||||
|
||||
self.architecture = self.get_item(['architecture'])
|
||||
def configurate_architecture(self):
|
||||
self.input_types = self.architecture[0]['conf']
|
||||
|
||||
# extra feature
|
||||
feature_all = set([_.lower() for _ in self.input_types.keys()])
|
||||
formal_feature = set(['word', 'char'])
|
||||
extra_feature_num = feature_all - formal_feature
|
||||
self.extra_feature = len(extra_feature_num) != 0
|
||||
if self.extra_feature:
|
||||
if self.DBC2SBC:
|
||||
logging.warning("Detect the extra feature %s, set the DBC2sbc is False." % ''.join(list(extra_feature_num)))
|
||||
if self.unicode_fix:
|
||||
logging.warning("Detect the extra feature %s, set the unicode_fix is False." % ''.join(list(extra_feature_num)))
|
||||
if self.remove_stopwords:
|
||||
logging.warning("Detect the extra feature %s, set the remove_stopwords is False." % ''.join(list(extra_feature_num)))
|
||||
|
||||
# output layer
|
||||
self.output_layer_id = []
|
||||
for single_layer in self.architecture:
|
||||
if 'output_layer_flag' in single_layer and single_layer['output_layer_flag']:
|
||||
|
@ -384,42 +445,59 @@ class ModelConf(object):
|
|||
self.min_sentence_len = max(self.min_sentence_len, np.max(np.array([single_conf_value])))
|
||||
break
|
||||
|
||||
if self.phase == 'train' or self.phase == 'test':
|
||||
self.loss = BaseLossConf.get_conf(**self.get_item(['loss']))
|
||||
self.metrics = self.get_item(['metrics'])
|
||||
if 'auc' in self.metrics and ProblemTypes[self.problem_type] == ProblemTypes.classification:
|
||||
self.pos_label = self.get_item(['inputs', 'positive_label'], default=None, use_default=True)
|
||||
def configurate_loss(self):
|
||||
if self.phase != 'train' and self.phase != 'test':
|
||||
return
|
||||
|
||||
if self.loss is None or self.metrics is None:
|
||||
self.raise_configuration_error('loss/metrics')
|
||||
self.loss = BaseLossConf.get_conf(**self.loss)
|
||||
|
||||
def get_item(self, keys, default=None, use_default=False):
|
||||
"""
|
||||
if 'auc' in self.metrics and ProblemTypes[self.problem_type] == ProblemTypes.classification:
|
||||
self.pos_label = self.positive_label
|
||||
|
||||
Args:
|
||||
keys:
|
||||
default: if some key is not found and default is None, we would raise an Exception, except that use_default is True
|
||||
use_default: if you really want to set default to None, set use_default=True
|
||||
def configurate_cache(self):
|
||||
# whether use cache
|
||||
if self.mode == 'philly':
|
||||
self.use_cache = True
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
item = self.conf
|
||||
valid_keys = []
|
||||
try:
|
||||
for key in keys:
|
||||
item = item[key]
|
||||
valid_keys.append(key)
|
||||
except:
|
||||
error_keys = copy.deepcopy(valid_keys)
|
||||
error_keys.append(key)
|
||||
if default is None and use_default is False:
|
||||
raise ConfigurationError(
|
||||
"The configuration file %s is illegal. There should be an item configuration[%s], "
|
||||
"but the item %s is not found." % (self.conf_path, "][".join(error_keys), key))
|
||||
# cache dir
|
||||
if self.phase == 'train':
|
||||
if hasattr(self.params, 'cache_dir') and self.params.cache_dir:
|
||||
self.cache_dir = self.params.cache_dir
|
||||
else:
|
||||
# print("configuration[%s] is not found in %s, use default value %s" %
|
||||
# ("][".join(error_keys), self.conf_path, repr(default)))
|
||||
item = default
|
||||
if self.mode == 'normal':
|
||||
if self.use_cache is False:
|
||||
self.cache_dir = os.path.join(tempfile.gettempdir(), 'neuron_blocks', ''.join(random.sample(string.ascii_letters+string.digits, 16)))
|
||||
else:
|
||||
# for philly mode, we can only save files in model_path or scratch_path
|
||||
self.cache_dir = os.path.join(self.save_base_dir, 'cache')
|
||||
|
||||
return item
|
||||
self.problem_path = os.path.join(self.cache_dir, 'problem.pkl')
|
||||
if self.pretrained_emb_path is not None:
|
||||
self.emb_pkl_path = os.path.join(self.cache_dir, 'emb.pkl')
|
||||
else:
|
||||
self.emb_pkl_path = None
|
||||
else:
|
||||
tmp_problem_path = os.path.join(self.save_base_dir, '.necessary_cache', 'problem.pkl')
|
||||
self.problem_path = tmp_problem_path if os.path.isfile(tmp_problem_path) else os.path.join(self.save_base_dir, 'necessary_cache', 'problem.pkl')
|
||||
|
||||
# md5 of training data and problem
|
||||
self.train_data_md5 = None
|
||||
if self.phase == 'train' and self.train_data_path:
|
||||
logging.info("Calculating the md5 of traing data ...")
|
||||
self.train_data_md5 = md5([self.train_data_path])
|
||||
logging.info("the md5 of traing data is %s"%(self.train_data_md5))
|
||||
self.problem_md5 = None
|
||||
|
||||
# encoding
|
||||
self.encoding_cache_dir = None
|
||||
self.encoding_cache_index_file_path = None
|
||||
self.encoding_cache_index_file_md5_path = None
|
||||
self.encoding_file_index = None
|
||||
self.encoding_cache_legal_line_cnt = 0
|
||||
self.encoding_cache_illegal_line_cnt = 0
|
||||
self.load_encoding_cache_generator = None
|
||||
|
||||
def check_conf(self):
|
||||
""" verify if the configuration is legal or not
|
||||
|
@ -537,24 +615,3 @@ class ModelConf(object):
|
|||
def back_up(self, params):
|
||||
shutil.copy(params.conf_path, self.save_base_dir)
|
||||
logging.info('Configuration file is backed up to %s' % (self.save_base_dir))
|
||||
|
||||
def _load_cache_config_from_conf(self):
|
||||
# training data
|
||||
self.train_data_md5 = None
|
||||
if self.phase == 'train' and self.train_data_path:
|
||||
logging.info("Calculating the md5 of traing data ...")
|
||||
self.train_data_md5 = md5([self.train_data_path])
|
||||
logging.info("the md5 of traing data is %s"%(self.train_data_md5))
|
||||
|
||||
# problem
|
||||
self.problem_md5 = None
|
||||
|
||||
# encoding
|
||||
self.encoding_cache_dir = None
|
||||
self.encoding_cache_index_file_path = None
|
||||
self.encoding_cache_index_file_md5_path = None
|
||||
self.encoding_file_index = None
|
||||
self.encoding_cache_legal_line_cnt = 0
|
||||
self.encoding_cache_illegal_line_cnt = 0
|
||||
self.load_encoding_cache_generator = None
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import time
|
|||
import tempfile
|
||||
import subprocess
|
||||
import hashlib
|
||||
from .exceptions import ConfigurationError
|
||||
|
||||
def log_set(log_path, console_level='INFO', console_detailed=False, disable_log_file=False):
|
||||
"""
|
||||
|
@ -38,29 +39,36 @@ def log_set(log_path, console_level='INFO', console_detailed=False, disable_log_
|
|||
logging.getLogger().addHandler(console)
|
||||
|
||||
|
||||
def load_from_pkl(pkl_path):
|
||||
def load_from_pkl(pkl_path, debug=True):
|
||||
with open(pkl_path, 'rb') as fin:
|
||||
obj = pkl.load(fin)
|
||||
logging.debug("%s loaded!" % pkl_path)
|
||||
if debug:
|
||||
logging.debug("%s loaded!" % pkl_path)
|
||||
return obj
|
||||
|
||||
|
||||
def dump_to_pkl(obj, pkl_path):
|
||||
def dump_to_pkl(obj, pkl_path, debug=True):
|
||||
with open(pkl_path, 'wb') as fout:
|
||||
pkl.dump(obj, fout, protocol=pkl.HIGHEST_PROTOCOL)
|
||||
logging.debug("Obj dumped to %s!" % pkl_path)
|
||||
if debug:
|
||||
logging.debug("Obj dumped to %s!" % pkl_path)
|
||||
|
||||
def load_from_json(json_path):
|
||||
def load_from_json(json_path, debug=True):
|
||||
data = None
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
logging.debug("%s loaded!" % json_path)
|
||||
try:
|
||||
data = json.loads(f.read())
|
||||
except Exception as e:
|
||||
raise ConfigurationError("%s is not a legal JSON file, please check your JSON format!" % json_path)
|
||||
if debug:
|
||||
logging.debug("%s loaded!" % json_path)
|
||||
return data
|
||||
|
||||
def dump_to_json(obj, json_path):
|
||||
def dump_to_json(obj, json_path, debug=True):
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(obj))
|
||||
logging.debug("Obj dumped to %s!" % json_path)
|
||||
if debug:
|
||||
logging.debug("Obj dumped to %s!" % json_path)
|
||||
|
||||
def get_trainable_param_num(model):
|
||||
""" get the number of trainable parameters
|
||||
|
|
Загрузка…
Ссылка в новой задаче