Use lazy loading instead of one-off loading (#28)
* Add new config about knowledge distillation for query binary classifier * remove inferenced result in knowledge distillation for query binary classifier * Add AUC.py in tools folder * Add test_data_path into conf_kdqbc_bilstmattn_cnn.json * Modify AUC.py * Rename AUC.py into calculate_AUC.py * Modify test&calculate AUC commands for Knowledge Distillation for Query Binary Classifier * Add cpu_thread_num parameter in conf.training_params * Rename cpu_thread_num into cpu_num_workers * update comments in ModelConf.py * Add cup_num_workers in model_zoo/advanced/conf.json * Add the description of cpu_num_workers in Tutorial.md * Update inference speed of compressed model * Add ProcessorsScheduler Class * Add license in ProcessorScheduler.py * use lazy loading instead of one-off loading * Remove Debug Info in problem.py * use open instead of codecs.open * update the inference of build dictionary for classification * update typo * update typo
This commit is contained in:
Родитель
dddafc168e
Коммит
fa9e8e2f19
|
@ -66,23 +66,23 @@ class CellDict(object):
|
|||
|
||||
return cell in self.cell_id_map
|
||||
|
||||
def build(self, docs, threshold, max_vocabulary_num=800000):
|
||||
|
||||
def update(self, docs):
|
||||
for doc in docs:
|
||||
# add type judge
|
||||
if isinstance(doc, list):
|
||||
self.cell_doc_count.update(set(doc))
|
||||
else:
|
||||
self.cell_doc_count.update([doc])
|
||||
|
||||
|
||||
def build(self, threshold, max_vocabulary_num=800000):
|
||||
# restrict the vocabulary size to prevent embedding dict weight size
|
||||
self.cell_doc_count = Counter(dict(self.cell_doc_count.most_common(max_vocabulary_num)))
|
||||
for cell in self.cell_doc_count:
|
||||
|
||||
if cell not in self.cell_id_map and self.cell_doc_count[cell] >= threshold:
|
||||
id = len(self.cell_id_map)
|
||||
self.cell_id_map[cell] = id
|
||||
self.id_cell_map[id] = cell
|
||||
self.cell_doc_count = Counter()
|
||||
|
||||
def id_unk(self, cell):
|
||||
"""
|
||||
|
|
228
problem.py
228
problem.py
|
@ -10,7 +10,6 @@ import nltk
|
|||
nltk.download('punkt', quiet=True)
|
||||
nltk.download('stopwords', quiet=True)
|
||||
from utils.BPEEncoder import BPEEncoder
|
||||
import codecs
|
||||
import os
|
||||
import pickle as pkl
|
||||
from utils.common_utils import load_from_pkl, dump_to_pkl
|
||||
|
@ -94,16 +93,21 @@ class Problem():
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_data_list_from_file(self, fin, file_with_col_header):
|
||||
data_list = list()
|
||||
for index, line in enumerate(fin):
|
||||
if file_with_col_header and index == 0:
|
||||
continue
|
||||
line = line.rstrip()
|
||||
if not line:
|
||||
break
|
||||
data_list.append(line)
|
||||
return data_list
|
||||
def get_data_generator_from_file(self, file_path, file_with_col_header, chunk_size=1000000):
|
||||
with open(file_path, "r", encoding='utf-8') as f:
|
||||
if file_with_col_header:
|
||||
f.readline()
|
||||
data_list = list()
|
||||
for index, line in enumerate(f):
|
||||
line = line.rstrip()
|
||||
if not line:
|
||||
break
|
||||
data_list.append(line)
|
||||
if (index + 1) % chunk_size == 0:
|
||||
yield data_list
|
||||
data_list = list()
|
||||
if len(data_list) > 0:
|
||||
yield data_list
|
||||
|
||||
def build_training_data_list(self, training_data_list, file_columns, input_types, answer_column_name, bpe_encoder=None):
|
||||
docs = dict() # docs of each type of input
|
||||
|
@ -165,33 +169,35 @@ class Problem():
|
|||
pass
|
||||
return docs, target_docs, cnt_legal, cnt_illegal
|
||||
|
||||
def build_training_multi_processor(self, training_data_list, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=None):
|
||||
scheduler = ProcessorsScheduler(cpu_num_workers)
|
||||
func_args = (training_data_list, file_columns, input_types, answer_column_name, bpe_encoder)
|
||||
res = scheduler.run_data_parallel(self.build_training_data_list, func_args)
|
||||
def build_training_multi_processor(self, training_data_generator, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=None):
|
||||
for data in training_data_generator:
|
||||
# multi-Processing
|
||||
scheduler = ProcessorsScheduler(cpu_num_workers)
|
||||
func_args = (data, file_columns, input_types, answer_column_name, bpe_encoder)
|
||||
res = scheduler.run_data_parallel(self.build_training_data_list, func_args)
|
||||
# aggregate
|
||||
docs = dict() # docs of each type of input
|
||||
target_docs = []
|
||||
cnt_legal = 0
|
||||
cnt_illegal = 0
|
||||
for (index, j) in res:
|
||||
#logging.info("collect proccesor %d result" % index)
|
||||
tmp_docs, tmp_target_docs, tmp_cnt_legal, tmp_cnt_illegal = j.get()
|
||||
if len(docs) == 0:
|
||||
docs = tmp_docs
|
||||
else:
|
||||
for key, value in tmp_docs.items():
|
||||
docs[key].extend(value)
|
||||
if len(target_docs) == 0:
|
||||
target_docs = tmp_target_docs
|
||||
else:
|
||||
for single_type in tmp_target_docs:
|
||||
target_docs[single_type].extend(tmp_target_docs[single_type])
|
||||
# target_docs.extend(tmp_target_docs)
|
||||
cnt_legal += tmp_cnt_legal
|
||||
cnt_illegal += tmp_cnt_illegal
|
||||
|
||||
docs = dict() # docs of each type of input
|
||||
target_docs = []
|
||||
cnt_legal = 0
|
||||
cnt_illegal = 0
|
||||
for (index, j) in res:
|
||||
#logging.info("collect proccesor %d result" % index)
|
||||
tmp_docs, tmp_target_docs, tmp_cnt_legal, tmp_cnt_illegal = j.get()
|
||||
if len(docs) == 0:
|
||||
docs = tmp_docs
|
||||
else:
|
||||
for key, value in tmp_docs.items():
|
||||
docs[key].extend(value)
|
||||
if len(target_docs) == 0:
|
||||
target_docs = tmp_target_docs
|
||||
else:
|
||||
for single_type in tmp_target_docs:
|
||||
target_docs[single_type].extend(tmp_target_docs[single_type])
|
||||
# target_docs.extend(tmp_target_docs)
|
||||
cnt_legal += tmp_cnt_legal
|
||||
cnt_illegal += tmp_cnt_illegal
|
||||
|
||||
return docs, target_docs, cnt_legal, cnt_illegal
|
||||
yield docs, target_docs, cnt_legal, cnt_illegal
|
||||
|
||||
def build(self, training_data_path, file_columns, input_types, file_with_col_header, answer_column_name, word2vec_path=None, word_emb_dim=None,
|
||||
format=None, file_type=None, involve_all_words=None, file_format="tsv", show_progress=True,
|
||||
|
@ -245,38 +251,47 @@ class Problem():
|
|||
bpe_encoder = None
|
||||
|
||||
self.file_column_num = len(file_columns)
|
||||
with open(training_data_path, "r", encoding='utf-8') as f:
|
||||
progress = self.get_data_list_from_file(f, file_with_col_header)
|
||||
docs, target_docs, cnt_legal, cnt_illegal = self.build_training_multi_processor(progress, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=bpe_encoder)
|
||||
progress = self.get_data_generator_from_file(training_data_path, file_with_col_header)
|
||||
preprocessed_data_generator= self.build_training_multi_processor(progress, cpu_num_workers, file_columns, input_types, answer_column_name, bpe_encoder=bpe_encoder)
|
||||
|
||||
# update symbol universe
|
||||
total_cnt_legal, total_cnt_illegal = 0, 0
|
||||
for docs, target_docs, cnt_legal, cnt_illegal in tqdm(preprocessed_data_generator):
|
||||
total_cnt_legal += cnt_legal
|
||||
total_cnt_illegal += cnt_illegal
|
||||
|
||||
logging.info("Corpus imported: %d legal lines, %d illegal lines." % (cnt_legal, cnt_illegal))
|
||||
|
||||
if word2vec_path and involve_all_words is True:
|
||||
logging.info("Getting pre-trained embeddings...")
|
||||
word_emb_dict = load_embedding(word2vec_path, word_emb_dim, format, file_type, with_head=False, word_set=None)
|
||||
self.input_dicts['word'].build([list(word_emb_dict.keys())], max_vocabulary_num=len(word_emb_dict), threshold=0)
|
||||
# input_type
|
||||
for input_type in input_types:
|
||||
self.input_dicts[input_type].update(docs[input_type])
|
||||
|
||||
# problem_type
|
||||
if ProblemTypes[self.problem_type] == ProblemTypes.classification or \
|
||||
ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
self.output_dict.update(list(target_docs.values())[0])
|
||||
elif ProblemTypes[self.problem_type] == ProblemTypes.regression or \
|
||||
ProblemTypes[self.problem_type] == ProblemTypes.mrc:
|
||||
pass
|
||||
logging.info("Corpus imported: %d legal lines, %d illegal lines." % (total_cnt_legal, total_cnt_illegal))
|
||||
|
||||
# build dictionary
|
||||
for input_type in input_types:
|
||||
if input_type != 'word':
|
||||
self.input_dicts[input_type].build(docs[input_type], max_vocabulary_num=max_vocabulary, threshold=word_frequency)
|
||||
else:
|
||||
self.input_dicts[input_type].build(docs[input_type], max_vocabulary_num=max_vocabulary, threshold=word_frequency)
|
||||
logging.info("%d types in %s" % (self.input_dicts[input_type].cell_num(), input_type))
|
||||
if ProblemTypes[self.problem_type] == ProblemTypes.classification:
|
||||
self.output_dict.build(list(target_docs.values())[0], threshold=0)
|
||||
elif ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
self.output_dict.build(list(target_docs.values())[0], threshold=0)
|
||||
elif ProblemTypes[self.problem_type] == ProblemTypes.regression or \
|
||||
ProblemTypes[self.problem_type] == ProblemTypes.mrc:
|
||||
pass
|
||||
|
||||
self.input_dicts[input_type].build(threshold=word_frequency, max_vocabulary_num=max_vocabulary)
|
||||
logging.info("%d types in %s column" % (self.input_dicts[input_type].cell_num(), input_type))
|
||||
if self.output_dict:
|
||||
logging.info("%d types in target" % (self.output_dict.cell_num()))
|
||||
|
||||
logging.debug("Cell dict built")
|
||||
self.output_dict.build(threshold=0)
|
||||
logging.info("%d types in target column" % (self.output_dict.cell_num()))
|
||||
logging.debug("training data dict built")
|
||||
|
||||
# embedding
|
||||
word_emb_matrix = None
|
||||
if word2vec_path:
|
||||
if not involve_all_words:
|
||||
logging.info("Getting pre-trained embeddings...")
|
||||
logging.info("Getting pre-trained embeddings...")
|
||||
word_emb_dict = None
|
||||
if involve_all_words is True:
|
||||
word_emb_dict = load_embedding(word2vec_path, word_emb_dim, format, file_type, with_head=False, word_set=None)
|
||||
self.input_dicts['word'].update([list(word_emb_dict.keys())])
|
||||
self.input_dicts['word'].build(threshold=0, max_vocabulary_num=len(word_emb_dict))
|
||||
else:
|
||||
word_emb_dict = load_embedding(word2vec_path, word_emb_dim, format, file_type, with_head=False, word_set=self.input_dicts['word'].cell_id_map.keys())
|
||||
|
||||
for word in word_emb_dict:
|
||||
|
@ -285,6 +300,7 @@ class Problem():
|
|||
|
||||
assert loaded_emb_dim == word_emb_dim, "The dimension of defined word embedding is inconsistent with the pretrained embedding provided!"
|
||||
|
||||
logging.info("constructing embedding table")
|
||||
if self.input_dicts['word'].with_unk:
|
||||
word_emb_dict['<unk>'] = np.random.random(size=word_emb_dim)
|
||||
if self.input_dicts['word'].with_pad:
|
||||
|
@ -302,57 +318,54 @@ class Problem():
|
|||
logging.info("word embedding matrix shape:(%d, %d); unknown word count: %d;" %
|
||||
(len(word_emb_matrix), len(word_emb_matrix[0]), unknown_word_count))
|
||||
logging.info("Word embedding loaded")
|
||||
else:
|
||||
word_emb_matrix = None
|
||||
|
||||
return word_emb_matrix
|
||||
|
||||
def encode_data_multi_processor(self, data_list, cpu_num_workers, file_columns, input_types, object_inputs,
|
||||
def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
|
||||
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
|
||||
def judge_dict(obj):
|
||||
return True if isinstance(obj, dict) else False
|
||||
|
||||
scheduler = ProcessorsScheduler(cpu_num_workers)
|
||||
func_args = (data_list, file_columns, input_types, object_inputs,
|
||||
answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder)
|
||||
res = scheduler.run_data_parallel(self.encode_data_list, func_args)
|
||||
|
||||
data = dict()
|
||||
cnt_legal, cnt_illegal = 0, 0
|
||||
output_data = dict()
|
||||
lengths = dict()
|
||||
target = dict()
|
||||
cnt_legal = 0
|
||||
cnt_illegal = 0
|
||||
for data in tqdm(data_generator):
|
||||
scheduler = ProcessorsScheduler(cpu_num_workers)
|
||||
func_args = (data, file_columns, input_types, object_inputs,
|
||||
answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder)
|
||||
res = scheduler.run_data_parallel(self.encode_data_list, func_args)
|
||||
|
||||
for (index, j) in res:
|
||||
# logging.info("collect proccesor %d result"%index)
|
||||
tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get()
|
||||
|
||||
for (index, j) in res:
|
||||
# logging.info("collect proccesor %d result"%index)
|
||||
tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get()
|
||||
|
||||
if len(data) == 0:
|
||||
data = tmp_data
|
||||
else:
|
||||
for branch in tmp_data:
|
||||
for input_type in data[branch]:
|
||||
data[branch][input_type].extend(tmp_data[branch][input_type])
|
||||
if len(lengths) == 0:
|
||||
lengths = tmp_lengths
|
||||
else:
|
||||
for branch in tmp_lengths:
|
||||
if judge_dict(tmp_lengths[branch]):
|
||||
for type_branch in tmp_lengths[branch]:
|
||||
lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch])
|
||||
else:
|
||||
lengths[branch].extend(tmp_lengths[branch])
|
||||
if not tmp_target:
|
||||
target = None
|
||||
else:
|
||||
if len(target) == 0:
|
||||
target = tmp_target
|
||||
if len(output_data) == 0:
|
||||
output_data = tmp_data
|
||||
else:
|
||||
for single_type in tmp_target:
|
||||
target[single_type].extend(tmp_target[single_type])
|
||||
cnt_legal += tmp_cnt_legal
|
||||
cnt_illegal += tmp_cnt_illegal
|
||||
for branch in tmp_data:
|
||||
for input_type in output_data[branch]:
|
||||
output_data[branch][input_type].extend(tmp_data[branch][input_type])
|
||||
if len(lengths) == 0:
|
||||
lengths = tmp_lengths
|
||||
else:
|
||||
for branch in tmp_lengths:
|
||||
if judge_dict(tmp_lengths[branch]):
|
||||
for type_branch in tmp_lengths[branch]:
|
||||
lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch])
|
||||
else:
|
||||
lengths[branch].extend(tmp_lengths[branch])
|
||||
if not tmp_target:
|
||||
target = None
|
||||
else:
|
||||
if len(target) == 0:
|
||||
target = tmp_target
|
||||
else:
|
||||
for single_type in tmp_target:
|
||||
target[single_type].extend(tmp_target[single_type])
|
||||
cnt_legal += tmp_cnt_legal
|
||||
cnt_illegal += tmp_cnt_illegal
|
||||
|
||||
return data, lengths, target, cnt_legal, cnt_illegal
|
||||
return output_data, lengths, target, cnt_legal, cnt_illegal
|
||||
|
||||
def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len,
|
||||
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
|
||||
|
@ -661,9 +674,8 @@ class Problem():
|
|||
else:
|
||||
bpe_encoder = None
|
||||
|
||||
with open(data_path, 'r', encoding='utf-8') as fin:
|
||||
progress = self.get_data_list_from_file(fin, file_with_col_header)
|
||||
data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers,
|
||||
progress = self.get_data_generator_from_file(data_path, file_with_col_header)
|
||||
data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers,
|
||||
file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths,
|
||||
fixed_lengths, file_format, bpe_encoder=bpe_encoder)
|
||||
logging.info("%s: %d legal samples, %d illegal samples" % (data_path, cnt_legal, cnt_illegal))
|
||||
|
|
|
@ -2,14 +2,13 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
def read_tsv(params):
|
||||
prediction, label = [], []
|
||||
predict_index, label_index = int(params.predict_index), int(params.label_index)
|
||||
min_column_num = max(predict_index, label_index) + 1
|
||||
with codecs.open(params.input_file, mode='r', encoding='utf-8') as f:
|
||||
with open(params.input_file, mode='r', encoding='utf-8') as f:
|
||||
for index, line in enumerate(f):
|
||||
if params.header and index == 0:
|
||||
continue
|
||||
|
|
Загрузка…
Ссылка в новой задаче