2019-04-20 14:17:30 +03:00
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import logging
import numpy as np
from core . CellDict import CellDict
from tqdm import tqdm
from utils . corpus_utils import load_embedding
2019-04-30 09:06:17 +03:00
import nltk
2019-05-07 18:23:20 +03:00
nltk . download ( ' punkt ' , quiet = True )
2019-05-08 14:07:52 +03:00
nltk . download ( ' stopwords ' , quiet = True )
2019-04-20 14:17:30 +03:00
from utils . BPEEncoder import BPEEncoder
import os
import pickle as pkl
2019-08-02 15:59:00 +03:00
from utils . common_utils import load_from_pkl , dump_to_pkl , load_from_json , dump_to_json , prepare_dir , md5
2019-04-20 14:17:30 +03:00
2019-08-02 15:59:00 +03:00
from settings import ProblemTypes , Setting as st
2019-04-20 14:17:30 +03:00
import math
2019-05-08 12:44:19 +03:00
from utils . ProcessorsScheduler import ProcessorsScheduler
2019-04-20 14:17:30 +03:00
from core . EnglishTokenizer import EnglishTokenizer
2019-05-08 14:07:52 +03:00
from core . ChineseTokenizer import ChineseTokenizer
2019-04-20 14:17:30 +03:00
from core . EnglishTextPreprocessor import EnglishTextPreprocessor
from utils . exceptions import PreprocessError
import torch
import torch . nn as nn
class Problem ( ) :
2019-05-23 02:55:17 +03:00
def __init__ ( self , phase , problem_type , input_types , answer_column_name = None , lowercase = False , with_bos_eos = True ,
2019-05-08 14:07:52 +03:00
tagging_scheme = None , tokenizer = " nltk " , remove_stopwords = False , DBC2SBC = True , unicode_fix = True ) :
2019-04-20 14:17:30 +03:00
"""
Args :
input_types : {
" word " : [ " word1 " , " word1 " ] ,
" postag " : [ " postag_feature1 " , " postag_feature2 " ]
}
answer_column_name : " label " after v1 .0 .0 answer_column_name change to list
source_with_start :
source_with_end :
source_with_unk :
source_with_pad :
target_with_start :
target_with_end :
target_with_unk :
target_with_pad :
same_length :
with_bos_eos : whether to add bos and eos when encoding
"""
2019-05-23 02:55:17 +03:00
# init
source_with_start , source_with_end , source_with_unk , source_with_pad , \
target_with_start , target_with_end , target_with_unk , target_with_pad , \
same_length = ( True , ) * 9
if ProblemTypes [ problem_type ] == ProblemTypes . sequence_tagging :
pass
elif \
ProblemTypes [ problem_type ] == ProblemTypes . classification or \
ProblemTypes [ problem_type ] == ProblemTypes . regression :
target_with_start , target_with_end , target_with_unk , target_with_pad , same_length = ( False , ) * 5
if phase != ' train ' :
same_length = True
elif ProblemTypes [ problem_type ] == ProblemTypes . mrc :
target_with_start , target_with_end , target_with_unk , target_with_pad , same_length = ( False , ) * 5
with_bos_eos = False
2019-06-19 09:44:07 +03:00
if ProblemTypes [ problem_type ] == ProblemTypes . sequence_tagging :
target_with_start = False
target_with_end = False
target_with_unk = False
2019-05-23 02:55:17 +03:00
self . lowercase = lowercase
2019-04-20 14:17:30 +03:00
self . problem_type = problem_type
self . tagging_scheme = tagging_scheme
self . with_bos_eos = with_bos_eos
self . source_with_start = source_with_start
self . source_with_end = source_with_end
self . source_with_unk = source_with_unk
self . source_with_pad = source_with_pad
self . target_with_start = target_with_start
self . target_with_end = target_with_end
self . target_with_unk = target_with_unk
self . target_with_pad = target_with_pad
2019-05-23 02:55:17 +03:00
self . input_dicts = dict ( )
2019-04-20 14:17:30 +03:00
for input_type in input_types :
self . input_dicts [ input_type ] = CellDict ( with_unk = source_with_unk , with_pad = source_with_pad ,
with_start = source_with_start , with_end = source_with_end )
if ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging or \
ProblemTypes [ self . problem_type ] == ProblemTypes . classification :
self . output_dict = CellDict ( with_unk = target_with_unk , with_pad = target_with_pad ,
with_start = target_with_start , with_end = target_with_end )
elif ProblemTypes [ self . problem_type ] == ProblemTypes . regression or \
ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
self . output_dict = None
self . file_column_num = None
2019-05-08 14:07:52 +03:00
if tokenizer in [ ' nltk ' ] :
self . tokenizer = EnglishTokenizer ( tokenizer = tokenizer , remove_stopwords = remove_stopwords )
elif tokenizer in [ ' jieba ' ] :
self . tokenizer = ChineseTokenizer ( tokenizer = tokenizer , remove_stopwords = remove_stopwords )
2019-04-20 14:17:30 +03:00
self . text_preprocessor = EnglishTextPreprocessor ( DBC2SBC = DBC2SBC , unicode_fix = unicode_fix )
def input_word_num ( self ) :
return self . input_word_dict . cell_num ( )
def output_target_num ( self ) :
if ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging or ProblemTypes [ self . problem_type ] == ProblemTypes . classification :
return self . output_dict . cell_num ( )
else :
return None
2019-08-02 15:59:00 +03:00
def get_data_generator_from_file ( self , data_path , file_with_col_header , chunk_size = 1000000 ) :
data_list = list ( )
with open ( data_path , " r " , encoding = ' utf-8 ' ) as f :
if file_with_col_header :
f . readline ( )
for index , line in enumerate ( f ) :
line = line . rstrip ( )
if not line :
break
data_list . append ( line )
if ( index + 1 ) % chunk_size == 0 :
yield data_list
data_list = list ( )
if len ( data_list ) > 0 :
yield data_list
2019-04-20 14:17:30 +03:00
def build_training_data_list ( self , training_data_list , file_columns , input_types , answer_column_name , bpe_encoder = None ) :
docs = dict ( ) # docs of each type of input
col_index_types = dict ( ) # input type of each column, col_index_types[0] = 'word'/'postag'
target_docs = { } # after v1.0.0, the target_docs change to dict for support multi_label
columns_to_target = { }
for single_target in answer_column_name :
target_docs [ single_target ] = [ ]
columns_to_target [ file_columns [ single_target ] ] = single_target
for input_type in input_types :
docs [ input_type ] = [ ]
# char is not in file_columns
if input_type == ' char ' :
continue
for col in input_types [ input_type ] [ ' cols ' ] :
col_index_types [ file_columns [ col ] ] = input_type
cnt_legal = 0
cnt_illegal = 0
for line in training_data_list :
# line_split = list(filter(lambda x: len(x) > 0, line.rstrip().split('\t')))
line_split = line . rstrip ( ) . split ( ' \t ' )
if len ( line_split ) != len ( file_columns ) :
logging . warning ( " Current line is inconsistent with configuration/inputs/file_header. Ingore now. %s " % line )
cnt_illegal + = 1
continue
cnt_legal + = 1
for i in range ( len ( line_split ) ) :
if i in col_index_types :
if self . lowercase :
line_split [ i ] = line_split [ i ] . lower ( )
line_split [ i ] = self . text_preprocessor . preprocess ( line_split [ i ] )
if col_index_types [ i ] == ' word ' :
2019-06-05 14:55:43 +03:00
if ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
token_list = line_split [ i ] . split ( " " )
else :
token_list = self . tokenizer . tokenize ( line_split [ i ] )
2019-04-20 14:17:30 +03:00
docs [ col_index_types [ i ] ] . append ( token_list )
if ' char ' in docs :
# add char
docs [ ' char ' ] . append ( [ single_char for single_char in ' ' . join ( token_list ) ] )
elif col_index_types [ i ] == ' bpe ' :
bpe_tokens = [ ]
for token in self . tokenizer . tokenize ( line_split [ i ] ) :
bpe_tokens . extend ( bpe_encoder . bpe ( token ) )
docs [ col_index_types [ i ] ] . append ( bpe_tokens )
else :
docs [ col_index_types [ i ] ] . append ( line_split [ i ] . split ( " " ) )
# target_docs change to dict
elif i in columns_to_target . keys ( ) :
curr_target = columns_to_target [ i ]
if ProblemTypes [ self . problem_type ] == ProblemTypes . classification :
target_docs [ curr_target ] . append ( line_split [ i ] )
elif ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
target_docs [ curr_target ] . append ( line_split [ i ] . split ( " " ) )
elif ProblemTypes [ self . problem_type ] == ProblemTypes . regression or \
ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
pass
return docs , target_docs , cnt_legal , cnt_illegal
2019-05-10 07:43:52 +03:00
def build_training_multi_processor ( self , training_data_generator , cpu_num_workers , file_columns , input_types , answer_column_name , bpe_encoder = None ) :
for data in training_data_generator :
# multi-Processing
scheduler = ProcessorsScheduler ( cpu_num_workers )
func_args = ( data , file_columns , input_types , answer_column_name , bpe_encoder )
res = scheduler . run_data_parallel ( self . build_training_data_list , func_args )
# aggregate
docs = dict ( ) # docs of each type of input
target_docs = [ ]
cnt_legal = 0
cnt_illegal = 0
for ( index , j ) in res :
#logging.info("collect proccesor %d result" % index)
tmp_docs , tmp_target_docs , tmp_cnt_legal , tmp_cnt_illegal = j . get ( )
if len ( docs ) == 0 :
docs = tmp_docs
else :
for key , value in tmp_docs . items ( ) :
docs [ key ] . extend ( value )
if len ( target_docs ) == 0 :
target_docs = tmp_target_docs
else :
for single_type in tmp_target_docs :
target_docs [ single_type ] . extend ( tmp_target_docs [ single_type ] )
# target_docs.extend(tmp_target_docs)
cnt_legal + = tmp_cnt_legal
cnt_illegal + = tmp_cnt_illegal
2019-04-20 14:17:30 +03:00
2019-05-10 07:43:52 +03:00
yield docs , target_docs , cnt_legal , cnt_illegal
2019-04-20 14:17:30 +03:00
2019-05-15 12:56:35 +03:00
def build ( self , data_path_list , file_columns , input_types , file_with_col_header , answer_column_name , word2vec_path = None , word_emb_dim = None ,
2019-04-20 14:17:30 +03:00
format = None , file_type = None , involve_all_words = None , file_format = " tsv " , show_progress = True ,
2019-08-02 15:59:00 +03:00
cpu_num_workers = - 1 , max_vocabulary = 800000 , word_frequency = 3 , max_building_lines = 1000 * 1000 ) :
2019-04-20 14:17:30 +03:00
"""
Args :
2019-08-02 15:59:00 +03:00
data_path_list :
2019-04-20 14:17:30 +03:00
file_columns : {
" word1 " : 0 ,
" word2 " : 1 ,
" label " : 2 ,
" postag_feature1 " : 3 ,
" postag_feature2 " : 4
} ,
input_types :
e . g .
{
" word " : {
" cols " : [ " word1 " , " word2 " ] ,
" dim " : 300
} ,
" postag " : {
" cols " : [ " postag_feature1 " , " postag_feature2 " ] ,
" dim " : 20
} ,
}
or
{
" bpe " : {
" cols " : [ " word1 " , " word2 " ] ,
" dim " : 100
" bpe_path " : " xxx.bpe "
}
}
word2vec_path :
word_emb_dim :
involve_all_word : involve all words that show up in the pretrained embedding
file_format : " tsv " , or " json " . Note " json " means each sample is represented by a json string .
Returns :
"""
2019-05-23 02:55:17 +03:00
# parameter check
2019-08-02 15:59:00 +03:00
bpe_encoder = self . _check_bpe_encoder ( input_types )
2019-04-20 14:17:30 +03:00
self . file_column_num = len ( file_columns )
2019-08-02 15:59:00 +03:00
for data_path in data_path_list :
if data_path :
progress = self . get_data_generator_from_file ( data_path , file_with_col_header , chunk_size = max_building_lines )
preprocessed_data_generator = self . build_training_multi_processor ( progress , cpu_num_workers , file_columns , input_types , answer_column_name , bpe_encoder = bpe_encoder )
2019-05-10 07:43:52 +03:00
2019-08-02 15:59:00 +03:00
# update symbol universe
docs , target_docs , cnt_legal , cnt_illegal = next ( preprocessed_data_generator )
2019-05-10 07:43:52 +03:00
2019-08-02 15:59:00 +03:00
# input_type
for input_type in input_types :
self . input_dicts [ input_type ] . update ( docs [ input_type ] )
# problem_type
if ProblemTypes [ self . problem_type ] == ProblemTypes . classification or \
ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
self . output_dict . update ( list ( target_docs . values ( ) ) [ 0 ] )
elif ProblemTypes [ self . problem_type ] == ProblemTypes . regression or \
ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
pass
logging . info ( " [Building Dictionary] in %s at most %d lines imported: %d legal lines, %d illegal lines. " % ( data_path , max_building_lines , cnt_legal , cnt_illegal ) )
2019-05-10 07:43:52 +03:00
# build dictionary
2019-04-20 14:17:30 +03:00
for input_type in input_types :
2019-05-10 07:43:52 +03:00
self . input_dicts [ input_type ] . build ( threshold = word_frequency , max_vocabulary_num = max_vocabulary )
logging . info ( " %d types in %s column " % ( self . input_dicts [ input_type ] . cell_num ( ) , input_type ) )
2019-04-20 14:17:30 +03:00
if self . output_dict :
2019-05-10 07:43:52 +03:00
self . output_dict . build ( threshold = 0 )
2019-06-19 09:44:07 +03:00
if ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
self . output_dict . cell_id_map [ " <start> " ] = len ( self . output_dict . cell_id_map )
self . output_dict . id_cell_map [ len ( self . output_dict . id_cell_map ) ] = " <start> "
self . output_dict . cell_id_map [ " <eos> " ] = len ( self . output_dict . cell_id_map )
self . output_dict . id_cell_map [ len ( self . output_dict . id_cell_map ) ] = " <eos> "
2019-05-10 07:43:52 +03:00
logging . info ( " %d types in target column " % ( self . output_dict . cell_num ( ) ) )
logging . debug ( " training data dict built " )
2019-04-20 14:17:30 +03:00
2019-05-10 07:43:52 +03:00
# embedding
word_emb_matrix = None
2019-04-20 14:17:30 +03:00
if word2vec_path :
2019-05-10 07:43:52 +03:00
logging . info ( " Getting pre-trained embeddings... " )
word_emb_dict = None
if involve_all_words is True :
word_emb_dict = load_embedding ( word2vec_path , word_emb_dim , format , file_type , with_head = False , word_set = None )
self . input_dicts [ ' word ' ] . update ( [ list ( word_emb_dict . keys ( ) ) ] )
self . input_dicts [ ' word ' ] . build ( threshold = 0 , max_vocabulary_num = len ( word_emb_dict ) )
else :
2019-06-05 14:55:43 +03:00
extend_vocabulary = set ( )
for single_word in self . input_dicts [ ' word ' ] . cell_id_map . keys ( ) :
extend_vocabulary . add ( single_word )
if single_word . lower ( ) != single_word :
extend_vocabulary . add ( single_word . lower ( ) )
word_emb_dict = load_embedding ( word2vec_path , word_emb_dim , format , file_type , with_head = False , word_set = extend_vocabulary )
2019-04-20 14:17:30 +03:00
for word in word_emb_dict :
loaded_emb_dim = len ( word_emb_dict [ word ] )
break
assert loaded_emb_dim == word_emb_dim , " The dimension of defined word embedding is inconsistent with the pretrained embedding provided! "
2019-05-10 07:43:52 +03:00
logging . info ( " constructing embedding table " )
2019-04-20 14:17:30 +03:00
if self . input_dicts [ ' word ' ] . with_unk :
word_emb_dict [ ' <unk> ' ] = np . random . random ( size = word_emb_dim )
if self . input_dicts [ ' word ' ] . with_pad :
word_emb_dict [ ' <pad> ' ] = np . random . random ( size = word_emb_dim )
word_emb_matrix = [ ]
unknown_word_count = 0
2019-06-05 14:55:43 +03:00
scale = np . sqrt ( 3.0 / word_emb_dim )
2019-04-20 14:17:30 +03:00
for i in range ( self . input_dicts [ ' word ' ] . cell_num ( ) ) :
2019-06-05 14:55:43 +03:00
single_word = self . input_dicts [ ' word ' ] . id_cell_map [ i ]
if single_word in word_emb_dict :
word_emb_matrix . append ( word_emb_dict [ single_word ] )
elif single_word . lower ( ) in word_emb_dict :
word_emb_matrix . append ( word_emb_dict [ single_word . lower ( ) ] )
2019-04-20 14:17:30 +03:00
else :
2019-06-05 14:55:43 +03:00
word_emb_matrix . append ( np . random . uniform ( - scale , scale , word_emb_dim ) )
2019-04-20 14:17:30 +03:00
unknown_word_count + = 1
word_emb_matrix = np . array ( word_emb_matrix )
logging . info ( " word embedding matrix shape:( %d , %d ); unknown word count: %d ; " %
( len ( word_emb_matrix ) , len ( word_emb_matrix [ 0 ] ) , unknown_word_count ) )
logging . info ( " Word embedding loaded " )
2019-05-10 07:43:52 +03:00
2019-04-20 14:17:30 +03:00
return word_emb_matrix
2019-05-23 02:55:17 +03:00
@staticmethod
def _merge_encode_data ( dest_dict , src_dict ) :
if len ( dest_dict ) == 0 :
dest_dict = src_dict
else :
for branch in src_dict :
for input_type in dest_dict [ branch ] :
dest_dict [ branch ] [ input_type ] . extend ( src_dict [ branch ] [ input_type ] )
return dest_dict
@staticmethod
def _merge_encode_lengths ( dest_dict , src_dict ) :
2019-04-20 14:17:30 +03:00
def judge_dict ( obj ) :
return True if isinstance ( obj , dict ) else False
2019-05-23 02:55:17 +03:00
if len ( dest_dict ) == 0 :
dest_dict = src_dict
else :
for branch in src_dict :
if judge_dict ( src_dict [ branch ] ) :
for type_branch in src_dict [ branch ] :
dest_dict [ branch ] [ type_branch ] . extend ( src_dict [ branch ] [ type_branch ] )
else :
dest_dict [ branch ] . extend ( src_dict [ branch ] )
return dest_dict
@staticmethod
def _merge_target ( dest_dict , src_dict ) :
if not src_dict :
return src_dict
if len ( dest_dict ) == 0 :
dest_dict = src_dict
else :
for single_type in src_dict :
dest_dict [ single_type ] . extend ( src_dict [ single_type ] )
return dest_dict
def encode_data_multi_processor ( self , data_generator , cpu_num_workers , file_columns , input_types , object_inputs ,
answer_column_name , min_sentence_len , extra_feature , max_lengths = None , fixed_lengths = None , file_format = " tsv " , bpe_encoder = None ) :
for data in data_generator :
2019-05-10 07:43:52 +03:00
scheduler = ProcessorsScheduler ( cpu_num_workers )
func_args = ( data , file_columns , input_types , object_inputs ,
answer_column_name , min_sentence_len , extra_feature , max_lengths , fixed_lengths , file_format , bpe_encoder )
res = scheduler . run_data_parallel ( self . encode_data_list , func_args )
2019-05-23 02:55:17 +03:00
output_data , lengths , target = dict ( ) , dict ( ) , dict ( )
cnt_legal , cnt_illegal = 0 , 0
2019-05-10 07:43:52 +03:00
for ( index , j ) in res :
# logging.info("collect proccesor %d result"%index)
tmp_data , tmp_lengths , tmp_target , tmp_cnt_legal , tmp_cnt_illegal = j . get ( )
2019-05-23 02:55:17 +03:00
output_data = self . _merge_encode_data ( output_data , tmp_data )
lengths = self . _merge_encode_lengths ( lengths , tmp_lengths )
target = self . _merge_target ( target , tmp_target )
2019-05-10 07:43:52 +03:00
cnt_legal + = tmp_cnt_legal
cnt_illegal + = tmp_cnt_illegal
2019-05-23 02:55:17 +03:00
yield output_data , lengths , target , cnt_legal , cnt_illegal
2019-04-20 14:17:30 +03:00
def encode_data_list ( self , data_list , file_columns , input_types , object_inputs , answer_column_name , min_sentence_len ,
2019-06-28 11:58:33 +03:00
extra_feature , max_lengths = None , fixed_lengths = None , file_format = " tsv " , bpe_encoder = None , predict_mode = ' batch ' ) :
2019-04-20 14:17:30 +03:00
data = dict ( )
lengths = dict ( )
char_emb = True if ' char ' in [ single_input_type . lower ( ) for single_input_type in input_types ] else False
if answer_column_name is not None and len ( answer_column_name ) > 0 :
target = { }
lengths [ ' target ' ] = { }
columns_to_target = { }
for single_target in answer_column_name :
target [ single_target ] = [ ]
columns_to_target [ file_columns [ single_target ] ] = single_target
lengths [ ' target ' ] [ single_target ] = [ ]
else :
target = None
col_index_types = dict ( ) # input type of each column, namely the inverse of file_columns, e.g. col_index_types[0] = 'query_index'
type2cluster = dict ( ) # e.g. type2cluster['query_index'] = 'word'
type_branches = dict ( ) # branch of input type, e.g. type_branches['query_index'] = 'query'
2019-08-02 15:59:00 +03:00
# for char: don't split these word
word_no_split = [ ' <start> ' , ' <pad> ' , ' <eos> ' , ' <unk> ' ]
2019-04-20 14:17:30 +03:00
for branch in object_inputs :
data [ branch ] = dict ( )
lengths [ branch ] = dict ( )
lengths [ branch ] [ ' sentence_length ' ] = [ ]
temp_branch_char = False
for input_type in object_inputs [ branch ] :
type_branches [ input_type ] = branch
data [ branch ] [ input_type ] = [ ]
if ' char ' in input_type . lower ( ) :
temp_branch_char = True
if char_emb and temp_branch_char :
lengths [ branch ] [ ' word_length ' ] = [ ]
# for extra_info for mrc task
if ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
extra_info_type = ' passage '
if extra_info_type not in object_inputs :
raise Exception ( ' MRC task need passage for model_inputs, given: {0} ' . format ( ' ; ' . join ( list ( object_inputs . keys ( ) ) ) ) )
data [ extra_info_type ] [ ' extra_passage_text ' ] = [ ]
data [ extra_info_type ] [ ' extra_passage_token_offsets ' ] = [ ]
for input_type in input_types :
for col_name in input_types [ input_type ] [ ' cols ' ] :
type2cluster [ col_name ] = input_type
if col_name in file_columns :
col_index_types [ file_columns [ col_name ] ] = col_name
cnt_legal = 0
cnt_illegal = 0
# cnt_length_unconsistent = 0
cnt_all = 0
for line in data_list :
# line_split = list(filter(lambda x: len(x) > 0, line.rstrip().split('\t')))
line_split = line . rstrip ( ) . split ( ' \t ' )
cnt_all + = 1
if len ( line_split ) != len ( file_columns ) :
2019-06-28 11:58:33 +03:00
if predict_mode == ' batch ' :
cnt_illegal + = 1
if cnt_illegal / cnt_all > 0.33 :
raise PreprocessError ( ' The illegal data is too much. Please check the number of data columns or text token version. ' )
continue
else :
print ( ' \t The case is illegal! Please check your case and input again! ' )
return [ None ] * 5
2019-04-20 14:17:30 +03:00
# cnt_legal += 1
length_appended_set = set ( ) # to store branches whose length have been appended to lengths[branch]
if ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
passage_token_offsets = None
for i in range ( len ( line_split ) ) :
line_split [ i ] = line_split [ i ] . strip ( )
if i in col_index_types :
# these are data
branch = type_branches [ col_index_types [ i ] ]
input_type = [ ]
input_type . append ( col_index_types [ i ] )
if ( type2cluster [ col_index_types [ i ] ] == ' word ' and char_emb ) :
temp_col_char = col_index_types [ i ] . split ( ' _ ' ) [ 0 ] + ' _ ' + ' char '
if temp_col_char in input_types [ ' char ' ] [ ' cols ' ] :
input_type . append ( temp_col_char )
if type2cluster [ col_index_types [ i ] ] == ' word ' or type2cluster [ col_index_types [ i ] ] == ' bpe ' :
if self . lowercase :
line_split [ i ] = line_split [ i ] . lower ( )
line_split [ i ] = self . text_preprocessor . preprocess ( line_split [ i ] )
if type2cluster [ col_index_types [ i ] ] == ' word ' :
if ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
token_offsets = self . tokenizer . span_tokenize ( line_split [ i ] )
tokens = [ line_split [ i ] [ span [ 0 ] : span [ 1 ] ] for span in token_offsets ]
if branch == ' passage ' :
passage_token_offsets = token_offsets
data [ extra_info_type ] [ ' extra_passage_text ' ] . append ( line_split [ i ] )
data [ extra_info_type ] [ ' extra_passage_token_offsets ' ] . append ( passage_token_offsets )
else :
2019-06-05 14:55:43 +03:00
if extra_feature == False and ProblemTypes [ self . problem_type ] != ProblemTypes . sequence_tagging :
2019-04-20 14:17:30 +03:00
tokens = self . tokenizer . tokenize ( line_split [ i ] )
else :
tokens = line_split [ i ] . split ( ' ' )
elif type2cluster [ col_index_types [ i ] ] == ' bpe ' :
tokens = bpe_encoder . encode ( line_split [ i ] )
else :
tokens = line_split [ i ] . split ( ' ' )
2019-06-19 09:44:07 +03:00
# for sequence labeling task, the length must be record the corpus truth length
if ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
if not branch in length_appended_set :
lengths [ branch ] [ ' sentence_length ' ] . append ( len ( tokens ) )
length_appended_set . add ( branch )
else :
if len ( tokens ) != lengths [ branch ] [ ' sentence_length ' ] [ - 1 ] :
# logging.warning(
# "The length of inputs are not consistent. Ingore now. %s" % line)
cnt_illegal + = 1
if cnt_illegal / cnt_all > 0.33 :
raise PreprocessError (
" The illegal data is too much. Please check the number of data columns or text token version. " )
lengths [ branch ] [ ' sentence_length ' ] . pop ( )
true_len = len ( lengths [ branch ] [ ' sentence_length ' ] )
# need delete the last example
check_list = [ ' data ' , ' lengths ' , ' target ' ]
for single_check in check_list :
single_check = eval ( single_check )
self . delete_example ( single_check , true_len )
break
2019-04-20 14:17:30 +03:00
if fixed_lengths and type_branches [ input_type [ 0 ] ] in fixed_lengths :
if len ( tokens ) > = fixed_lengths [ type_branches [ input_type [ 0 ] ] ] :
tokens = tokens [ : fixed_lengths [ type_branches [ input_type [ 0 ] ] ] ]
else :
tokens = tokens + [ ' <pad> ' ] * ( fixed_lengths [ type_branches [ input_type [ 0 ] ] ] - len ( tokens ) )
else :
if max_lengths and type_branches [ input_type [ 0 ] ] in max_lengths : # cut sequences which are too long
tokens = tokens [ : max_lengths [ type_branches [ input_type [ 0 ] ] ] ]
if len ( tokens ) < min_sentence_len :
tokens = tokens + [ ' <pad> ' ] * ( min_sentence_len - len ( tokens ) )
if self . with_bos_eos is True :
tokens = [ ' <start> ' ] + tokens + [ ' <eos> ' ] # so that source_with_start && source_with_end should be True
2019-06-19 09:44:07 +03:00
# for other tasks, length must be same as data length because fix/max_length operation
if not ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
if not branch in length_appended_set :
lengths [ branch ] [ ' sentence_length ' ] . append ( len ( tokens ) )
length_appended_set . add ( branch )
else :
if len ( tokens ) != lengths [ branch ] [ ' sentence_length ' ] [ - 1 ] :
# logging.warning(
# "The length of inputs are not consistent. Ingore now. %s" % line)
cnt_illegal + = 1
if cnt_illegal / cnt_all > 0.33 :
raise PreprocessError (
" The illegal data is too much. Please check the number of data columns or text token version. " )
lengths [ branch ] [ ' sentence_length ' ] . pop ( )
true_len = len ( lengths [ branch ] [ ' sentence_length ' ] )
# need delete the last example
check_list = [ ' data ' , ' lengths ' , ' target ' ]
for single_check in check_list :
single_check = eval ( single_check )
self . delete_example ( single_check , true_len )
break
2019-04-20 14:17:30 +03:00
for single_input_type in input_type :
if ' char ' in single_input_type :
temp_word_char = [ ]
temp_word_length = [ ]
for single_token in tokens :
2019-08-02 15:59:00 +03:00
if single_token in word_no_split :
# temp_word_length.append(1)
temp_id = [ self . input_dicts [ type2cluster [ single_input_type ] ] . id ( single_token ) ]
else :
temp_id = self . input_dicts [ type2cluster [ single_input_type ] ] . lookup ( single_token )
if fixed_lengths and ' word ' in fixed_lengths :
if len ( temp_id ) > = fixed_lengths [ ' word ' ] :
temp_id = temp_id [ : fixed_lengths [ ' word ' ] ]
else :
temp_id = temp_id + [ self . input_dicts [ type2cluster [ single_input_type ] ] . id ( ' <pad> ' ) ] * ( fixed_lengths [ ' word ' ] - len ( temp_id ) )
temp_word_char . append ( temp_id )
temp_word_length . append ( len ( temp_id ) )
2019-04-20 14:17:30 +03:00
data [ branch ] [ single_input_type ] . append ( temp_word_char )
lengths [ branch ] [ ' word_length ' ] . append ( temp_word_length )
else :
data [ branch ] [ single_input_type ] . \
append ( self . input_dicts [ type2cluster [ single_input_type ] ] . lookup ( tokens ) )
else :
# judge target
if answer_column_name is not None and len ( answer_column_name ) > 0 :
if i in columns_to_target . keys ( ) :
# this is target
curr_target = columns_to_target [ i ]
if ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
try :
trans2int = int ( line_split [ i ] )
except ( ValueError ) :
target [ curr_target ] . append ( line_split [ i ] )
else :
target [ curr_target ] . append ( trans2int )
lengths [ ' target ' ] [ curr_target ] . append ( 1 )
if ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
target_tags = line_split [ i ] . split ( " " )
if fixed_lengths and " target " in fixed_lengths :
if len ( target_tags ) > = fixed_lengths [ type_branches [ input_type [ 0 ] ] ] :
target_tags = target_tags [ : fixed_lengths [ type_branches [ input_type [ 0 ] ] ] ]
else :
target_tags = target_tags + [ ' <pad> ' ] * ( fixed_lengths [ type_branches [ input_type [ 0 ] ] ] - len ( target_tags ) )
else :
if max_lengths and " target " in max_lengths : # cut sequences which are too long
target_tags = target_tags [ : max_lengths [ " target " ] ]
if self . with_bos_eos is True :
target_tags = [ ' O ' ] + target_tags + [ ' O ' ]
target [ curr_target ] . append ( self . output_dict . lookup ( target_tags ) )
lengths [ ' target ' ] [ curr_target ] . append ( len ( target_tags ) )
elif ProblemTypes [ self . problem_type ] == ProblemTypes . classification :
target [ curr_target ] . append ( self . output_dict . id ( line_split [ i ] ) )
lengths [ ' target ' ] [ curr_target ] . append ( 1 )
elif ProblemTypes [ self . problem_type ] == ProblemTypes . regression :
target [ curr_target ] . append ( float ( line_split [ i ] ) )
lengths [ ' target ' ] [ curr_target ] . append ( 1 )
else :
# these columns are useless in the configuration
pass
cnt_legal + = 1
if ProblemTypes [ self . problem_type ] == ProblemTypes . mrc and target is not None :
if passage_token_offsets :
if ' start_label ' not in target or ' end_label ' not in target :
raise Exception ( ' MRC task need start_label and end_label. ' )
start_char_label = target [ ' start_label ' ] [ - 1 ]
end_char_label = target [ ' end_label ' ] [ - 1 ]
start_word_label = 0
end_word_label = len ( passage_token_offsets ) - 1
# for i in range(len(passage_token_offsets)):
# token_s, token_e = passage_token_offsets[i]
# if token_s > start_char_label:
# break
# start_word_label = i
# for i in range(len(passage_token_offsets)):
# token_s, token_e = passage_token_offsets[i]
# end_word_label = i
# if token_e >= end_char_label:
# break
for i in range ( len ( passage_token_offsets ) ) :
token_s , token_e = passage_token_offsets [ i ]
if token_s < = start_char_label < = token_e :
start_word_label = i
if token_s < = end_char_label - 1 < = token_e :
end_word_label = i
target [ ' start_label ' ] [ - 1 ] = start_word_label
target [ ' end_label ' ] [ - 1 ] = end_word_label
else :
raise Exception ( ' MRC task need passage. ' )
return data , lengths , target , cnt_legal , cnt_illegal
def encode ( self , data_path , file_columns , input_types , file_with_col_header , object_inputs , answer_column_name ,
min_sentence_len , extra_feature , max_lengths = None , fixed_lengths = None , file_format = " tsv " , show_progress = True ,
2019-08-02 15:59:00 +03:00
cpu_num_workers = - 1 , chunk_size = 1000 * 1000 ) :
2019-04-20 14:17:30 +03:00
"""
Args :
data_path :
file_columns : {
" word1 " : 0 ,
" word2 " : 1 ,
" label " : 2 ,
" postag_feature1 " : 3 ,
" postag_feature2 " : 4
} ,
input_types :
{
" word " : {
" cols " : [
" word1 " ,
" word2 "
] ,
" dim " : 300
} ,
" postag " : {
" cols " : [ " postag_feature1 " , " postag_feature2 " ] ,
" dim " : 20
}
}
or
{
" bpe " : {
" cols " : [ " word1 " , " word2 " ] ,
" dim " : 100
" bpe_path " : " xxx.bpe "
}
}
object_inputs : {
" string1 " : [
" word1 " ,
" postag_feature1 "
] ,
" string2 " : [
" word2 " ,
" postag_feature2 "
]
} ,
answer_column_name : ' label ' / None . None means there is no target and it is used for prediction only .
max_lengths : if it is a dict , firstly cut the sequences if they exceed the max length . Then , pad all the sequences to the length of longest string .
{
" string1 " : 25 ,
" string2 " : 100
}
fixed_lengths : if it is a dict , cut or pad the sequences to the fixed lengths .
{
" string1 " : 25 ,
" string2 " : 100
}
file_format :
Returns :
data : indices , padded
{
' string1 ' : {
' word1 ' : [ . . . ] ,
' postage_feature1 ' : [ . . ]
}
' string2 ' : {
' word1 ' : [ . . . ] ,
' postage_feature1 ' : [ . . ]
}
lengths : real length of data
{
' string1 ' : [ . . . ] ,
' string2 ' : [ . . . ]
}
target : [ . . . ]
"""
2019-08-02 15:59:00 +03:00
bpe_encoder = self . _check_bpe_encoder ( input_types )
2019-04-20 14:17:30 +03:00
2019-08-02 15:59:00 +03:00
progress = self . get_data_generator_from_file ( data_path , file_with_col_header , chunk_size = chunk_size )
encode_generator = self . encode_data_multi_processor ( progress , cpu_num_workers ,
2019-04-20 14:17:30 +03:00
file_columns , input_types , object_inputs , answer_column_name , min_sentence_len , extra_feature , max_lengths ,
fixed_lengths , file_format , bpe_encoder = bpe_encoder )
2019-05-23 02:55:17 +03:00
data , lengths , target = dict ( ) , dict ( ) , dict ( )
cnt_legal , cnt_illegal = 0 , 0
2019-08-02 15:59:00 +03:00
for temp_data , temp_lengths , temp_target , temp_cnt_legal , temp_cnt_illegal in tqdm ( encode_generator ) :
2019-05-23 02:55:17 +03:00
data = self . _merge_encode_data ( data , temp_data )
lengths = self . _merge_encode_lengths ( lengths , temp_lengths )
target = self . _merge_target ( target , temp_target )
cnt_legal + = temp_cnt_legal
cnt_illegal + = temp_cnt_illegal
2019-04-20 14:17:30 +03:00
logging . info ( " %s : %d legal samples, %d illegal samples " % ( data_path , cnt_legal , cnt_illegal ) )
return data , lengths , target
2019-08-02 15:59:00 +03:00
def build_encode_cache ( self , conf , file_format = " tsv " ) :
logging . info ( " [Cache] building encoding cache " )
build_encode_cache_generator = self . get_encode_generator ( conf , build_cache = True , file_format = file_format )
for _ in build_encode_cache_generator :
continue
logging . info ( " [Cache] encoding is saved to %s " % conf . encoding_cache_dir )
def get_encode_generator ( self , conf , build_cache = True , file_format = " tsv " ) :
# parameter check
if build_cache :
assert conf . encoding_cache_dir , ' There is no property encoding_cache_dir in object conf '
assert conf . encoding_cache_index_file_path , ' There is no property encoding_cache_index_file_path in object conf '
assert conf . encoding_cache_index_file_md5_path , ' There is no property encoding_cache_index_file_md5_path in object conf '
bpe_encoder = self . _check_bpe_encoder ( conf . input_types )
data_generator = self . get_data_generator_from_file ( conf . train_data_path , conf . file_with_col_header , chunk_size = conf . chunk_size )
encode_generator = self . encode_data_multi_processor ( data_generator , conf . cpu_num_workers ,
conf . file_columns , conf . input_types , conf . object_inputs , conf . answer_column_name ,
conf . min_sentence_len , conf . extra_feature , conf . max_lengths ,
conf . fixed_lengths , file_format , bpe_encoder = bpe_encoder )
file_index = [ ]
total_cnt_legal , total_cnt_illegal = 0 , 0
for part_number , encode_data in enumerate ( encode_generator ) :
data , lengths , target , cnt_legal , cnt_illegal = encode_data
if build_cache :
total_cnt_legal = total_cnt_legal + cnt_legal
total_cnt_illegal = total_cnt_illegal + cnt_illegal
file_name = st . cencoding_file_name_pattern % ( part_number )
file_path = os . path . join ( conf . encoding_cache_dir , file_name )
dump_to_pkl ( ( data , lengths , target ) , file_path )
file_index . append ( [ file_name , md5 ( [ file_path ] ) ] )
logging . info ( " Up to now, in %s : %d legal samples, %d illegal samples " % ( conf . train_data_path , total_cnt_legal , total_cnt_illegal ) )
yield data , lengths , target
if build_cache :
cache_index = dict ( )
cache_index [ st . cencoding_key_index ] = file_index
cache_index [ st . cencoding_key_legal_cnt ] = total_cnt_legal
cache_index [ st . cencoding_key_illegal_cnt ] = total_cnt_illegal
dump_to_json ( cache_index , conf . encoding_cache_index_file_path )
dump_to_json ( md5 ( [ conf . encoding_cache_index_file_path ] ) , conf . encoding_cache_index_file_md5_path )
@staticmethod
def _check_bpe_encoder ( input_types ) :
bpe_encoder = None
if ' bpe ' in input_types :
try :
bpe_encoder = BPEEncoder ( input_types [ ' bpe ' ] [ ' bpe_path ' ] )
except KeyError :
raise Exception ( ' Please define a bpe path at the embedding layer. ' )
return bpe_encoder
2019-04-20 14:17:30 +03:00
def decode ( self , model_output , lengths = None , batch_data = None ) :
""" decode the model output, either a batch of output or a single output
Args :
model_output : target indices .
if is 1 d array , it is an output of a sample ;
if is 2 d array , it is outputs of a batch of samples ;
lengths : if not None , the shape of length should be consistent with model_output .
Returns :
the original output
"""
if ProblemTypes [ self . problem_type ] == ProblemTypes . classification :
if isinstance ( model_output , int ) : # output of a sample
return self . output_dict . cell ( model_output )
else : # output of a batch
return self . output_dict . decode ( model_output )
elif ProblemTypes [ self . problem_type ] == ProblemTypes . sequence_tagging :
if isinstance ( model_output , dict ) :
model_output = list ( model_output . values ( ) ) [ 0 ]
if not isinstance ( model_output , np . ndarray ) :
model_output = np . array ( model_output )
if len ( model_output . shape ) == 1 : # output of a sample
if lengths is None :
outputs = np . array ( self . output_dict . decode ( model_output ) )
else :
outputs = np . array ( self . output_dict . decode ( model_output [ : lengths ] ) )
if self . with_bos_eos :
outputs = outputs [ 1 : - 1 ]
elif len ( model_output . shape ) == 2 : # output of a batch of sequence
outputs = [ ]
if lengths is None :
for sample in model_output :
if self . with_bos_eos :
outputs . append ( self . output_dict . decode ( sample [ 1 : - 1 ] ) )
else :
outputs . append ( self . output_dict . decode ( sample ) )
else :
for sample , length in zip ( model_output , lengths ) :
if self . with_bos_eos :
outputs . append ( self . output_dict . decode ( sample [ : length ] [ 1 : - 1 ] ) )
else :
outputs . append ( self . output_dict . decode ( sample [ : length ] ) )
return outputs
elif ProblemTypes [ self . problem_type ] == ProblemTypes . mrc :
# for mrc, model_output is dict
answers = [ ]
p1 , p2 = list ( model_output . values ( ) ) [ 0 ] , list ( model_output . values ( ) ) [ 1 ]
batch_size , c_len = p1 . size ( )
passage_length = lengths . numpy ( )
padding_mask = np . ones ( ( batch_size , c_len ) )
for i , single_len in enumerate ( passage_length ) :
padding_mask [ i ] [ : single_len ] = 0
device = p1 . device
padding_mask = torch . from_numpy ( padding_mask ) . byte ( ) . to ( device )
p1 . data . masked_fill_ ( padding_mask . data , float ( ' -inf ' ) )
p2 . data . masked_fill_ ( padding_mask . data , float ( ' -inf ' ) )
ls = nn . LogSoftmax ( dim = 1 )
mask = ( torch . ones ( c_len , c_len ) * float ( ' -inf ' ) ) . to ( device ) . tril ( - 1 ) . unsqueeze ( 0 ) . expand ( batch_size , - 1 , - 1 )
score = ( ls ( p1 ) . unsqueeze ( 2 ) + ls ( p2 ) . unsqueeze ( 1 ) ) + mask
score , s_idx = score . max ( dim = 1 )
score , e_idx = score . max ( dim = 1 )
s_idx = torch . gather ( s_idx , 1 , e_idx . view ( - 1 , 1 ) ) . squeeze ( )
# encode mrc answer text
passage_text = ' extra_passage_text '
passage_token_offsets = ' extra_passage_token_offsets '
for i in range ( batch_size ) :
char_s_idx , _ = batch_data [ passage_token_offsets ] [ i ] [ s_idx [ i ] ]
_ , char_e_idx = batch_data [ passage_token_offsets ] [ i ] [ e_idx [ i ] ]
answer = batch_data [ passage_text ] [ i ] [ char_s_idx : char_e_idx ]
answers . append ( answer )
return answers
def get_vocab_sizes ( self ) :
""" get size of vocabs: including word embedding, postagging ...
Returns :
{
' word ' : xxx ,
' postag ' : xxx ,
}
"""
vocab_sizes = dict ( )
for input in self . input_dicts :
vocab_sizes [ input ] = self . input_dicts [ input ] . cell_num ( )
return vocab_sizes
def export_problem ( self , save_path , ret_without_save = False ) :
if not os . path . exists ( os . path . dirname ( save_path ) ) :
os . makedirs ( os . path . dirname ( save_path ) )
problem = dict ( )
for name , value in vars ( self ) . items ( ) :
if name . startswith ( " __ " ) is False :
if isinstance ( value , CellDict ) :
problem [ name ] = value . export_cell_dict ( )
else :
problem [ name ] = value
if ret_without_save is False :
with open ( save_path , ' wb ' ) as fout :
pkl . dump ( problem , fout , protocol = pkl . HIGHEST_PROTOCOL )
logging . debug ( " Problem saved to %s " % save_path )
return None
else :
return problem
def load_problem ( self , problem_path ) :
info_dict = load_from_pkl ( problem_path )
for name in info_dict :
if isinstance ( getattr ( self , name ) , CellDict ) :
getattr ( self , name ) . load_cell_dict ( info_dict [ name ] )
else :
setattr ( self , name , info_dict [ name ] )
# the type of input_dicts is dict
# elif name == 'input_dicts' and isinstance(getattr(self, name), type(info_dict[name])):
# setattr(self, name, info_dict[name])
logging . debug ( " Problem loaded " )
def delete_example ( self , data , true_len ) :
if isinstance ( data , list ) :
if len ( data ) > true_len :
data . pop ( )
else :
# data is dict
for single_value in data . values ( ) :
2019-05-08 14:07:52 +03:00
self . delete_example ( single_value , true_len )