This commit is contained in:
Feixiang Cheng 2019-05-18 23:20:21 +08:00
Родитель 609cca813f
Коммит 1aa208bc1f
5 изменённых файлов: 164 добавлений и 115 удалений

Просмотреть файл

@ -13,7 +13,7 @@ import codecs
import pickle as pkl
from utils.common_utils import dump_to_pkl, load_from_pkl, get_param_num, get_trainable_param_num, \
transfer_to_gpu, transform_params2tensors
transfer_to_gpu, transform_params2tensors, load_from_json, dump_to_json
from utils.philly_utils import HDFSDirectTransferer, open_and_move, convert_to_tmppath, \
convert_to_hdfspath, move_from_local_to_hdfs
from Model import Model
@ -22,7 +22,7 @@ from metrics.Evaluator import Evaluator
from utils.corpus_utils import get_batches
from core.StreamingRecorder import StreamingRecorder
from core.LRScheduler import LRScheduler
from settings import ProblemTypes
from settings import ProblemTypes, Setting as st
from block_zoo import Linear
@ -83,38 +83,15 @@ class LearningMachine(object):
def train(self, optimizer, loss_fn):
self.model.train()
if not self.conf.train_data_path.endswith('.pkl'):
if self.conf.use_cache:
training_data_func = self.conf.encoding_generator_func
else:
train_data, train_length, train_target = self.problem.encode(self.conf.train_data_path, self.conf.file_columns,
self.conf.input_types, self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
min_sentence_len = self.conf.min_sentence_len, extra_feature=self.conf.extra_feature,fixed_lengths=self.conf.fixed_lengths, file_format='tsv',
show_progress=True if self.conf.mode == 'normal' else False, cpu_num_workers=self.conf.cpu_num_workers)
training_data_func = lambda : [(train_data, train_length, train_target)]
else:
train_pkl_data = load_from_pkl(self.conf.train_data_path)
train_data, train_length, train_target = train_pkl_data['data'], train_pkl_data['length'], train_pkl_data['target']
training_data_func = lambda : [(train_data, train_length, train_target)]
if not self.conf.valid_data_path.endswith('.pkl'):
valid_data, valid_length, valid_target = self.problem.encode(self.conf.valid_data_path, self.conf.file_columns,
self.conf.input_types, self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
min_sentence_len = self.conf.min_sentence_len, extra_feature = self.conf.extra_feature,fixed_lengths=self.conf.fixed_lengths, file_format='tsv',
show_progress=True if self.conf.mode == 'normal' else False, cpu_num_workers=self.conf.cpu_num_workers)
else:
valid_pkl_data = load_from_pkl(self.conf.valid_data_path)
valid_data, valid_length, valid_target = valid_pkl_data['data'], valid_pkl_data['length'], valid_pkl_data['target']
if self.conf.test_data_path is not None:
if not self.conf.test_data_path.endswith('.pkl'):
test_data, test_length, test_target = self.problem.encode(self.conf.test_data_path, self.conf.file_columns, self.conf.input_types,
self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
min_sentence_len = self.conf.min_sentence_len, extra_feature = self.conf.extra_feature,fixed_lengths=self.conf.fixed_lengths,
file_format='tsv', show_progress=True if self.conf.mode == 'normal' else False, cpu_num_workers=self.conf.cpu_num_workers)
else:
test_pkl_data = load_from_pkl(self.conf.test_data_path)
test_data, test_length, test_target = test_pkl_data['data'], test_pkl_data['length'], test_pkl_data['target']
valid_data, valid_length, valid_target = self.problem.encode(self.conf.valid_data_path, self.conf.file_columns,
self.conf.input_types, self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
min_sentence_len = self.conf.min_sentence_len, extra_feature = self.conf.extra_feature,fixed_lengths=self.conf.fixed_lengths, file_format='tsv',
show_progress=True if self.conf.mode == 'normal' else False, cpu_num_workers=self.conf.cpu_num_workers)
test_data, test_length, test_target = self.problem.encode(self.conf.test_data_path, self.conf.file_columns, self.conf.input_types,
self.conf.file_with_col_header, self.conf.object_inputs, self.conf.answer_column_name, max_lengths=self.conf.max_lengths,
min_sentence_len = self.conf.min_sentence_len, extra_feature = self.conf.extra_feature,fixed_lengths=self.conf.fixed_lengths,
file_format='tsv', show_progress=True if self.conf.mode == 'normal' else False, cpu_num_workers=self.conf.cpu_num_workers)
stop_training = False
epoch = 1
@ -131,14 +108,13 @@ class LearningMachine(object):
elif ProblemTypes[self.problem.problem_type] == ProblemTypes.mrc:
streaming_recoder = StreamingRecorder(['prediction', 'answer_text'])
part_num = len(self.conf.encoding_cache_index)
while not stop_training and epoch <= self.conf.max_epoch:
logging.info('Training: Epoch ' + str(epoch))
train_data_generator = training_data_func()
train_data_generator = self._get_training_data_generator()
part_index = 1
for train_data, train_length, train_target in train_data_generator:
logging.info('Training: Epoch %s Part (%s/%s)'%(epoch, part_index, part_num))
logging.debug('Training: Epoch %s Part %s'%(epoch, part_index))
part_index += 1
data_batches, length_batches, target_batches = \
get_batches(self.problem, train_data, train_length, train_target, self.conf.batch_size_total,
@ -156,7 +132,7 @@ class LearningMachine(object):
streaming_recoder.clear_records()
all_costs = []
logging.info('There are %d batches during an epoch; validation are conducted every %d batch' % (small_batch_num, valid_batch_num_show))
logging.info('There are %d batches during current period; validation are conducted every %d batch' % (small_batch_num, valid_batch_num_show))
if self.conf.mode == 'normal':
progress = tqdm(range(len(target_batches)))
@ -762,5 +738,19 @@ class LearningMachine(object):
logging.info("Model %s loaded!" % model_path)
logging.info("Total trainable parameters: %d" % (get_trainable_param_num(self.model)))
def _get_training_data_generator(self):
if not self.conf.use_cache:
return self.problem.get_encode_generator(self.conf, build_cache=False)
if not self.conf.encoding_file_index:
return self._get_save_encode_generator()
assert self.conf.load_encoding_cache_generator, 'function conf.load_encoding_cache_generator is not defined'
return self.conf.load_encoding_cache_generator(self.conf.encoding_cache_dir, self.conf.encoding_file_index)
def _get_save_encode_generator(self):
load_save_encode_generator = self.problem.get_encode_generator(self.conf, build_cache=True)
for data, lengths, target in load_save_encode_generator:
yield data, lengths, target
cache_index = load_from_json(self.conf.encoding_cache_index_file_path)
self.conf.encoding_file_index = cache_index[st.cencoding_key_index]

Просмотреть файл

@ -538,9 +538,10 @@ class ModelConf(object):
# encoding
self.encoding_cache_dir = None
self.encoding_cache_index_file_name = 'index.json'
self.encoding_cache_index_file_path = None
self.encoding_cache_index = None
self.encoding_cache_index_file_md5_name = 'index_md5.json'
self.encoding_cache_index_file_md5_path = None
self.encoding_file_index = None
self.encoding_cache_legal_line_cnt = 0
self.encoding_cache_illegal_line_cnt = 0
self.load_encoding_cache_generator = None

Просмотреть файл

@ -14,7 +14,7 @@ import os
import pickle as pkl
from utils.common_utils import load_from_pkl, dump_to_pkl, load_from_json, dump_to_json, prepare_dir, md5
from settings import ProblemTypes
from settings import ProblemTypes, Setting as st
import math
from utils.ProcessorsScheduler import ProcessorsScheduler
@ -264,27 +264,29 @@ class Problem():
bpe_encoder = None
self.file_column_num = len(file_columns)
progress = self.get_data_generator_from_file(training_data_path, file_with_col_header)
chunk_size = st.chunk_size
progress = self.get_data_generator_from_file(training_data_path, file_with_col_header, chunk_size=chunk_size)
preprocessed_data_generator= self.build_training_multi_processor(progress, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=bpe_encoder)
# update symbol universe
total_cnt_legal, total_cnt_illegal = 0, 0
for docs, target_docs, cnt_legal, cnt_illegal in tqdm(preprocessed_data_generator):
total_cnt_legal += cnt_legal
total_cnt_illegal += cnt_illegal
#for docs, target_docs, cnt_legal, cnt_illegal in tqdm(preprocessed_data_generator):
docs, target_docs, cnt_legal, cnt_illegal = next(preprocessed_data_generator)
total_cnt_legal += cnt_legal
total_cnt_illegal += cnt_illegal
# input_type
for input_type in input_types:
self.input_dicts[input_type].update(docs[input_type])
# problem_type
if ProblemTypes[self.problem_type] == ProblemTypes.classification or \
ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.output_dict.update(list(target_docs.values())[0])
elif ProblemTypes[self.problem_type] == ProblemTypes.regression or \
ProblemTypes[self.problem_type] == ProblemTypes.mrc:
pass
logging.info("Corpus imported: %d legal lines, %d illegal lines." % (total_cnt_legal, total_cnt_illegal))
# input_type
for input_type in input_types:
self.input_dicts[input_type].update(docs[input_type])
# problem_type
if ProblemTypes[self.problem_type] == ProblemTypes.classification or \
ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
self.output_dict.update(list(target_docs.values())[0])
elif ProblemTypes[self.problem_type] == ProblemTypes.regression or \
ProblemTypes[self.problem_type] == ProblemTypes.mrc:
pass
logging.info("[Corpus] at most %d lines imported: %d legal lines, %d illegal lines." % (chunk_size, total_cnt_legal, total_cnt_illegal))
# build dictionary
for input_type in input_types:
@ -702,13 +704,13 @@ class Problem():
bpe_encoder = None
progress = self.get_data_generator_from_file(data_path, file_with_col_header)
encoder_generator = self.encode_data_multi_processor(progress, cpu_num_workers,
encode_generator = self.encode_data_multi_processor(progress, cpu_num_workers,
file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths,
fixed_lengths, file_format, bpe_encoder=bpe_encoder)
data, lengths, target = dict(), dict(), dict()
cnt_legal, cnt_illegal = 0, 0
for temp_data, temp_lengths, temp_target, temp_cnt_legal, temp_cnt_illegal in tqdm(encoder_generator):
for temp_data, temp_lengths, temp_target, temp_cnt_legal, temp_cnt_illegal in tqdm(encode_generator):
data = self._merge_encode_data(data, temp_data)
lengths = self._merge_encode_lengths(lengths, temp_lengths)
target = self._merge_target(target, temp_target)
@ -718,37 +720,58 @@ class Problem():
logging.info("%s: %d legal samples, %d illegal samples" % (data_path, cnt_legal, cnt_illegal))
return data, lengths, target
def build_encode_cache(self, conf, file_format="tsv", cpu_num_workers = -1):
if 'bpe' in conf.input_types:
try:
bpe_encoder = BPEEncoder(conf.input_types['bpe']['bpe_path'])
except KeyError:
raise Exception('Please define a bpe path at the embedding layer.')
else:
bpe_encoder = None
def build_encode_cache(self, conf, file_format="tsv"):
logging.info("[Cache] building encoding cache")
build_encode_cache_generator = self.get_encode_generator(conf, build_cache=True, file_format=file_format)
for _ in build_encode_cache_generator:
continue
logging.info("[Cache] encoding is saved to %s" % conf.encoding_cache_dir)
def get_encode_generator(self, conf, build_cache=True, file_format="tsv"):
# parameter check
if build_cache:
assert conf.encoding_cache_dir, 'There is no property encoding_cache_dir in object conf'
assert conf.encoding_cache_index_file_path, 'There is no property encoding_cache_index_file_path in object conf'
assert conf.encoding_cache_index_file_md5_path, 'There is no property encoding_cache_index_file_md5_path in object conf'
progress = self.get_data_generator_from_file(conf.train_data_path, conf.file_with_col_header)
encoder_generator = self.encode_data_multi_processor(progress, cpu_num_workers,
bpe_encoder = self._check_bpe_encoder(conf.input_types)
data_generator = self.get_data_generator_from_file(conf.train_data_path, conf.file_with_col_header)
encode_generator = self.encode_data_multi_processor(data_generator, conf.cpu_num_workers,
conf.file_columns, conf.input_types, conf.object_inputs, conf.answer_column_name,
conf.min_sentence_len, conf.extra_feature, conf.max_lengths,
conf.fixed_lengths, file_format, bpe_encoder=bpe_encoder)
file_index = []
total_cnt_legal, total_cnt_illegal = 0, 0
for part_number, encode_data in enumerate(encode_generator):
data, lengths, target, cnt_legal, cnt_illegal = encode_data
if build_cache:
total_cnt_legal = total_cnt_legal + cnt_legal
total_cnt_illegal = total_cnt_illegal + cnt_illegal
file_name = st.cencoding_file_name_pattern % (part_number)
file_path = os.path.join(conf.encoding_cache_dir, file_name)
dump_to_pkl((data, lengths, target), file_path)
file_index.append([file_name, md5([file_path])])
logging.info("Up to now, in %s: %d legal samples, %d illegal samples" % (conf.train_data_path, total_cnt_legal, total_cnt_illegal))
yield data, lengths, target
cnt_legal, cnt_illegal = 0, 0
cache_index = []
file_name_pattern = 'encoding_cache_%s.pkl'
part_number = 0
for temp_data, temp_lengths, temp_target, temp_cnt_legal, temp_cnt_illegal in tqdm(encoder_generator):
data = (temp_data, temp_lengths, temp_target)
file_name = file_name_pattern % (part_number)
file_path = os.path.join(conf.encoding_cache_dir, file_name)
dump_to_pkl(data, file_path)
cache_index.append([file_name, md5([file_path])])
cnt_legal += temp_cnt_legal
cnt_illegal += temp_cnt_illegal
part_number += 1
dump_to_json(cache_index, conf.encoding_cache_index_file_path)
dump_to_json(md5([conf.encoding_cache_index_file_path]), conf.encoding_cache_index_file_md5_path)
logging.info("%s: %d legal samples, %d illegal samples" % (conf.train_data_path, cnt_legal, cnt_illegal))
if build_cache:
cache_index = dict()
cache_index[st.cencoding_key_index] = file_index
cache_index[st.cencoding_key_legal_cnt] = total_cnt_legal
cache_index[st.cencoding_key_illegal_cnt] = total_cnt_illegal
dump_to_json(cache_index, conf.encoding_cache_index_file_path)
dump_to_json(md5([conf.encoding_cache_index_file_path]), conf.encoding_cache_index_file_md5_path)
@staticmethod
def _check_bpe_encoder(input_types):
bpe_encoder = None
if 'bpe' in input_types:
try:
bpe_encoder = BPEEncoder(input_types['bpe']['bpe_path'])
except KeyError:
raise Exception('Please define a bpe path at the embedding layer.')
return bpe_encoder
def decode(self, model_output, lengths=None, batch_data=None):
""" decode the model output, either a batch of output or a single output

Просмотреть файл

@ -53,3 +53,29 @@ DefaultPredictionFields = {
# nltk's models
nltk.data.path.append(os.path.join(os.getcwd(), 'dataset', 'nltk_data'))
class Constant(type):
def __setattr__(self, name, value):
raise AttributeError("Class %s can not be modified"%(self.__name__))
class ConstantStatic(metaclass=Constant):
def __init__(self, *args,**kwargs):
raise Exception("Class %s can not be instantiated"%(self.__class__.__name__))
class Setting(ConstantStatic):
# cache
## cencoding (cache_encoding)
cencodig_index_file_name = 'index.json'
cencoding_index_md5_file_name = 'index_md5.json'
cencoding_file_name_pattern = 'encoding_cache_%s.pkl'
cencoding_key_finish = 'finish'
cencoding_key_index = 'index'
cencoding_key_legal_cnt = 'legal_line_cnt'
cencoding_key_illegal_cnt = 'illegal_line_cnt'
# lazy load
chunk_size = 1000 * 1000

Просмотреть файл

@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
from settings import ProblemTypes, version
from settings import ProblemTypes, version, Setting as st
import os
import argparse
@ -56,6 +56,7 @@ class Cache:
self.embedding_invalid = False
self.dictionary_invalid = False
logging.info('[Cache] dictionary found')
return True
def _check_encoding(self, conf):
@ -75,26 +76,32 @@ class Cache:
# check the valid of encoding cache
## encoding cache dir
conf.encoding_cache_dir = os.path.join(conf.cache_dir, conf.train_data_md5 + conf.problem_md5)
logging.info('conf.encoding_cache_dir %s'%(conf.encoding_cache_dir))
logging.debug('[Cache] conf.encoding_cache_dir %s' % (conf.encoding_cache_dir))
if not os.path.exists(conf.encoding_cache_dir):
return False
## encoding cache index
conf.encoding_cache_index_file_path = os.path.join(conf.encoding_cache_dir, conf.encoding_cache_index_file_name)
conf.encoding_cache_index_file_md5_path = os.path.join(conf.encoding_cache_dir, conf.encoding_cache_index_file_md5_name)
conf.encoding_cache_index_file_path = os.path.join(conf.encoding_cache_dir, st.cencodig_index_file_name)
conf.encoding_cache_index_file_md5_path = os.path.join(conf.encoding_cache_dir, st.cencoding_index_md5_file_name)
if not os.path.exists(conf.encoding_cache_index_file_path) or not os.path.exists(conf.encoding_cache_index_file_md5_path):
return False
if md5([conf.encoding_cache_index_file_path]) != load_from_json(conf.encoding_cache_index_file_md5_path):
return False
conf.encoding_cache_index = load_from_json(conf.encoding_cache_index_file_path)
cache_index = load_from_json(conf.encoding_cache_index_file_path)
## encoding cache content
for index in conf.encoding_cache_index:
for index in cache_index[st.cencoding_key_index]:
file_name, file_md5 = index[0], index[1]
if file_md5 != md5([os.path.join(conf.encoding_cache_dir, file_name)]):
return False
if hasattr(cache_index, st.cencoding_key_legal_cnt) and hasattr(cache_index, st.cencoding_key_illegal_cnt):
conf.encoding_cache_legal_line_cnt = cache_index[st.cencoding_key_legal_cnt]
conf.encoding_cache_illegal_line_cnt = cache_index[st.cencoding_key_illegal_cnt]
self.encoding_invalid = False
logging.info('[Cache] encoding found')
logging.info('%s: %d legal samples, %d illegal samples' % (conf.train_data_path, conf.encoding_cache_legal_line_cnt, conf.encoding_cache_illegal_line_cnt))
return True
def check(self, conf, params):
@ -107,7 +114,7 @@ class Cache:
self._renew_cache(params, conf.encoding_cache_dir)
def load(self, conf, problem, emb_matrix):
# load dictionary when (not finetune) and (cache invalid)
# load dictionary when (not finetune) and (cache valid)
if not conf.pretrained_model_path and not self.dictionary_invalid:
problem.load_problem(conf.problem_path)
if not self.embedding_invalid:
@ -143,11 +150,7 @@ class Cache:
# encoding
if self.encoding_invalid:
self._prepare_encoding_cache(conf, problem, build=True)
logging.info("[Cache] encoding is saved to %s" % conf.encoding_cache_dir)
logging.info("[Cache] encoding_cache_index_file_path is %s" % conf.encoding_cache_index_file_path)
logging.info("[Cache] encoding_cache_index_file_md5_path is %s" % conf.encoding_cache_index_file_md5_path)
logging.info("[Cache] encoding_cache_index is %s" % conf.encoding_cache_index)
self._prepare_encoding_cache(conf, problem, build=params.make_cache_only)
def back_up(self, conf, problem):
cache_bakup_path = os.path.join(conf.save_base_dir, 'necessary_cache/')
@ -194,23 +197,33 @@ class Cache:
return flag
def _prepare_encoding_cache(self, conf, problem, build=False):
# encoding cache dir
problem_path = conf.problem_path if not conf.pretrained_model_path else conf.saved_problem_path
conf.problem_md5 = md5([problem_path])
conf.encoding_cache_dir = os.path.join(conf.cache_dir, conf.train_data_md5 + conf.problem_md5)
conf.encoding_cache_index_file_path = os.path.join(conf.encoding_cache_dir, conf.encoding_cache_index_file_name)
conf.encoding_cache_index_file_md5_path = os.path.join(conf.encoding_cache_dir, conf.encoding_cache_index_file_md5_name)
if not os.path.exists(conf.encoding_cache_dir):
os.makedirs(conf.encoding_cache_dir)
# encoding cache files
conf.encoding_cache_index_file_path = os.path.join(conf.encoding_cache_dir, st.cencodig_index_file_name)
conf.encoding_cache_index_file_md5_path = os.path.join(conf.encoding_cache_dir, st.cencoding_index_md5_file_name)
conf.load_encoding_cache_generator = self._load_encoding_cache_generator
if build:
prepare_dir(conf.encoding_cache_dir, True, allow_overwrite=True, clear_dir_if_exist=True)
problem.build_encode_cache(conf, file_format="tsv", cpu_num_workers=conf.cpu_num_workers)
conf.encoding_cache_index = load_from_json(conf.encoding_cache_index_file_path)
conf.encoding_generator_func = lambda : self._load_encoding_cache(conf)
@staticmethod
def _load_encoding_cache(conf):
for cache_index in conf.encoding_cache_index:
file_path = os.path.join(conf.encoding_cache_dir, cache_index[0])
yield load_from_pkl(file_path)
problem.build_encode_cache(conf)
self.encoding_invalid = False
if not self.encoding_invalid:
cache_index = load_from_json(conf.encoding_cache_index_file_path)
conf.encoding_file_index = cache_index[st.cencoding_key_index]
@staticmethod
def _load_encoding_cache_generator(cache_dir, file_index):
for index in file_index:
file_path = os.path.join(cache_dir, index[0])
yield load_from_pkl(file_path)
def main(params):
# init
conf = ModelConf("train", params.conf_path, version, params, mode=params.mode)
@ -241,10 +254,6 @@ def main(params):
show_progress=True if params.mode == 'normal' else False, cpu_num_workers = conf.cpu_num_workers,
max_vocabulary=conf.max_vocabulary, word_frequency=conf.min_word_frequency)
## encode rawdata when do not use cache
if conf.use_cache == False:
pass
# environment preparing
## cache save
if conf.use_cache:
@ -270,7 +279,7 @@ def main(params):
vocab_info, initialize = None, False
if not conf.pretrained_model_path:
vocab_info, initialize = get_vocab_info(problem, emb_matrix), True
print(initialize)
lm = LearningMachine('train', conf, problem, vocab_info=vocab_info, initialize=initialize, use_gpu=conf.use_gpu)
if conf.pretrained_model_path:
logging.info('Loading the pretrained model: %s...' % conf.pretrained_model_path)