#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Tokenize tables (the reader output) to generate model inputs with pre-processing for three pre-training objectives: MLM: Masked Language Model CLC: Cell Level Cloze TCR: Table Context Retrieval """ from typing import Dict, List import os import random import collections import unicodedata from utils import UNZIPS # %% Special Token Identifications # additionally defined token ids VAL_ID = 1 EMP_ID = 2 # placeholder for emtpy cells # adopt from bert PAD_ID = 0 UNK_ID = 100 CLS_ID = 101 SEP_ID = 102 MASK_ID = 103 # corresponding token-strings VAL_TOKEN = "[VAL]" EMP_TOKEN = "[EMP]" PAD_TOKEN = "[PAD]" UNK_TOKEN = "[UNK]" CLS_TOKEN = "[CLS]" SEP_TOKEN = "[SEP]" MASK_TOKEN = "[MASK]" # %% Cleansing functions def whitespace_split(text): text = text.strip() if not text: return [] return text.split() def _is_whitespace(char): if char == " " or char == "\t" or char == "\n" or char == "\r": return True cat = unicodedata.category(char) if cat == "Zs": return True return False def _is_control(char): if char == "\t" or char == "\n" or char == "\r": return False cat = unicodedata.category(char) if cat in ("Cc", "Cf"): return True return False def _is_punctuation(char): cp = ord(char) if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): return True cat = unicodedata.category(char) if cat.startswith("P"): return True return False def convert_to_unicode(text): """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" if isinstance(text, str): return text elif isinstance(text, bytes): return text.decode("utf-8", "ignore") else: raise ValueError("Unsupported string type: {}".format(type(text))) def load_vocab(vocab_path): vocab = collections.OrderedDict() with open(vocab_path, "r", encoding='utf-8') as reader: for index,line in enumerate(reader): token = convert_to_unicode(line.strip()) if not token: continue vocab[token] = index print("Successfully Loaded Vocabulary from {}".format(vocab_path)) return vocab # %% Basic class BasicTokenizer(object): """ Basic: separate words at whitespaces, punctuations, and numbers""" def __init__(self, do_lower_case=True): self.do_lower_case = do_lower_case self.digit_set = set(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]) self.numeric_mark = set([',', '.']) def tokenize(self, text): text = convert_to_unicode(text) text = self._clean_text(text) text = self._tokenize_chinese_chars(text) org_text_list, org_type_list = self._group_digit_chars(text) # cut numbers apart text_list, type_list = [], [] for stext, stype in zip(org_text_list, org_type_list): if stype: # digit span (with , and .) stext = self.remove_comma_in_number(stext) slist, stype_list = self.split_by_dot(stext) text_list.extend(slist) type_list.extend(stype_list) else: # text span stext_list = whitespace_split(stext) for item in stext_list: if self.do_lower_case: item = item.lower() item = self._run_strip_accents(item) # str removed with accent tokens item = self._run_split_on_punc(item) text_list.extend([ii.strip() for ii in item]) type_list.extend([stype for _ in item]) stripped_text, stripped_type = [], [] for stext, stype in zip(text_list, type_list): if stext.strip(): stripped_text.append(stext.strip()) stripped_type.append(stype) stripped_text = whitespace_split(" ".join(stripped_text)) # not necessary assert len(stripped_text) == len(stripped_type), "sizes of text and type don't match." return stripped_text, stripped_type # list of texts without whitespaces def remove_comma_in_number(self, text): """ convert '3,000' to '3000' """ chunks = text.split(',') return "".join([c for c in chunks if c]) def split_by_dot(self, digit_text): """ Handle digits in numbers, e.g. '3.00', '3' or '192.0.0.0'. treat as respective numbers if >=2 dots, value strings remain otherwise. """ text_list, type_list = [], [] parts = [p for p in digit_text.strip().split('.') if p] if len(parts) > 2: # not a value for part in parts: text_list.append(part) type_list.append(True) text_list.append('.') type_list.append(False) text_list = text_list[: -1] type_list = type_list[: -1] elif len(parts) > 0: # integer or decimal text_list.append(digit_text) type_list.append(True) return text_list, type_list def _group_digit_chars(self, text): """split numbers apart from surrounding texts """ output, type_list = [[]], [] digit_flag = False if text and (text[0] in self.digit_set): digit_flag = True for char in text: if char in self.numeric_mark: output[-1].append(char) elif char in self.digit_set: # current item is digit if digit_flag is True: # previous item also a digit output[-1].append(char) else: # alter from text to digit type_list.append(digit_flag) # type mark of previous span digit_flag = True output.append([char]) else: # current item is text if digit_flag is False: # previous item also text output[-1].append(char) else: # alter from digit to text type_list.append(digit_flag) digit_flag = False output.append([char]) text_list = ["".join(span) for span in output] type_list.append(digit_flag) return text_list, type_list def _run_strip_accents(self, text): text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output) def _run_split_on_punc(self, text): chars = list(text) i = 0 start_new_word = True output = [] while i < len(chars): char = chars[i] if _is_punctuation(char): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) i += 1 return ["".join(x) for x in output] def _tokenize_chinese_chars(self, text): output = [] for char in text: cp = ord(char) if self._is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) def _is_chinese_char(self, cp): if ((cp >= 0x4E00 and cp <= 0x9FFF) or # (cp >= 0x3400 and cp <= 0x4DBF) or # (cp >= 0x20000 and cp <= 0x2A6DF) or # (cp >= 0x2A700 and cp <= 0x2B73F) or # (cp >= 0x2B740 and cp <= 0x2B81F) or # (cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or # (cp >= 0x2F800 and cp <= 0x2FA1F)): # return True return False def _clean_text(self, text): cleaned_text = "" for char in text: cp = ord(char) if cp == 0 or cp == 0xfffd or _is_control(char): continue if _is_whitespace(char): cleaned_text += " " else: cleaned_text += char return cleaned_text # %% Word piecing class WordpieceTokenizer(object): """Cut words into token pieces """ def __init__(self, vocab, unk_token="[UNK]", val_token="[VAL]", mag_size=10, pre_size=10, top_digit=10, low_digit=10, max_cell_len=8, max_token_chars=64 ): self.vocab = vocab self.unk_token = unk_token self.val_token = val_token self.mag_size = mag_size self.pre_size = pre_size self.top_digit = top_digit self.low_digit = low_digit self.default_num = ( self.mag_size + 1, self.pre_size + 1, self.top_digit + 1, self.low_digit + 1 ) self.max_cell_len = max_cell_len self.max_token_chars = max_token_chars def tokenize(self, word_text, word_type): """ Cut a single word into tokens using the vocabulary """ word_text = convert_to_unicode(word_text) if word_type is True: # digit span return self.tokenize_digit(word_text) output_tokens = [] for text in whitespace_split(word_text): chars = list(text) if len(chars) > self.max_token_chars: output_tokens.append(self.unk_token) continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr if substr in self.vocab: cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) return output_tokens, [self.default_num for _ in output_tokens] def tokenize_digit(self, digit_text): """ Tokenize numeric strings. Default settings: magnitude: +inf (= magnitude_size), nan (= magnitude_size + 1) precision: +inf (= precision_size), nan (= precision_size + 1) """ parts = [p for p in digit_text.strip().split('.') if p] token, magnitude, precision = self.val_token, 0, 0 if len(parts) == 1: # integer if digit_text in self.vocab: token = digit_text magnitude = len(digit_text) if len(parts) == 2: # decimal magnitude = len(parts[0]) precision = len(parts[1]) top, low = int(parts[0][0]), int(parts[-1][-1]) # post-process magnitude = min(magnitude, self.mag_size) precision = min(precision, self.pre_size) top = min(max(0, top), self.top_digit) low = min(max(0, low), self.low_digit) return [token], [(magnitude, precision, top, low)] # %% Leveraged for (matrix) tables class TableTokenizer(object): """Basic definitions, string tokenization, repository building. """ def __init__(self, args, do_lower_case=True): self.do_lower_case = do_lower_case self.vocab = load_vocab(args.vocab_path) # get id from token (str) self.inv_vocab = {v: k for k, v in self.vocab.items()} # get token str from id self.punc_ids = set( [index-1 for index in range(1000, 1004)] + [index-1 for index in range(1025, 1037)] + [index-1 for index in range(1064, 1996)] + [index-1 for index in range(29613, 30522)] ) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.wordpiece_tokenizer = WordpieceTokenizer( vocab=self.vocab, mag_size=args.magnitude_size, pre_size=args.precision_size, top_digit=args.top_digit_size, low_digit=args.low_digit_size ) self.row_size = args.row_size self.column_size = args.column_size self.tree_depth = args.tree_depth self.node_degree = args.node_degree self.max_cell_length = args.max_cell_length self.max_seq_len = args.max_seq_len self.text_threshold = args.text_threshold self.value_threshold = args.value_threshold self.clc_rate= args.clc_rate self.wcm_rate = args.wcm_rate # whole-cell-mask rate self.num_format_feature = args.num_format_feature self.default_format = [0.25, 0.25, 0., 0., 0., 0., 0., 0., 0., 1., 1.] self.format_range = [4., 4., 1., 1., 1., 1., 1., 1., 1., 1., 1.] self.default_mlm_label = -1 self.default_clc_label = 0 self.total_node = sum(self.node_degree) self.default_tree_position = [-1 for _ in self.node_degree] # load pre-collected context tuples and cell strings in repository self.context_repo = [] self.load_repository(self.context_repo, args.context_repo_path) self.cellstr_repo = {keylen: [] for keylen in range(2, args.max_cell_length + 1)} self.build_cellstr_dict(self.cellstr_repo, args.cellstr_repo_path) def load_repository(self, repository, repo_path): if repo_path is None: return with open(repo_path, "r", encoding='utf-8') as fr: lines = fr.readlines() repository.extend( [line.strip() for line in lines] ) print("tokenizer collected {} context pieces in repository.".format(len(self.context_repo))) def build_cellstr_dict(self, repository, repo_path): if repo_path is None: return with open(repo_path, "r", encoding='utf-8') as fr: lines = fr.readlines() print("#Num of Lines: ", len(lines)) for line in lines: cellstr = line.strip() tokens, _ = self.tokenize_text(cellstr, add_separate=False) key_len = len(tokens) if key_len < 2: continue for small_key_len in range(len(tokens), 1, -1): if len(repository[small_key_len]) < len(repository[key_len]): key_len = small_key_len offset = random.randint(0, len(tokens) - key_len) repository[key_len].append(tokens[0 + offset: key_len + offset]) # sampling, then print repository message counts = [] for key_len in repository.keys(): m = min(len(repository[key_len]), 500) counts.append(m) repository[key_len] = random.sample(repository[key_len], m) msg = ", ".join([str(l) for l in counts]) print("Replace Repository: {}".format(msg)) def tokenize_string_matrix(self, string_matrix, add_separate, max_cell_len=8): """ Return cell-wise text/number ids as lists. """ token_matrix, number_matrix = [], [] for string_row in string_matrix: token_matrix.append([]) number_matrix.append([]) for cell_string in string_row: cell_token, cell_number = self.tokenize_text(cell_string, add_separate, max_cell_len) token_matrix[-1].append(cell_token) number_matrix[-1].append(cell_number) return token_matrix, number_matrix def tokenize_text(self, cell_string, add_separate=True, max_cell_len=8): """ Tokenize cell strings (as text) upto max_cell_len. cell_number = (magnitude, precision, highest, lowest), or default value """ cell_token, cell_number = [], [] if add_separate == True: cell_token.append("[SEP]") cell_number.append(self.wordpiece_tokenizer.default_num) text_list, type_list = self.basic_tokenizer.tokenize(cell_string) for word_text, word_type in zip(text_list, type_list): token_list, num_list = self.wordpiece_tokenizer.tokenize(word_text, word_type) cell_token.extend(token_list) cell_number.extend(num_list) cell_token = self.convert_tokens_to_ids(cell_token) assert len(cell_token) == len(cell_number), "Token number doesn't match Magnitudes." if len(cell_token) == 1 and cell_token[0] == SEP_ID: cell_token.append(EMP_ID) cell_number.append(self.wordpiece_tokenizer.default_num) cell_token = cell_token[: max_cell_len] cell_number = cell_number[: max_cell_len] if add_separate == True: assert len(cell_token) > 1, "Token list too short: {} -> {}".format(cell_string, cell_token) return cell_token, cell_number def convert_tokens_to_ids(self, tokens): ids = [] for token in tokens: ids.append(self.vocab.get(token, UNK_ID)) return ids def convert_ids_to_tokens(self, ids): tokens = [] for token_id in ids: token = self.inv_vocab.get(token_id, '[UNK]') tokens.append(token) return tokens class TutaTokenizer(TableTokenizer): def check_valid(self, tokens, numbers, in_header): """ Check if a cell is valid enough as table input remove empty cell, in both header and data regions for all-number cell, keep in header (do not mask), sampling from data for half-text&number cell, keep in header (can mask), sampling from data for text cell, keep in header (can mask), sampling from data (resp.) """ if EMP_ID in tokens: return 0 if self.wordpiece_tokenizer.default_num not in set(numbers[1: ]): # pure numeric if (in_header == True) or (random.random() < self.value_threshold): return 1 elif set([self.wordpiece_tokenizer.default_num]) == set(numbers[1: ]): # pure text if (in_header == True) or (random.random() < self.text_threshold): return 2 else: # halfway if (in_header == True) or (random.random() < self.value_threshold): return 2 return 0 def sampling(self, token_matrix, number_matrix, header_info, max_disturb_num=20, disturb_prob=0.2, clc_rate=0.3 ): """ Mark each cell: '0' as dumped, '1' as input, '2' as blanked, '3' as masked. """ header_rows, header_columns = header_info sampling_mask = [[1 for _ in token_matrix[0]] for _ in token_matrix] divide_prob = random.random() divide_prob = 0.4 + divide_prob / 5 mask = [] valid_cell_count = 0 blank_buckets = [set() for _ in range(header_rows + header_columns + 2)] # coarse marking for irow, token_row in enumerate(token_matrix): for icol, tokens in enumerate(token_row): in_header = (irow < header_rows) or (icol < header_columns) cell_valid = self.check_valid(tokens, number_matrix[irow][icol], in_header) valid_cell_count += cell_valid if cell_valid == 0: sampling_mask[irow][icol] = 0 elif cell_valid > 1: prob = random.random() if prob < (disturb_prob * divide_prob): # mask mask.append( (irow, icol) ) else: # blank or none if irow < header_rows: if icol < header_columns: blank_buckets[0].add( (irow, icol) ) else: blank_buckets[1+irow].add( (irow, icol) ) elif icol < header_columns: blank_buckets[header_rows+1+icol].add( (irow, icol) ) else: blank_buckets[-1].add( (irow, icol) ) max_disturb_num = min(max_disturb_num, int(disturb_prob * valid_cell_count)) # refine mask marks mask_num = min(len(mask), int(max_disturb_num * divide_prob)) mask = random.sample(mask, mask_num) for irow, icol in mask: sampling_mask[irow][icol] = 3 # refine blank marks blank = [] for bucket in blank_buckets: if len(bucket) == 0: continue if (len(bucket) < 6) and (random.random() < clc_rate): blank.extend( random.sample(bucket, 1) ) else: bucket_blank_num = min(max(1, int(len(bucket) * clc_rate)), 3) bucket_blank = random.sample(bucket, bucket_blank_num) blank.extend(bucket_blank) blank_num = min(len(blank), max_disturb_num - mask_num) if blank_num > 1: blank = random.sample(blank, blank_num) for irow, icol in blank: sampling_mask[irow][icol] = 2 return sampling_mask def init_table_seq(self, root_context): """Initialize table sequence with CLS_ID at head, add context if provided. """ context_tokens, context_number = self.tokenize_text(cell_string=root_context, add_separate=False, max_cell_len=8) if len(context_tokens) > 0: # do context mlm if available gold_labels, masked_tokens, _ = self.easy_mask(context_tokens, context_number) token_list = [ [CLS_ID] + masked_tokens ] mlm_label = [ [self.default_mlm_label] + gold_labels ] else: token_list = [ [CLS_ID] ] mlm_label = [ [self.default_mlm_label] ] num_list = [ [self.wordpiece_tokenizer.default_num] + context_number ] pos_list = [ (self.row_size, self.column_size, self.default_tree_position, self.default_tree_position) ] format_list = [ self.default_format ] indicator = [ [-1] + [-2 for _ in context_tokens] ] clc_label = [ [self.default_clc_label for _ in token_list[0]] ] cell_num = 1 seq_len = len(token_list[0]) return token_list, num_list, pos_list, format_list, indicator, mlm_label, clc_label, cell_num, seq_len def get_text_choices(self, context_truths, length_left, max_pair_num=3): choice_label_pairs = [] context_length, num = 0, 0 for truth_text in context_truths: truth_token, truth_num = self.tokenize_text(truth_text, True, 20) disturb_text = self.context_repo.pop() # pair each truth with a disturb disturb_token, disturb_num = self.tokenize_text(disturb_text) context_length += len(truth_token) + len(disturb_token) if (context_length > length_left) or (num > max_pair_num): self.context_repo.append(disturb_text) # push the disturb back if not used break num += 1 choice_label_pairs.append( ((truth_token, truth_num), 1) ) choice_label_pairs.append( ((disturb_token, disturb_num), 0) ) new_repo = random.sample(context_truths, num) self.context_repo.extend( new_repo ) random.shuffle(choice_label_pairs) return choice_label_pairs def objective_preprocess(self, sampling_matrix, token_matrix, number_matrix, position_lists, format_matrix=None, context=None, add_sep=True ): top_pos_list, left_pos_list = position_lists row_number, column_number = len(token_matrix), len(token_matrix[0]) # rewrite a fake format_matrix (all default values) for wiki & wdc tables if format_matrix is None: format_matrix = [[[0. for _ in range(self.num_format_feature)] for _ in range(column_number)] for _ in range(row_number)] # get all context chunks context_truths = self.get_context_truth(context) root_context = "" if len(context_truths) > 0: root_context = context_truths[0] context_truths = context_truths[1: ] token_list, num_list, pos_list, format_list, \ indicator, mlm_label, clc_label, cell_num, seq_len = self.init_table_seq(root_context=root_context) paste_token_list, paste_num_list, paste_clc_label = [], [], [] for irow, sampling_row in enumerate(sampling_matrix): for icol, mark in enumerate(sampling_row): if mark == 0: # dumped continue cell_num += 1 tokens = token_matrix[irow][icol] number = number_matrix[irow][icol] cell_len = len(tokens) if mark == 1: # normal token_list.append(tokens) num_list.append(number) mlm_label.append( [self.default_mlm_label for _ in tokens] ) clc_label.append( [self.default_clc_label for _ in tokens] ) elif mark == 2: # blank paste_token_list.append(tokens) paste_num_list.append(number) if add_sep == True: token_list.append( [SEP_ID, EMP_ID] ) num_list.append( [self.wordpiece_tokenizer.default_num] * 2 ) mlm_label.append( [self.default_mlm_label, self.default_mlm_label] ) clc_label.append( [-seq_len, -seq_len-1] ) assert cell_len > 1 real_clc_label = [seq_len, seq_len+1] + [0] * (cell_len-2) paste_clc_label.append( real_clc_label ) cell_len = 2 else: token_list.append( [EMP_ID] ) num_list.append( [self.wordpiece_tokenizer.default_num] ) mlm_label.append( [self.default_mlm_label] ) clc_label.append( [-seq_len] ) assert cell_len > 1 real_clc_label = [seq_len] + [0] * (cell_len-1) paste_clc_label.append( real_clc_label ) cell_len = 1 elif mark == 3: # mask if add_sep == True: if random.random() < self.wcm_rate: gold, fake, sep_label = self.whole_mask(tokens[1: ], number[1: ]) else: gold, fake, sep_label = self.easy_mask(tokens[1: ], number[1: ]) token_list.append( [SEP_ID] + fake ) mlm_label.append([sep_label] + gold) else: gold, fake = self.easy_mask(tokens, number) token_list.append( fake ) mlm_label.append( gold ) num_list.append(number) clc_label.append( [self.default_clc_label for _ in tokens] ) # add corresponding format vector format_vector = [] for ivec, vec in enumerate(format_matrix[irow][icol]): format_vector.append( min(vec, self.format_range[ivec]) / self.format_range[ivec] ) format_list.append( format_vector ) icell = irow * column_number + icol # check if has null position if (max(top_pos_list[icell]) == -1) or (max(left_pos_list[icell]) == -1): return None pos_list.append( (irow, icol, top_pos_list[icell], left_pos_list[icell]) ) indicator.append( [cell_num*2 for _ in range(cell_len)] ) if add_sep: assert (cell_len > 1) or (token_list[-1] == [CLS_ID]), "Mini cell: {}".format(token_list[-1]) indicator[-1][0] -= 1 seq_len += cell_len paste_position = (self.row_size, self.column_size, [-1] * self.tree_depth, [-1] * self.tree_depth) for tokens in paste_token_list: cell_num += 1 seq_len += len(tokens) # ramdom positions for pasted cells to enlarge the attention distance with other cells paste_position = (self.row_size, self.column_size, [random.randint(0, 31)] * self.tree_depth, [random.randint(0, 31)] * self.tree_depth) pos_list.append(paste_position) format_list.append(self.default_format) mlm_label.append( [self.default_mlm_label for _ in tokens] ) paste_ind = [cell_num*2] * len(tokens) if add_sep: paste_ind[0] -= 1 indicator.append( [-dd for dd in paste_ind] ) token_list.extend(paste_token_list) num_list.extend(paste_num_list) clc_label.extend(paste_clc_label) # add table-level context choices tcr_label = [] for clc in clc_label: tcr_label.append( [-1 for _ in clc] ) context_choice_label_pairs = self.get_text_choices(context_truths, self.max_seq_len - seq_len) # adjust number of choices based on current sequence length for (token, number), label in context_choice_label_pairs: cell_num += 1 token_list.append(token) num_list.append(number) mlm_label.append( [self.default_mlm_label for _ in token] ) paste_position = (self.row_size, self.column_size, [random.randint(0, 31)] * self.tree_depth, [random.randint(0, 31)] * self.tree_depth) pos_list.append(paste_position) format_list.append(self.default_format) indicator.append([-cell_num*2 for _ in token]) if add_sep: indicator[-1][0] += 1 clc_label.append( [self.default_clc_label for _ in token] ) tcr_label.append( [-1 for _ in token] ) tcr_label[-1][1] = label if add_sep: tcr_label[-1][0] = label return token_list, num_list, pos_list, format_list, indicator, mlm_label, clc_label, tcr_label def get_replace_token(self, token_id): prob = random.random() if prob < 0.8: return MASK_ID elif prob < 0.9: return random.randint(1996, 29611) else: return token_id def get_mlm_index(self, cell_ids, cell_num): """ Select one token in list, prefer text over punctuations over numbers. """ nonum_indexes, text_indexes = [], [] for ii, (ids, num) in enumerate(zip(cell_ids, cell_num)): if self.wordpiece_tokenizer.default_num == num: nonum_indexes.append(ii) if (ids not in self.punc_ids): text_indexes.append(ii) if len(text_indexes) > 0: index = random.sample(text_indexes, 1)[0] elif len(nonum_indexes) > 0: index = random.sample(nonum_indexes, 1)[0] else: index = random.randint(0, len(cell_ids) - 1) return index def easy_mask(self, cell_ids, cell_num): """ Mask only one token in the given token list. get a random index, mark golden truth, and build fake token list inputs: token_ids and numerical_feats of a cell return: gold labels, fake token list, [SEP] label """ index = self.get_mlm_index(cell_ids, cell_num) gold = [-1 for _ in cell_ids] gold[index] = cell_ids[index] fake = cell_ids[: index] + [self.get_replace_token(cell_ids[index])] + cell_ids[index + 1: ] sep_label = -1 # cell_ids[index] return gold, fake, sep_label def get_mlm_index_whole(self, cell_ids, cell_num): """ Record all viable tokens in list, prefer text over punctuations over numbers. """ nonum_indexes, text_indexes = [], [] indexes = [] for ii, (ids, num) in enumerate(zip(cell_ids, cell_num)): if self.wordpiece_tokenizer.default_num == num: nonum_indexes.append(ii) if (ids not in self.punc_ids): text_indexes.append(ii) if len(text_indexes) > 0: indexes = text_indexes elif len(nonum_indexes) > 0: indexes.append(random.sample(nonum_indexes, 1)[0]) else: indexes.append(random.randint(0, len(cell_ids) - 1)) return indexes def whole_mask(self, cell_ids, cell_num): """ Mask all of the tokens in the given token list get a random index, mark golden truth, and build fake token list inputs: token_ids and numerical_feats of a cell return: gold labels, fake token list, [SEP] label """ indexes = self.get_mlm_index_whole(cell_ids, cell_num) gold = [-1 for _ in cell_ids] fake = [cell_id for cell_id in cell_ids] for index in indexes: gold[index] = cell_ids[index] fake[index] = self.get_replace_token(cell_ids[index]) sep_label = random.sample(cell_ids, 1)[0] return gold, fake, sep_label def get_context_truth(self, context): truths = [] if context is None: # print("context not available.") return truths else: for text in context[: 2]: truths.extend( self.chunk_text(text) ) random.shuffle(truths) return truths def chunk_text(self, text, max_snippet_len=10): words = text.strip().split() num_snippets = (len(words) + max_snippet_len - 1) // max_snippet_len # dump last span spans = [] for i in range(num_snippets): spans.append( " ".join(words[i*max_snippet_len: (i+1)*max_snippet_len]) ) return spans class CtcTokenizer(TableTokenizer): def no_sampling(self, token_matrix): sampling_mask = [[1 for _ in token_matrix[0]] for _ in token_matrix] return sampling_mask def check_valid(self, tokens, numbers, in_header, ctc_label): if EMP_ID in tokens: return 0 if 2 <= ctc_label <= 4: return 1 if self.wordpiece_tokenizer.default_num not in set(numbers[1: ]): # pure numeric if (in_header == True) or (random.random() < self.value_threshold): return 1 elif set([self.wordpiece_tokenizer.default_num]) == set(numbers[1: ]): # pure text if (in_header == True) or (random.random() < self.text_threshold): return 2 else: # halfway if (in_header == True) or (random.random() < self.value_threshold): return 2 return 0 def sampling(self, token_matrix, number_matrix, header_info, label_matrix): """mark each cell: '0' as dumped, '1' as input. """ header_rows, header_columns = header_info sampling_mask = [[1 for _ in token_matrix[0]] for _ in token_matrix] for irow, token_row in enumerate(token_matrix): for icol, tokens in enumerate(token_row): in_header = (irow < header_rows) or (icol < header_columns) ctc_label = int(label_matrix[irow][icol]) cell_valid = self.check_valid(tokens, number_matrix[irow][icol], in_header, ctc_label) sampling_mask[irow][icol] = cell_valid return sampling_mask def init_table_seq(self, root_context=""): """Initialize table sequence with CLS_ID at head, add context if provided. """ context_tokens, context_number = self.tokenize_text(cell_string=root_context, add_separate=False, max_cell_len=64) token_list = [ [CLS_ID] + context_tokens ] num_list = [ [self.wordpiece_tokenizer.default_num] + context_number ] pos_list = [ (self.row_size, self.column_size, [-1] * self.tree_depth, [-1] * self.tree_depth) ] fmt_list = [ self.default_format ] ind_list = [ [-1] + [-2 for _ in context_tokens] ] label_list = [[-1] + [-1 for ct in context_tokens]] cell_num = 1 seq_len = len(token_list[0]) return token_list, num_list, pos_list, fmt_list, ind_list, label_list, cell_num, seq_len def create_table_seq(self, sampling_matrix, token_matrix, number_matrix, position_lists, format_matrix, label_matrix, sep_or_tok, add_sep=True ): top_pos_list, left_pos_list = position_lists row_number, column_number = len(token_matrix), len(token_matrix[0]) if format_matrix is None: format_matrix = [[self.default_format for _ in range(column_number)] for _ in range(row_number)] token_list, num_list, pos_list, fmt_list, ind_list, label_list, cell_num, seq_len = self.init_table_seq(root_context="") for irow, token_row in enumerate(token_matrix): for icol, token_cell in enumerate(token_row): if sampling_matrix[irow][icol] == 0: continue token_list.append( token_cell ) num_list.append( number_matrix[irow][icol] ) cell_len = len(token_cell) icell = irow * column_number + icol pos_list.append( (irow, icol, top_pos_list[icell], left_pos_list[icell]) ) format_vector = [] for ivec, vec in enumerate(format_matrix[irow][icol]): format_vector.append( min(vec, self.format_range[ivec]) / self.format_range[ivec] ) fmt_list.append( format_vector ) ind_list.append( [cell_num*2 for _ in range(cell_len)] ) ind_list[-1][0] -= 1 ctc_label = int(label_matrix[irow][icol]) if (ctc_label < 2) or (ctc_label > 4): ctc_label = -1 else: ctc_label -= 2 label_list.append( [-1 for _ in token_cell] ) label_list[-1][sep_or_tok] = ctc_label label_list[-1][1-sep_or_tok] = ctc_label seq_len += cell_len cell_num += 1 return token_list, num_list, pos_list, fmt_list, ind_list, label_list class TtcTokenizer(TableTokenizer): """Adapted tokenizer for Table Type Classification.""" def __init__(self, args): super(TtcTokenizer, self).__init__(args) def check_valid(self, tokens, numbers, in_header): if EMP_ID in tokens: return 0 if self.wordpiece_tokenizer.default_num not in set(numbers[1: ]): # pure numeric if (in_header == True) or (random.random() < self.value_threshold): return 1 elif set([self.wordpiece_tokenizer.default_num]) == set(numbers[1: ]): # pure text if (in_header == True) or (random.random() < self.text_threshold): return 2 else: # halfway if (in_header == True) or (random.random() < self.value_threshold): return 2 return 0 def sampling(self, token_matrix, number_matrix): """Mark each cell: '0' as dumped, other cases as input. """ sampling_mask = [[1 for cell in row] for row in token_matrix] header_rows, header_columns = 1, 1 for irow, token_row in enumerate(token_matrix): for icol, tokens in enumerate(token_row): # in_header = (irow < header_rows) or (icol < header_columns) cell_valid = self.check_valid(tokens, number_matrix[irow][icol], in_header=True) sampling_mask[irow][icol] = cell_valid return sampling_mask def init_table_lists(self, root_text: str = "") -> Dict: """Initialize table sequences with CLS_ID at head for TTC prediction, add context if provided.""" context_tokens, context_number = self.tokenize_text(cell_string=root_text, add_separate=False, max_cell_len=32) token_list = [ [CLS_ID] + context_tokens ] num_list = [ [self.wordpiece_tokenizer.default_num] + context_number ] pos_list = [(self.row_size, self.column_size, [-1]*self.tree_depth, [-1]*self.tree_depth)] fmt_list = [ self.default_format ] ind_list = [ [-1] + [-2 for _ in context_tokens] ] cell_num = 1 seq_len = len(token_list[0]) return token_list, num_list, pos_list, fmt_list, ind_list, cell_num, seq_len def create_table_lists( self, string_matrix: List[List[str]], top_position_list: List[List[List[int]]], left_position_list: List[List[List[int]]], title: str, label: int, format_matrix: List = None, add_separate: bool = True, **kwargs ): token_list, num_list, pos_list, fmt_list, ind_list, cell_num, seq_len = self.init_table_lists(title) nrows, ncols = len(string_matrix), len(string_matrix[0]) token_matrix, number_matrix = self.tokenize_string_matrix(string_matrix, add_separate) sampling_matrix = self.sampling(token_matrix, number_matrix) assert len(token_matrix) == len(number_matrix) == len(sampling_matrix) == nrows assert len(token_matrix[0]) == len(number_matrix[0]) == len(sampling_matrix[0]) == ncols if format_matrix is None: format_matrix = [[self.default_format for _ in range(ncols)] for _ in range(nrows)] for irow, token_row in enumerate(token_matrix): for icol, token_cell in enumerate(token_row): if sampling_matrix[irow][icol] == 0: continue token_list.append(token_cell) num_list.append(number_matrix[irow][icol]) icell = irow * ncols + icol pos_list.append( (irow, icol, top_position_list[icell], left_position_list[icell]) ) cell_len = len(token_cell) format_vector = [] for ivec, vec in enumerate(format_matrix[irow][icol]): format_vector.append( min(vec, self.format_range[ivec]) / self.format_range[ivec] ) fmt_list.append( format_vector ) ind_list.append( [cell_num*2 for _ in range(cell_len)] ) ind_list[-1][0] -= 1 seq_len += cell_len cell_num += 1 return token_list, num_list, pos_list, fmt_list, ind_list, label @staticmethod def table_lists_to_seq(lists, target): """Serialize lists of loaded samples to model input sequences.""" token_list, num_list, pos_list, fmt_list, ind_list, ttc_label = lists token_id, num_mag, num_pre, num_top, num_low = [], [], [], [], [] token_order, pos_row, pos_col, pos_top, pos_left = [], [], [], [], [] fmt_vec, indicator = [], [] for tokens, num_feats, (r, c, t, l), fmt, ind in zip(token_list, num_list, pos_list, fmt_list, ind_list): cell_len = len(tokens) cell_len = min(8, cell_len) token_id.extend(tokens[:cell_len]) num_mag.extend([f[0] for f in num_feats[:cell_len]]) num_pre.extend([f[1] for f in num_feats[:cell_len]]) num_top.extend([f[2] for f in num_feats[:cell_len]]) num_low.extend([f[3] for f in num_feats[:cell_len]]) token_order.extend([ii for ii in range(cell_len)]) pos_row.extend([r for _ in range(cell_len)]) pos_col.extend([c for _ in range(cell_len)]) entire_top = UNZIPS[target](t) pos_top.extend([entire_top for _ in range(cell_len)]) entire_left = UNZIPS[target](l) fmt_vec.extend( [fmt for _ in range(cell_len)] ) pos_left.extend([entire_left for _ in range(cell_len)]) indicator.extend(ind[:cell_len]) if len(token_id) > 256: return None return ( token_id, num_mag, num_pre, num_top, num_low, token_order, pos_row, pos_col, pos_top, pos_left, fmt_vec, indicator, ttc_label )