TUTA_table_understanding/tuta/ctc_finetune.py

594 строки
30 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Tabular Cell Type Classification
"""
from ast import parse
import os
import json
import random
import argparse
import json_lines
import torch
import torch.nn as nn
import utils as ut
import reader as rdr
import tokenizer as tknr
import model.backbones as bbs
import model.act_funcs as act
import model.pretrains as ptm
from optimizer import AdamW
import sys
random.seed(0)
torch.manual_seed(0)
# %% Reader Class, not been used here since raw data are pre-processed to jsons
class CTCReader(rdr.SheetReader):
def __init__(self, args):
self.tree_depth = args.tree_depth
self.node_degree = args.node_degree
self.row_size = args.row_size
self.column_size = args.column_size
self.target = args.target
def get_inputs(self, hier_table):
cell_matrix, merged_regions = hier_table["Cells"], hier_table["MergedRegions"]
row_number, column_number = len(cell_matrix), len(cell_matrix[0])
if (row_number > self.row_size) or (column_number > self.column_size):
# print("Fail for extreme sizes: {} rows, {} columns ".format(row_number, column_number))
return None
try:
top_root, left_root = hier_table["TopTreeRoot"], hier_table["LeftTreeRoot"]
top_position_list, left_position_list = None, None
top_position_list = self.read_header(top_root, merged_regions, row_number, column_number, True)
left_position_list = self.read_header(left_root, merged_regions, row_number, column_number, False)
if top_position_list is None:
top_position_list = self.read_header(None, merged_regions, row_number, column_number, True)
if left_position_list is None:
left_position_list = self.read_header(None, merged_regions, row_number, column_number, False)
if top_position_list is None:
print("Top Position List is None!")
if left_position_list is None:
print("Left Position List is None!")
except:
print("Error in read header. ")
return None
string_matrix, format_matrix = self.info_from_matrix(cell_matrix, merged_regions)
header_rows, header_columns = hier_table["TopHeaderRowsNumber"], hier_table["LeftHeaderColumnsNumber"]
return string_matrix, (top_position_list, left_position_list), (header_rows, header_columns), format_matrix
# %% Tokenizer Class
class CTCTokenizer(tknr.TableTokenizer):
def no_sampling(self, token_matrix):
sampling_mask = [[1 for _ in token_matrix[0]] for _ in token_matrix]
return sampling_mask
def simple_sampling(self, token_matrix, label_matrix, sample_rate=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]):
sampling_mask = [[1 for _ in token_matrix[0]] for _ in token_matrix]
for irow, token_row in enumerate(token_matrix):
for icol, tokens in enumerate(token_row):
if (tknr.EMP_ID in tokens): # if the cell is empty: [102, 2] 102 means CLS
ctc_label = -1
ctc_label = label_matrix[irow][icol]
for i, sr in enumerate(sample_rate):
if (ctc_label == i) and (random.random() < sr):
sampling_mask[irow][icol] = 0
return sampling_mask
# %% Tokenizer Class
class CTCTok(CTCTokenizer):
def init_table_seq(self, root_context=""):
"""Initialize table sequence with CLS_ID at head, add context if provided. """
context_tokens, context_number = self.tokenize_text(cell_string=root_context, add_separate=False, max_cell_len=8)
token_list = [ [tknr.CLS_ID] + context_tokens ]
num_list = [ [self.wordpiece_tokenizer.default_num] + context_number ]
pos_list = [ (self.row_size, self.column_size, [-1] * self.tree_depth, [-1] * self.tree_depth) ]
format_list = [ self.default_format ]
ind_list = [ [-1] + [-2 for _ in context_tokens] ]
label_list = [ [-1] + [-1 for ct in context_tokens] ]
cell_num = 1
seq_len = len(token_list[0])
return token_list, num_list, pos_list, format_list, ind_list, label_list, cell_num, seq_len
def create_table_seq(self, sampling_matrix, token_matrix, number_matrix, position_lists, format_matrix, label_matrix,
range, max_seq_len, max_cell_length, add_separate=True):
seqs = []
label_dict = {}
start_row = 0
# spit tables exceeded the maximum length to smaller ones
while start_row < len(token_matrix):
token_list, num_list, pos_list, format_list, ind_list, label_list, cell_num, seq_len = self.init_table_seq(root_context="")
top_pos_list, left_pos_list = position_lists
top, bottom, left, right = range
icell = 0
mark_exceed_len = False
for irow, token_row in enumerate(token_matrix):
if mark_exceed_len: break
if irow < start_row:continue
for icol, token_cell in enumerate(token_row):
if sampling_matrix[irow][icol] == 0:
continue
token_cell = token_cell[:max_cell_length]
cell_len = len(token_cell)
if cell_len + seq_len >= max_seq_len:
if irow > start_row:
start_row = max(start_row + 1, irow - 3)
else:
start_row = irow + 1
seqs.append([token_list, num_list, pos_list, format_list, ind_list, label_list])
mark_exceed_len = True
break
if (top <= irow <= bottom) and (left <= icol <= right):
pos_list.append( (irow, icol, top_pos_list[icell], left_pos_list[icell]) )
icell += 1
format_vector = []
for ivec, vec in enumerate(format_matrix[irow-top][icol-left]):
format_vector.append( min(vec, self.format_range[ivec]) / self.format_range[ivec] )
format_list.append( format_vector )
else:
if irow < top:
pos_list.append( (irow, icol, [-1,31,31,irow%256], [-1,31,31,icol%256]))
else:
pos_list.append( (irow, icol, [-1,31,63,irow%256], [-1,31,63,icol%256]))
format_list.append( self.default_format )
token_list.append( token_cell )
num_list.append( number_matrix[irow][icol][:cell_len] )
ind_list.append( [cell_num*2] * cell_len )
ctc_label = label_matrix[irow][icol]
if str(irow) + "_" + str(icol) in label_dict: ctc_label = -1
label_list.append( [-1 for _ in token_cell] )
label_list[-1][0] = ctc_label
if add_separate == True:
ind_list[-1][0] -= 1
label_list[-1][1] = ctc_label
label_dict[str(irow) + "_" + str(icol)] = True
seq_len += cell_len
cell_num += 1
if mark_exceed_len: continue
seqs.append([token_list, num_list, pos_list, format_list, ind_list, label_list])
start_row = len(token_matrix)
return seqs
annotations_mapping_dict={"metadata": 0, "notes": 0, "data": 0, "attributes": 0, "header": 0, "derived": 0, None: -1}
def lists_to_inputs(lists, target, max_seq_len, args):
token_list, num_list, pos_list, format_list, ind_list, label_list = lists
token_id, num_mag, num_pre, num_top, num_low = [], [], [], [], []
token_order, pos_row, pos_col, pos_top, pos_left = [], [], [], [], []
format_vec, indicator, ctc_labels = [], [], []
for tokens, num_feats, (row, col, ttop, tleft), fmt, ind, label in zip(token_list, num_list, pos_list, format_list, ind_list, label_list):
cell_len = len(tokens)
token_id.extend(tokens)
num_mag.extend([f[0] for f in num_feats])
num_pre.extend([f[1] for f in num_feats])
num_top.extend([f[2] for f in num_feats])
num_low.extend([f[3] for f in num_feats])
token_order.extend([ii for ii in range(cell_len)])
pos_row.extend([row for _ in range(cell_len)])
pos_col.extend([col for _ in range(cell_len)])
entire_top = ut.UNZIPS[target](zipped=ttop, node_degree=args.node_degree, total_node=args.total_node)
pos_top.extend([entire_top for _ in range(cell_len)])
entire_left = ut.UNZIPS[target](zipped=tleft, node_degree=args.node_degree, total_node=args.total_node)
pos_left.extend([entire_left for _ in range(cell_len)])
format_vec.extend( [fmt for _ in range(cell_len)] )
indicator.extend(ind)
ctc_labels.extend(label)
if (len(token_id) > max_seq_len) or (max(ctc_labels) == -1):
# print(":( current sequence length {} exceeds the upper bound {}".format(len(token_id), max_seq_len))
return None
# print(":) current sequence length {} meets the upper bound {}".format(len(token_id), max_seq_len))
return (token_id, num_mag, num_pre, num_top, num_low, token_order, pos_row, pos_col, pos_top, pos_left, format_vec, indicator, ctc_labels)
def create_sample(table, tokenizer, args):
# get label_matrix from flat_table
string_matrix = table["string_matrix"]
format_matrix = table["format_matrix"]
position_lists = table["position_lists"]
label_matrix = table["label_matrix"]
range = table["range"]
token_matrix, number_matrix = tokenizer.tokenize_string_matrix(
string_matrix=string_matrix, add_separate=True, max_cell_len=sys.maxsize
)
# do not do any sampling
sampling_matrix = tokenizer.simple_sampling(token_matrix, label_matrix)
input_tuples = []
seqs = tokenizer.create_table_seq(
sampling_matrix=sampling_matrix,
token_matrix=token_matrix,
number_matrix=number_matrix,
position_lists=position_lists,
format_matrix=format_matrix,
label_matrix=label_matrix,
range=range, max_seq_len=args.max_seq_len, max_cell_length=args.max_cell_length
)
for seq in seqs:
token_list, num_list, pos_list, format_list, ind_list, label_list = seq
input_tuple = lists_to_inputs((token_list, num_list, pos_list, format_list, ind_list, label_list), args.target, sys.maxsize, args)
input_tuples.append(input_tuple)
return input_tuples
def create_hier_id(blobname, sheetname):
blobname = blobname.split('.')[0]
unique_id = blobname + '-' + sheetname
return unique_id
def build_datadict(data_file, tokenizer, args):
datadict = {}
with open(data_file, "r") as fr_hier:
tables = json.load(fr_hier)
for key in tables.keys():
table = tables[key]
model_inputs = create_sample(table, tokenizer, args)
blobname, sheetname = table["BlobName"], table["SheetName"]
unique_id = create_hier_id(blobname, sheetname)
for i, model_input in enumerate(model_inputs):
if model_input == None: continue
datadict[unique_id+str(i)] = model_input # might be None!!
print("create a datadict of size: ", len(datadict))
return datadict
def source_content(content_list, datadict):
dataset = []
for content in content_list:
fname, sname = content["fname"], content["sname"]
for i in range(1000):
tmpid = create_hier_id(fname, sname) + str(i)
if tmpid not in datadict:
# print("Not Found Error: Blobname ({}), SheetName ({}) unable to Match!".format(fname, sname))
continue
if datadict[tmpid] is None:
# print("Value Error: Data in Blobname ({}), SheetName ({}) fail at Pre-processing!".format(fname, sname))
continue
dataset.append( datadict[tmpid] )
return dataset
def stat_dataset(dataset, name):
counts = [0 for _ in range(6)]
for sample in dataset:
ctc_labels = sample[-1]
for lbl in ctc_labels:
if lbl<0 or lbl>5:continue
counts[lbl] += 1
print("Dataset: {} counts: {}".format(name, counts))
def create_dynamic_dataset_folds(folds_path, data_file, tokenizer, args):
print("execute create_dynamic_dataset_folds")
with open(folds_path, "r") as fr:
folds = json.load(fr) # folds: list of five json file
dataset_couples = [ [] for _ in range(5)]
for iepoch in range(args.dataset_num):
dyn_datadict = build_datadict(data_file, tokenizer, args) # sampling from the intersection between data_file and flat
print("iepoch:", iepoch)
for i, fold in enumerate(folds):
dyn_trainset = source_content(fold["train"], dyn_datadict)
stat_dataset(dyn_trainset, "Train #{}".format(i))
dyn_testset = source_content(fold["test"], dyn_datadict)
stat_dataset(dyn_testset, "Test #{}".format(i))
print("Train Size: {}, Test Size: {}".format(len(dyn_trainset), len(dyn_testset)))
dataset_couples[i].append( (dyn_trainset, dyn_testset) )
print("Build Dataset Couples: {}".format([len(dc) for dc in dataset_couples]))
return dataset_couples
# %% CTC Head Class
class CtcHead(nn.Module):
def __init__(self, config):
super(CtcHead, self).__init__()
self.uniform_linear_tok = nn.Linear(config.hidden_size, config.hidden_size)
self.uniform_linear_sep = nn.Linear(config.hidden_size, config.hidden_size)
self.act_fn = act.ACT_FCN[config.hidden_act]
self.tanh = nn.Tanh()
self.predict_linear = nn.Linear(config.hidden_size, config.num_ctc_type)
self.loss = nn.CrossEntropyLoss()
self.aggregator = config.aggregator
self.aggr_funcs = {"sum": self.token_sum, "avg": self.token_avg}
def token_sum(self, token_states, indicator):
"""take the sum of token encodings (not including [SEP]s) as cell encodings """
x_mask = indicator.unsqueeze(1) # [batch_size, 1, seq_len]
y_mask = x_mask.transpose(-1, -2) # [batch_size, seq_len, 1]
mask_matrix = y_mask.eq(x_mask).float() # [batch_size, seq_len, seq_len]
sum_states = mask_matrix.matmul(token_states) # [batch_size, seq_len, hidden_size]
return sum_states
def token_avg(self, token_states, indicator):
"""take the average of token encodings (not including [SEP]s) as cell encodings """
x_mask = indicator.unsqueeze(1) # [batch_size, 1, seq_len]
y_mask = x_mask.transpose(-1, -2) # [batch_size, seq_len, 1]
mask_matrix = y_mask.eq(x_mask).float() # [batch_size, seq_len, seq_len]
sum_matrix = torch.sum(mask_matrix, dim=-1)
mask_matrix = mask_matrix.true_divide(sum_matrix.unsqueeze(-1))
cell_states = mask_matrix.matmul(token_states) # [batch_size, seq_len, hidden_size]
return cell_states
def forward(self, encoded_states, indicator, ctc_label):
# get cell encodings from token sequence
cell_states = self.aggr_funcs[self.aggregator](encoded_states, indicator)
ctc_label = ctc_label.contiguous().view(-1)
cell_states = cell_states.contiguous().view(ctc_label.size()[0], -1)
ctc_logits = cell_states[ctc_label > -1, :] # [batch_total_cell_num, hidden_size]
ctc_label = ctc_label[ctc_label > -1]
# separator
sep_logits = self.uniform_linear_sep(ctc_logits[0::2, :])
sep_logits = self.tanh(sep_logits)
sep_logits = self.predict_linear(sep_logits)
sep_predict = sep_logits.argmax(dim=-1)
sep_labels = ctc_label[0: : 2]
sep_loss = self.loss(sep_logits, sep_labels)
# token-aggregation
tok_logits = self.uniform_linear_tok(ctc_logits[1::2, :])
tok_logits = self.tanh(tok_logits)
tok_logits = self.predict_linear(tok_logits)
tok_predict = tok_logits.argmax(dim=-1)
tok_labels = ctc_label[1: : 2] # [batch-variant copied num]
tok_loss = self.loss(tok_logits, tok_labels) # scalar
return (sep_loss, sep_predict, sep_labels), (tok_loss, tok_predict, tok_labels)
# %% Model Architecture Class
class TUTAForCTC(nn.Module):
def __init__(self, config):
super(TUTAForCTC, self).__init__()
self.backbone = bbs.BACKBONES[config.target](config)
self.ctc_head = CtcHead(config)
def forward(self, token_id, num_mag, num_pre, num_top, num_low, \
token_order, pos_row, pos_col, pos_top, pos_left, format_vec, \
indicator, ctc_label):
encoded_states = self.backbone(token_id, num_mag, num_pre, num_top, num_low, \
token_order, pos_row, pos_col, pos_top, pos_left, format_vec, indicator)
sep_triple, tok_triple = self.ctc_head(encoded_states, indicator, ctc_label)
return sep_triple, tok_triple
# training and testing pipeline
def Pipeline(args, model, dataset_couples, no_decay=['bias', 'gamma', 'beta']):
def evaluate(args, testset):
# print("Start Evaluation with {} Instances: ".format(len(testset)))
model.eval()
sep_confusion_matrix = [[0 for _ in range(args.num_ctc_type)] for _ in range(args.num_ctc_type)]
tok_confusion_matrix = [[0 for _ in range(args.num_ctc_type)] for _ in range(args.num_ctc_type)]
for i, tensors in enumerate(
ut.load_dataset_batch_withpad(
dataset=testset,
batch_size=args.batch_size,
defaults=[0,11,11,11,11,0,256,256,args.default_tree_position,args.default_tree_position,args.default_format,0,-1],
device_id=args.device_id
)
):
with torch.no_grad():
token_id, num_mag, num_pre, num_top, num_low, token_order, pos_row, pos_col, pos_top, pos_left, fmt_vec, ind, ctc = tensors
(_, sep_pred, sep_gold), (_, tok_pred, tok_gold) = model(
token_id=token_id, num_mag=num_mag, num_pre=num_pre, num_top=num_top, num_low=num_low,
token_order=token_order, pos_row=pos_row, pos_col=pos_col, pos_top=pos_top, pos_left=pos_left,
format_vec=fmt_vec.float(), indicator=ind, ctc_label=ctc
)
for spd, sgd in zip(sep_pred.tolist(), sep_gold.tolist()):
sep_confusion_matrix[spd][sgd] += 1
for tpd, tgd in zip(tok_pred.tolist(), tok_gold.tolist()):
tok_confusion_matrix[tpd][tgd] += 1
# compute confusion matrix
sep_precision, sep_recall = [], []
tok_precision, tok_recall = [], []
for iclass in range(args.num_ctc_type):
class_sep_precision = sep_confusion_matrix[iclass][iclass] / (sum(sep_confusion_matrix[iclass]) + 1e-6)
sep_precision.append(class_sep_precision)
class_sep_recall = sep_confusion_matrix[iclass][iclass] / (sum([line[iclass] for line in sep_confusion_matrix]) + 1e-6)
sep_recall.append(class_sep_recall)
class_tok_precision = tok_confusion_matrix[iclass][iclass] / (sum(tok_confusion_matrix[iclass]) + 1e-6)
tok_precision.append(class_tok_precision)
class_tok_recall = tok_confusion_matrix[iclass][iclass] / (sum([line[iclass] for line in tok_confusion_matrix]) + 1e-6)
tok_recall.append(class_tok_recall)
sep_f1 = [(2*p*r)/(p+r+1e-6) for p,r in zip(sep_precision, sep_recall)]
print("[SEP] f1: ", [round(value, 3) for value in sep_f1])
tok_f1 = [(2*p*r)/(p+r+1e-6) for p,r in zip(tok_precision, tok_recall)]
print("[TOK] f1: ", [round(value, 3) for value in tok_f1], sum(tok_f1)/6)
if args.sep_or_tok != 0:
return [s for s,t in zip(sep_f1, tok_f1)]
else:
return [t for s,t in zip(sep_f1, tok_f1)]
# Do Training
param_optimizer = list(model.named_parameters())
print("tuning all of the model parameters (backbone + ctc_head)")
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
total_sep_loss, total_tok_loss = 0.0, 0.0
best_result = [0.0] * args.num_ctc_type
early_stopping_count = 0
for iepoch in range(args.epochs_num):
print(len(dataset_couples))
trainset, testset = dataset_couples[iepoch % args.dataset_num]
random.shuffle(trainset)
print("Start Training", iepoch, args.report_steps)
model.train()
for ii, tensors in enumerate(
ut.load_dataset_batch_withpad(
dataset=trainset,
batch_size=args.batch_size,
defaults=[0,11,11,11,11,0,256,256,args.default_tree_position,args.default_tree_position,args.default_format,0,-1],
device_id=args.device_id
)
):
model.zero_grad()
token_id, num_mag, num_pre, num_top, num_low, token_order, pos_row, pos_col, pos_top, pos_left, fmt_vec, ind, ctc = tensors
(sep_loss, _, _), (tok_loss, _, _) = model(
token_id=token_id, num_mag=num_mag, num_pre=num_pre, num_top=num_top, num_low=num_low,
token_order=token_order, pos_row=pos_row, pos_col=pos_col, pos_top=pos_top, pos_left=pos_left,
format_vec=fmt_vec.float(), indicator=ind, ctc_label=ctc
)
loss = sep_loss * args.sep_weight + tok_loss * (1. - args.sep_weight)
total_sep_loss += sep_loss.item()
total_tok_loss += tok_loss.item()
if (ii+1) % args.report_steps == 0:
print("Epoch id: {}, Training steps: {}, Avg loss: [SEP] {:.3f}, [TOK] {:.3f}".\
format(iepoch, ii+1, total_sep_loss / args.report_steps, total_tok_loss / args.report_steps))
total_sep_loss, total_tok_loss = 0.0, 0.0
loss.backward()
optimizer.step()
result = evaluate(args, testset)
if sum(result) >= sum(best_result):
best_result = result
ut.save_model(model, args.output_model_path)
else:
early_stopping_count += 1
if early_stopping_count > args.early_stopping_bound:
break
return result
# %% Main Procedure
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# i/o paths
parser.add_argument("--folds_path", type=str, default="example/folds_saus5.json", help="Path of the splitted folds content json.")
parser.add_argument("--data_file", type=str, default="deex.json", help="Directory of the hierarchical json files.")
# add train/dev/test files
parser.add_argument("--pretrained_model_path", type=str, default="sqall-2-512-tab-1007.bin-1000000", help="Path of the pretrained bert/ts model.")
parser.add_argument("--output_model_path", type=str, default="ctc_out.bin", help="Fine-tuned model path.")
# model configurations
parser.add_argument("--vocab_path", type=str, default="./vocab/bert_vocab.txt", help="Path of the vocabulary file.")
parser.add_argument("--context_repo_path", type=str, default="./vocab/context_repo_init.txt", help="TXT of pre-collected context pieces.")
parser.add_argument("--cellstr_repo_path", type=str, default="./vocab/cellstr_repo_init.txt", help="TXT of pre-collected context pieces.")
parser.add_argument("--hidden_size", type=int, default=768, help="Size of the hidden states.")
parser.add_argument("--intermediate_size", type=int, default=3072, help="Size of the intermediate layer.")
parser.add_argument("--magnitude_size", type=int, default=10, help="Max magnitude of numeric values.")
parser.add_argument("--precision_size", type=int, default=10, help="Max precision of numeric values.")
parser.add_argument("--top_digit_size", type=int, default=10, help="Most significant digit from '0' to '9'.")
parser.add_argument("--low_digit_size", type=int, default=10, help="Least significant digit from '0' to '9'.")
parser.add_argument("--max_cell_length", type=int, default=8, help="Maximum number of tokens in one cell string.")
parser.add_argument("--row_size", type=int, default=2560, help="Max number of rows in table.")
parser.add_argument("--column_size", type=int, default=2560, help="Max number of columns in table.")
parser.add_argument("--tree_depth", type=int, default=4, help="Maximum depth of top & left header tree.")
parser.add_argument("--node_degree", type=str, default="32,32,64,256", help="Maximum number of children of each tree node.")
parser.add_argument("--attention_distance", type=int, default=2, help="Maximum distance for attention visibility.")
parser.add_argument("--attention_step", type=int, default=0, help="Step size of attention distance to add for each layer.")
parser.add_argument("--num_attention_heads", type=int, default=12, help="Number of the attention heads.")
parser.add_argument("--num_encoder_layers", type=int, default=12, help="Number of the encoding layers.")
parser.add_argument("--hidden_dropout_prob", type=int, default=0.1, help="Dropout probability for hidden layers.")
parser.add_argument("--attention_dropout_prob", type=int, default=0.1, help="Dropout probability for attention.")
parser.add_argument("--layer_norm_eps", type=float, default=1e-6)
parser.add_argument("--hidden_act", type=str, default="gelu", help="Activation function for hidden layers.")
parser.add_argument("--learning_rate", type=float, default=8e-6, help="Learning rate during fine-tuning.")
parser.add_argument("--max_seq_len", type=int, default=512, help="Maximum length of the table sequence.")
parser.add_argument("--max_cell_num", type=int, default=256, help="Maximum cell number.") # useful ??
parser.add_argument("--text_threshold", type=float, default=0.5, help="Probability threshold to sample text in data region.")
parser.add_argument("--value_threshold", type=float, default=0.1, help="Prob to sample value in data region.")
parser.add_argument("--clc_rate", type=float, default=0.3)
parser.add_argument("--wcm_rate", type=float, default=0.3, help="Proportion of masked cells doing whole-cell-masking.")
parser.add_argument("--add_separate", type=bool, default=True, help="Whether to add [SEP] as aggregate cell representation.")
parser.add_argument("--num_ctc_type", type=int, default=6, help="Number of cell types for classification.")
parser.add_argument("--attn_method", type=str, default="add", choices=["max", "add"])
parser.add_argument("--hier_or_flat", type=str, default="both", choices=["hier", "flat", "both"])
parser.add_argument("--org_or_weigh", type=str, default="original", choices=["original", "weighted"])
parser.add_argument("--num_format_feature", type=int, default=11)
parser.add_argument("--sep_or_tok", type=int, default=0, choices=[0, 1])
parser.add_argument("--sep_weight", type=float, default=0.0, help="Weight to be multiplied on SEP loss.")
parser.add_argument("--aggregator", type=str, default="sum", choices=["sum", "avg"], help="Aggregation method from token to cell.")
# model choices
parser.add_argument("--target", type=str, default="tuta", help="Pre-training objectives.")
# training options
parser.add_argument("--batch_size", type=int, default=4, help="Size of the input batch.")
parser.add_argument("--report_steps", type=int, default=200, help="Specific steps to print prompt.")
parser.add_argument("--epochs_num", type=int, default=30, help="Number of epochs for fine-tune.")
parser.add_argument("--dataset_num", type=int, default=1, help="Times of distinct data sampling.")
parser.add_argument("--early_stopping_bound", type=int, default=100)
parser.add_argument("--device_id", type=int, default=None, help="Designated GPU id if not None.")
args = parser.parse_args()
args.node_degree = [int(degree) for degree in args.node_degree.split(',')]
args.total_node = sum(args.node_degree)
args.default_tree_position = [args.total_node for _ in args.node_degree]
print("node degree: ", args.node_degree)
# reader = CTCReader(args)
tokenizer = CTCTok(args)
args.vocab_size = len(tokenizer.vocab)
args.default_format = tokenizer.default_format
print("Default Format: ", args.default_format)
model = TUTAForCTC(args)
if args.device_id is not None:
print("Using devices {} for testing".format(args.device_id))
model.cuda(args.device_id)
# create datasets for five folds
dataset_couples = create_dynamic_dataset_folds(args.folds_path, args.data_file, tokenizer, args)
folds_results = []
for iset, fold_couples in enumerate(dataset_couples): # calculate every fold for 100 epochs
print("\nGo On to Couple #{}".format(iset+1))
print("fold_length:", len(dataset_couples))
ut.init_tuta_loose(model=model, tuta_path=args.pretrained_model_path)
f1_list = Pipeline(args, model, fold_couples)
print("F1 List: ", [round(fl, 3) for fl in f1_list], "\n\n")
folds_results.append(f1_list)
# post-calculate f1 results
average_f1 = []
for ilbl in range(args.num_ctc_type):
f1_collection = [res[ilbl] for res in folds_results]
f1_collection_no_zero = [f1 for f1 in f1_collection if ( f1 > 1e-6)]
f1_collection_with_zero = [f1 for f1 in f1_collection]
print(ilbl, f1_collection_with_zero)
average_f1.append( sum(f1_collection_no_zero) / (len(f1_collection_no_zero) + 1e-6))
print("Average F1: ", [round(af, 3) for af in average_f1])
print("Macro Acc.: {:.4f}".format(sum(average_f1) / (len(average_f1) + 1e-6)))
if __name__ == "__main__":
main()