# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import os, random import numpy as np import torch from torch.utils.data import TensorDataset, Dataset from itertools import groupby from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, BatchSampler) from seqeval.metrics import classification_report, f1_score from collections import defaultdict from seqeval.scheme import IOB2 import re def set_seed(seed): random.seed(seed) torch.cuda.cudnn_enabled = False np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False return class NERDataset(Dataset): def __init__(self, tensor_mp): super(NERDataset, self).__init__() self.data = tensor_mp return def __getitem__(self, idx): inputs = () for key in ['idx', 'words', 'word_masks', 'word_to_piece_inds', 'word_to_piece_ends', 'sent_lens', 'labels', 'weights', 'valid_masks']: if (key in self.data) and (self.data[key] is not None): inputs = inputs + (self.data[key][idx],) return inputs def __len__(self): return len(self.data["words"]) class DataUtils: def __init__(self, tokenizer, max_seq_length): self.tokenizer = tokenizer self.max_length = max_seq_length self.label2idx = None self.idx2label = None return def read_conll(self, file_name): def is_empty_line(line_pack): return len(line_pack.strip()) == 0 sent_len = 0 ent_len = 0 sentences = [] labels = [] type_mp = defaultdict(int) token_num = 0 bio_err = 0 with open(file_name, mode="r", encoding="utf-8") as fp: lines = [line.strip("\n") for line in fp] groups = groupby(lines, is_empty_line) for is_empty, pack in groups: if is_empty is False: cur_sent = [] cur_labels = [] prev = "O" for wt_pair in pack: tmp = wt_pair.split("\t") if len(tmp) == 1: w, t = tmp[0].split() else: w, t = tmp if t != "O": ent_len += 1 if prev == "O" and t[0] != "B" and t != "O": bio_err += 1 cur_sent.append(w) cur_labels.append(t) token_num += 1 prev = t if t[0] == "B" or t[0] == "S": type_mp[t[2:]] += 1 sent_len += len(cur_sent) sentences.append(cur_sent) labels.append(cur_labels) ent_num = sum(type_mp.values()) sent_len /= len(sentences) ent_len /= ent_num print("{}: ent {} token {} sent {}".format(file_name, ent_num, token_num, len(sentences))) print("avg sent len: {} avg ent len: {}".format(sent_len, ent_len)) print("bio error {}".format(bio_err)) print(sorted([(k, v) for k, v in type_mp.items()], key=lambda x: x[0])) if self.label2idx is None: print("get label map !!!!!!!!!!!!!!!") self.get_label_map(labels) print(self.label2idx) return sentences, labels def read_examples_from_file(self, file_name): guid_index = 0 sentences = [] labels = [] with open(file_name, encoding="utf-8") as f: cur_words = [] cur_labels = [] for line in f: line = line.strip().replace("\t", " ") if line.startswith("-DOCSTART-") or line == "" or line == "\n": if cur_words: sentences.append(cur_words) labels.append(cur_labels) guid_index += 1 cur_words = [] cur_labels = [] else: splits = line.split(" ") cur_words.append(splits[0]) cur_labels.append(splits[-1].replace("\n", "")) if cur_words: sentences.append(cur_words) labels.append(cur_labels) guid_index += 1 print("number of sentences : {}".format(len(sentences))) return sentences, labels def get_label_map(self, labels): labels = sorted(set([y for x in labels for y in x])) label2id_mp = {} id2label_mp = {} for i, tag in enumerate(labels): label2id_mp[tag] = i id2label_mp[i] = tag self.label2idx = label2id_mp self.idx2label = id2label_mp return def get_ner_dataset(self, filename): sentences, labels = self.read_conll(filename) print(f"reading {filename}") dataset = self.process_data(sentences, labels, mode='ner') return dataset def get_loader(self, dataset, shuffle=True, batch_size=32): if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) data_loader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) return data_loader def process_data(self, sentences, labels=None, mode='ner'): sent_num = len(sentences) data_word = np.full((sent_num, self.max_length), self.tokenizer.pad_token_id, dtype=np.int32) data_mask = np.zeros((sent_num, self.max_length), dtype=np.int32) data_w2p_ind = np.zeros((sent_num, self.max_length), dtype=np.int32) data_w2p_end = np.zeros((sent_num, self.max_length), dtype=np.int32) data_length = np.zeros((sent_num), dtype=np.int32) data_labels = np.full((sent_num, self.max_length), -100, dtype=np.int32) for i in range(sent_num): cur_tokens = [self.tokenizer.cls_token] cur_word_to_piece_ind = [] cur_word_to_piece_end = [] cur_seq_len = None for j, word in enumerate(sentences[i]): tag = labels[i][j] if labels else "O" if j > 0: word = ' ' + word data_labels[i][j] = self.label2idx[tag] word_tokens = self.tokenizer.tokenize(word) if len(word_tokens) == 0: word_tokens = [self.tokenizer.unk_token] print("error token") if len(cur_tokens) + len(word_tokens) + 1 > self.max_length: part_len = self.max_length - 1 - len(cur_tokens) if part_len > 0: cur_word_to_piece_ind.append(len(cur_tokens)) cur_tokens.extend(word_tokens[:part_len]) cur_word_to_piece_end.append(len(cur_tokens) - 1) cur_tokens.append(self.tokenizer.sep_token) cur_seq_len = len(cur_word_to_piece_ind) break else: cur_word_to_piece_ind.append(len(cur_tokens)) cur_tokens.extend(word_tokens) cur_word_to_piece_end.append(len(cur_tokens) - 1) if cur_seq_len is None: # not full, append padding tokens assert len(cur_tokens) + 1 <= self.max_length cur_tokens.append(self.tokenizer.sep_token) cur_seq_len = len(cur_word_to_piece_ind) assert cur_seq_len == len(cur_word_to_piece_end) indexed_tokens = self.tokenizer.convert_tokens_to_ids(cur_tokens) while len(indexed_tokens) < self.max_length: indexed_tokens.append(0) indexed_tokens = np.array(indexed_tokens) while len(cur_word_to_piece_ind) < self.max_length: cur_word_to_piece_ind.append(0) cur_word_to_piece_end.append(0) data_word[i, :] = indexed_tokens data_mask[i, :len(cur_tokens)] = 1 data_w2p_ind[i, :] = cur_word_to_piece_ind data_w2p_end[i, :] = cur_word_to_piece_end data_length[i] = cur_seq_len return NERDataset({'idx': torch.arange(0, len(sentences)), 'words': torch.LongTensor(data_word), 'word_masks': torch.tensor(data_mask), 'word_to_piece_inds': torch.LongTensor(data_w2p_ind), 'word_to_piece_ends': torch.LongTensor(data_w2p_end), 'sent_lens': torch.LongTensor(data_length), 'labels': torch.LongTensor(data_labels), 'ori_sents': sentences, 'ori_labels': labels}) def performance_report(self, y_pred_id, y_true_id, y_pred_probs=None, print_log=False): # convert id to tags y_pred_id = y_pred_id.detach().cpu().tolist() y_true_id = y_true_id.detach().cpu().tolist() if y_pred_probs is not None: y_pred_probs = y_pred_probs.detach().cpu().tolist() y_true = [] y_pred = [] y_probs = [] sid = 0 for tids, pids in zip(y_true_id, y_pred_id): y_true.append([]) y_pred.append([]) y_probs.append([]) for k in range(len(tids)): if tids[k] == -100: break y_true[-1].append(self.idx2label[tids[k]]) if k < len(pids): y_pred[-1].append(self.idx2label[pids[k]]) y_probs[-1].append(y_pred_probs[sid][k] if y_pred_probs is not None else None) else: y_pred[-1].append("O") y_probs[-1].append(None) sid += 1 if print_log: print("============conlleval mode==================") report = classification_report(y_true, y_pred, digits=4) print(report) f1 = f1_score(y_true, y_pred, scheme=IOB2) if y_pred_probs is not None: return y_pred, y_true, y_probs, f1 return y_pred, y_true, f1 class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = None self.count = 0 def update(self, val, n=1): self.val = val if self.sum is None: self.sum = val * n else: self.sum += val * n self.count += n self.avg = self.sum / self.count