Code for the DecomposedMetaNER paper at Findings of the ACL 2022 (#22)
Co-authored-by: Huiqiang Jiang <hjiang@microsoft.com>
This commit is contained in:
Родитель
9c8eedc5af
Коммит
dc14a4f1bf
|
@ -0,0 +1,142 @@
|
|||
# Decomposed Meta-Learning for Few-Shot Named Entity Recognition
|
||||
|
||||
This repository contains the open-sourced official implementation of the paper:
|
||||
|
||||
[Decomposed Meta-Learning for Few-Shot Named Entity Recognition](https://arxiv.org/abs/2204.05751) (Findings of the ACL 2022).
|
||||
|
||||
_Tingting Ma, Huiqiang Jiang, Qianhui Wu, Tiejun Zhao, and Chin-Yew Lin_
|
||||
|
||||
If you find this repo helpful, please cite the following paper:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{ma2022decomposedmetaner,
|
||||
title="{D}ecomposed {M}eta-{L}earning for {F}ew-{S}hot {N}amed {E}ntity {R}ecognition",
|
||||
author={Tingting Ma and Huiqiang Jiang and Qianhui Wu and Tiejun Zhao and Chin-Yew Lin},
|
||||
year={2022},
|
||||
month={aug},
|
||||
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics},
|
||||
publisher={Association for Computational Linguistics},
|
||||
url={https://arxiv.org/abs/2204.05751},
|
||||
}
|
||||
```
|
||||
|
||||
For any questions/comments, please feel free to open GitHub issues.
|
||||
|
||||
## 🎥 Overview
|
||||
|
||||
Few-shot named entity recognition (NER) systems aim at recognizing novel-class named entities based on only a few labeled examples. In this paper, we present a decomposed metalearning approach which addresses the problem of few-shot NER by sequentially tackling fewshot span detection and few-shot entity typing using meta-learning. In particular, we take the few-shot span detection as a sequence labeling problem and train the span detector by introducing the model-agnostic meta-learning (MAML) algorithm to find a good model parameter initialization that could fast adapt to new entity classes. For few-shot entity typing, we propose MAML-ProtoNet, i.e., MAML-enhanced prototypical networks to find a good embedding space that can better distinguish text span representations from different entity classes. Extensive experiments on various benchmarks show that our approach achieves superior performance over prior methods.
|
||||
|
||||
![image](./images/framework_DecomposedMetaNER.png)
|
||||
|
||||
## 🎯 Quick Start
|
||||
|
||||
### Requirements
|
||||
|
||||
- python 3.9
|
||||
- pytorch 1.9.0+cu111
|
||||
- [HuggingFace Transformers 4.10.0](https://github.com/huggingface/transformers)
|
||||
- Few-NERD dataset
|
||||
|
||||
Other pip package show in `requirements.txt`.
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
The code may work on other python and pytorch version. However, all experiments were run in the above environment.
|
||||
|
||||
### Train and Evaluate
|
||||
|
||||
For _Linux_ machines,
|
||||
|
||||
```bash
|
||||
bash scripts/run.sh
|
||||
```
|
||||
|
||||
For _Windows_ machines,
|
||||
|
||||
```cmd
|
||||
call scripts\run.bat
|
||||
```
|
||||
|
||||
If you only want to predict in the trained model, you can get the model from [Azure blob](https://kcpapers.blob.core.windows.net/dmn-findings-acl-2022/dmn-all.zip).
|
||||
|
||||
Put the `models-{N}-{K}-{mode}` folder in `papers/DecomposedMetaNER` directory.
|
||||
|
||||
For _Linux_ machines,
|
||||
|
||||
```bash
|
||||
bash scripts/run_pred.sh
|
||||
```
|
||||
|
||||
For _Windows_ machines,
|
||||
|
||||
```cmd
|
||||
call scripts\run_pred.bat
|
||||
```
|
||||
|
||||
## 🍯 Datasets
|
||||
|
||||
We use the following widely-used benchmark datasets for the experiments:
|
||||
|
||||
- Few-NERD [Ding et al., 2021](https://aclanthology.org/2021.acl-long.248) for Intra and Inter two mode few-shot NER;
|
||||
- Cross-Dataset [Hou et al., 2020](https://www.aclweb.org/anthology/2020.acl-main.128) for four cross-domain few-shot NER;
|
||||
|
||||
The Few-NERD dataset is annotated with 8 coarse-grained and 66 fine-grained entity types. And the Cross-Dataset are annotated with 4, 11, 6, 18 entity types in difference domain. Each dataset is split into training, dev, and test sets.
|
||||
|
||||
All datasets in N-way K~2K shot setting and IO tagging scheme. In this repo, we don't publish any data due to the license. You can download them from their respective websites: [Few-NERD](https://cloud.tsinghua.edu.cn/f/8483dc1a34da4a34ab58/?dl=1), and [Cross-Dataset](https://atmahou.github.io/attachments/ACL2020data.zip).
|
||||
And place them in the correct locations: `./`.
|
||||
|
||||
## 📋 Results
|
||||
|
||||
We report the few-shot NER results of the proposed DecomposedMetaNER on the 1~2 shot and 5~10 shot, alongside those reported by prior state-of-the-art methods.
|
||||
|
||||
### Few-NERD ACL Version
|
||||
|
||||
Available from [Ding et al., 2021](https://cloud.tsinghua.edu.cn/f/8483dc1a34da4a34ab58/?dl=1).
|
||||
|
||||
| | Intra 5-1 | Intra 10-1 | Intra 5-5 | Intra 10-5 | Inter 5-1 | Inter 10-1 | Inter 5-5 | Inter 10-5 |
|
||||
| -------------------------------------------------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
|
||||
| [ProtoBERT](https://aclanthology.org/2021.acl-long.248) | 23.45 ± 0.92 | 19.76 ± 0.59 | 41.93 ± 0.55 | 34.61 ± 0.59 | 44.44 ± 0.11 | 39.09 ± 0.87 | 58.80 ± 1.42 | 53.97 ± 0.38 |
|
||||
| [NNShot](https://aclanthology.org/2021.acl-long.248) | 31.01 ± 1.21 | 21.88 ± 0.23 | 35.74 ± 2.36 | 27.67 ± 1.06 | 54.29 ± 0.40 | 46.98 ± 1.96 | 50.56 ± 3.33 | 50.00 ± 0.36 |
|
||||
| [StructShot](https://aclanthology.org/2021.acl-long.248) | 35.92 ± 0.69 | 25.38 ± 0.84 | 38.83 ± 1.72 | 26.39 ± 2.59 | 57.33 ± 0.53 | 49.46 ± 0.53 | 57.16 ± 2.09 | 49.39 ± 1.77 |
|
||||
| [CONTAINER](https://arxiv.org/abs/2109.07589) | 40.43 | 33.84 | 53.70 | 47.49 | 55.95 | 48.35 | 61.83 | 57.12 |
|
||||
| [ESD](https://arxiv.org/abs/2109.13023v1) | 41.44 ± 1.16 | 32.29 ± 1.10 | 50.68 ± 0.94 | 42.92 ± 0.75 | 66.46 ± 0.49 | 59.95 ± 0.69 | **74.14** ± 0.80 | 67.91 ± 1.41 |
|
||||
| **DecomposedMetaNER** | **52.04** ± 0.44 | **43.50** ± 0.59 | **63.23** ± 0.45 | **56.84** ± 0.14 | **68.77** ± 0.24 | **63.26** ± 0.40 | 71.62 ± 0.16 | **68.32** ± 0.10 |
|
||||
|
||||
### Few-NERD Arxiv Version
|
||||
|
||||
Available from [Ding et al., 2021](https://cloud.tsinghua.edu.cn/f/0e38bd108d7b49808cc4/?dl=1).
|
||||
|
||||
| | Intra 5-1 | Intra 10-1 | Intra 5-5 | Intra 10-5 | Inter 5-1 | Inter 10-1 | Inter 5-5 | Inter 10-5 |
|
||||
| ---------------------------------------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
|
||||
| [ProtoBERT](https://arxiv.org/abs/2105.07464) | 20.76 ± 0.84 | 15.05 ± 0.44 | 42.54 ± 0.94 | 35.40 ± 0.13 | 38.83 ± 1.49 | 32.45 ± 0.79 | 58.79 ± 0.44 | 52.92 ± 0.37 |
|
||||
| [NNShot](https://arxiv.org/abs/2105.07464) | 25.78 ± 0.91 | 18.27 ± 0.41 | 36.18 ± 0.79 | 27.38 ± 0.53 | 47.24 ± 1.00 | 38.87 ± 0.21 | 55.64 ± 0.63 | 49.57 ± 2.73 |
|
||||
| [StructShot](https://arxiv.org/abs/2105.07464) | 30.21 ± 0.90 | 21.03 ± 1.13 | 38.00 ± 1.29 | 26.42 ± 0.60 | 51.88 ± 0.69 | 43.34 ± 0.10 | 57.32 ± 0.63 | 49.57 ± 3.08 |
|
||||
| [ESD](https://arxiv.org/abs/2109.13023) | 36.08 ± 1.60 | 30.00 ± 0.70 | 52.14 ± 1.50 | 42.15 ± 2.60 | 59.29 ± 1.25 | 52.16 ± 0.79 | 69.06 ± 0.80 | 64.00 ± 0.43 |
|
||||
| **DecomposedMetaNER** | **49.48** ± 0.85 | **42.84** ± 0.46 | **62.92** ± 0.57 | **57.31** ± 0.25 | **64.75** ± 0.35 | **58.65** ± 0.43 | **71.49** ± 0.47 | **68.11** ± 0.05 |
|
||||
|
||||
### Cross-Dataset
|
||||
|
||||
| | News 1-shot | Wiki 1-shot | Social 1-shot | Mixed 1-shot | News 5-shot | Wiki 5-shot | Social 5-shot | Mixed 5-shot |
|
||||
| ---------------------------------------------------------------------- | ---------------- | ---------------- | --------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
|
||||
| [TransferBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 4.75 ± 1.42 | 0.57 ± 0.32 | 2.71 ± 0.72 | 3.46 ± 0.54 | 15.36 ± 2.81 | 3.62 ± 0.57 | 11.08 ± 0.57 | 35.49 ± 7.60 |
|
||||
| [SimBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 19.22 ± 0.00 | 6.91 ± 0.00 | 5.18 ± 0.00 | 13.99 ± 0.00 | 32.01 ± 0.00 | 10.63 ± 0.00 | 8.20 ± 0.00 | 21.14 ± 0.00 |
|
||||
| [Matching Network](https://www.aclweb.org/anthology/2020.acl-main.128) | 19.50 ± 0.35 | 4.73 ± 0.16 | 17.23 ± 2.75 | 15.06 ± 1.61 | 19.85 ± 0.74 | 5.58 ± 0.23 | 6.61 ± 1.75 | 8.08 ± 0.47 |
|
||||
| [ProtoBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 32.49 ± 2.01 | 3.89 ± 0.24 | 10.68 ± 1.40 | 6.67 ± 0.46 | 50.06 ± 1.57 | 9.54 ± 0.44 | 17.26 ± 2.65 | 13.59 ± 1.61 |
|
||||
| [L-TapNet+CDT](https://www.aclweb.org/anthology/2020.acl-main.128) | 44.30 ± 3.15 | 12.04 ± 0.65 | 20.80 ± 1.06 | 15.17 ± 1.25 | 45.35 ± 2.67 | 11.65 ± 2.34 | 23.30 ± 2.80 | 20.95 ± 2.81 |
|
||||
| **DecomposedMetaNER** | **46.09** ± 0.44 | **17.54** ± 0.98 | **25.1** ± 0.24 | **34.13** ± 0.92 | **58.18** ± 0.87 | **31.36** ± 0.91 | **31.02** ± 1.28 | **45.55** ± 0.90 |
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,87 @@
|
|||
{
|
||||
"O": [
|
||||
"O"
|
||||
],
|
||||
"art": [
|
||||
"art-broadcastprogram",
|
||||
"art-film",
|
||||
"art-music",
|
||||
"art-other",
|
||||
"art-painting",
|
||||
"art-writtenart"
|
||||
],
|
||||
"building": [
|
||||
"building-airport",
|
||||
"building-hospital",
|
||||
"building-hotel",
|
||||
"building-library",
|
||||
"building-other",
|
||||
"building-restaurant",
|
||||
"building-sportsfacility",
|
||||
"building-theater"
|
||||
],
|
||||
"event": [
|
||||
"event-attack/battle/war/militaryconflict",
|
||||
"event-disaster",
|
||||
"event-election",
|
||||
"event-other",
|
||||
"event-protest",
|
||||
"event-sportsevent"
|
||||
],
|
||||
"location": [
|
||||
"location-GPE",
|
||||
"location-bodiesofwater",
|
||||
"location-island",
|
||||
"location-mountain",
|
||||
"location-other",
|
||||
"location-park",
|
||||
"location-road/railway/highway/transit"
|
||||
],
|
||||
"organization": [
|
||||
"organization-company",
|
||||
"organization-education",
|
||||
"organization-government/governmentagency",
|
||||
"organization-media/newspaper",
|
||||
"organization-other",
|
||||
"organization-politicalparty",
|
||||
"organization-religion",
|
||||
"organization-showorganization",
|
||||
"organization-sportsleague",
|
||||
"organization-sportsteam"
|
||||
],
|
||||
"other": [
|
||||
"other-astronomything",
|
||||
"other-award",
|
||||
"other-biologything",
|
||||
"other-chemicalthing",
|
||||
"other-currency",
|
||||
"other-disease",
|
||||
"other-educationaldegree",
|
||||
"other-god",
|
||||
"other-language",
|
||||
"other-law",
|
||||
"other-livingthing",
|
||||
"other-medical"
|
||||
],
|
||||
"person": [
|
||||
"person-actor",
|
||||
"person-artist/author",
|
||||
"person-athlete",
|
||||
"person-director",
|
||||
"person-other",
|
||||
"person-politician",
|
||||
"person-scholar",
|
||||
"person-soldier"
|
||||
],
|
||||
"product": [
|
||||
"product-airplane",
|
||||
"product-car",
|
||||
"product-food",
|
||||
"product-game",
|
||||
"product-other",
|
||||
"product-ship",
|
||||
"product-software",
|
||||
"product-train",
|
||||
"product-weapon"
|
||||
]
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
{
|
||||
"O": [
|
||||
"O"
|
||||
],
|
||||
"Wiki": [
|
||||
"abstract",
|
||||
"animal",
|
||||
"event",
|
||||
"object",
|
||||
"organization",
|
||||
"person",
|
||||
"place",
|
||||
"plant",
|
||||
"quantity",
|
||||
"substance",
|
||||
"time"
|
||||
],
|
||||
"SocialMedia": [
|
||||
"corporation",
|
||||
"creative-work",
|
||||
"group",
|
||||
"location",
|
||||
"person",
|
||||
"product"
|
||||
],
|
||||
"News": [
|
||||
"LOC",
|
||||
"MISC",
|
||||
"ORG",
|
||||
"PER"
|
||||
],
|
||||
"OntoNotes": [
|
||||
"CARDINAL",
|
||||
"DATE",
|
||||
"EVENT",
|
||||
"FAC",
|
||||
"GPE",
|
||||
"LANGUAGE",
|
||||
"LAW",
|
||||
"LOC",
|
||||
"MONEY",
|
||||
"NORP",
|
||||
"ORDINAL",
|
||||
"ORG",
|
||||
"PERCENT",
|
||||
"PERSON",
|
||||
"PRODUCT",
|
||||
"QUANTITY",
|
||||
"TIME",
|
||||
"WORK_OF_ART"
|
||||
]
|
||||
}
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 345 KiB |
|
@ -0,0 +1,803 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import joblib
|
||||
from modeling import BertForTokenClassification_
|
||||
from transformers import CONFIG_NAME, PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME
|
||||
from transformers import AdamW as BertAdam
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class Learner(nn.Module):
|
||||
ignore_token_label_id = torch.nn.CrossEntropyLoss().ignore_index
|
||||
pad_token_label_id = -1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bert_model,
|
||||
label_list,
|
||||
freeze_layer,
|
||||
logger,
|
||||
lr_meta,
|
||||
lr_inner,
|
||||
warmup_prop_meta,
|
||||
warmup_prop_inner,
|
||||
max_meta_steps,
|
||||
model_dir="",
|
||||
cache_dir="",
|
||||
gpu_no=0,
|
||||
py_alias="python",
|
||||
args=None,
|
||||
):
|
||||
|
||||
super(Learner, self).__init__()
|
||||
|
||||
self.lr_meta = lr_meta
|
||||
self.lr_inner = lr_inner
|
||||
self.warmup_prop_meta = warmup_prop_meta
|
||||
self.warmup_prop_inner = warmup_prop_inner
|
||||
self.max_meta_steps = max_meta_steps
|
||||
|
||||
self.bert_model = bert_model
|
||||
self.label_list = label_list
|
||||
self.py_alias = py_alias
|
||||
self.entity_types = args.entity_types
|
||||
self.is_debug = args.debug
|
||||
self.train_mode = args.train_mode
|
||||
self.eval_mode = args.eval_mode
|
||||
self.model_dir = model_dir
|
||||
self.args = args
|
||||
self.freeze_layer = freeze_layer
|
||||
|
||||
num_labels = len(label_list)
|
||||
|
||||
# load model
|
||||
if model_dir != "":
|
||||
if self.eval_mode != "two-stage":
|
||||
self.load_model(self.eval_mode)
|
||||
else:
|
||||
logger.info("********** Loading pre-trained model **********")
|
||||
cache_dir = cache_dir if cache_dir else str(PYTORCH_PRETRAINED_BERT_CACHE)
|
||||
self.model = BertForTokenClassification_.from_pretrained(
|
||||
bert_model,
|
||||
cache_dir=cache_dir,
|
||||
num_labels=num_labels,
|
||||
output_hidden_states=True,
|
||||
).to(args.device)
|
||||
|
||||
if self.eval_mode != "two-stage":
|
||||
self.model.set_config(
|
||||
args.use_classify,
|
||||
args.distance_mode,
|
||||
args.similar_k,
|
||||
args.shared_bert,
|
||||
self.train_mode,
|
||||
)
|
||||
self.model.to(args.device)
|
||||
self.layer_set()
|
||||
|
||||
def layer_set(self):
|
||||
# layer freezing
|
||||
no_grad_param_names = ["embeddings", "pooler"] + [
|
||||
"layer.{}.".format(i) for i in range(self.freeze_layer)
|
||||
]
|
||||
logger.info("The frozen parameters are:")
|
||||
for name, param in self.model.named_parameters():
|
||||
if any(no_grad_pn in name for no_grad_pn in no_grad_param_names):
|
||||
param.requires_grad = False
|
||||
logger.info(" {}".format(name))
|
||||
|
||||
self.opt = BertAdam(self.get_optimizer_grouped_parameters(), lr=self.lr_meta)
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt,
|
||||
num_warmup_steps=int(self.max_meta_steps * self.warmup_prop_meta),
|
||||
num_training_steps=self.max_meta_steps,
|
||||
)
|
||||
|
||||
def get_optimizer_grouped_parameters(self):
|
||||
param_optimizer = list(self.model.named_parameters())
|
||||
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in param_optimizer
|
||||
if not any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": 0.01,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in param_optimizer
|
||||
if any(nd in n for nd in no_decay) and p.requires_grad
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def get_names(self):
|
||||
names = [n for n, p in self.model.named_parameters() if p.requires_grad]
|
||||
return names
|
||||
|
||||
def get_params(self):
|
||||
params = [p for p in self.model.parameters() if p.requires_grad]
|
||||
return params
|
||||
|
||||
def load_weights(self, names, params):
|
||||
model_params = self.model.state_dict()
|
||||
for n, p in zip(names, params):
|
||||
model_params[n].data.copy_(p.data)
|
||||
|
||||
def load_gradients(self, names, grads):
|
||||
model_params = self.model.state_dict(keep_vars=True)
|
||||
for n, g in zip(names, grads):
|
||||
if model_params[n].grad is None:
|
||||
continue
|
||||
model_params[n].grad.data.add_(g.data) # accumulate
|
||||
|
||||
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
|
||||
if schedule == "linear":
|
||||
if progress < warmup:
|
||||
lr *= progress / warmup
|
||||
else:
|
||||
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
|
||||
return lr
|
||||
|
||||
def inner_update(self, data_support, lr_curr, inner_steps, no_grad: bool = False):
|
||||
inner_opt = BertAdam(self.get_optimizer_grouped_parameters(), lr=self.lr_inner)
|
||||
self.model.train()
|
||||
|
||||
for i in range(inner_steps):
|
||||
inner_opt.param_groups[0]["lr"] = lr_curr
|
||||
inner_opt.param_groups[1]["lr"] = lr_curr
|
||||
|
||||
inner_opt.zero_grad()
|
||||
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||
input_ids=data_support["input_ids"],
|
||||
attention_mask=data_support["input_mask"],
|
||||
token_type_ids=data_support["segment_ids"],
|
||||
labels=data_support["label_ids"],
|
||||
e_mask=data_support["e_mask"],
|
||||
e_type_ids=data_support["e_type_ids"],
|
||||
e_type_mask=data_support["e_type_mask"],
|
||||
entity_types=self.entity_types,
|
||||
is_update_type_embedding=True,
|
||||
lambda_max_loss=self.args.inner_lambda_max_loss,
|
||||
sim_k=self.args.inner_similar_k,
|
||||
)
|
||||
if loss is None:
|
||||
loss = type_loss
|
||||
elif type_loss is not None:
|
||||
loss = loss + type_loss
|
||||
if no_grad:
|
||||
continue
|
||||
loss.backward()
|
||||
inner_opt.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
def forward_supervise(self, batch_query, batch_support, progress, inner_steps):
|
||||
span_losses, type_losses = [], []
|
||||
task_num = len(batch_query)
|
||||
|
||||
for task_id in range(task_num):
|
||||
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||
input_ids=batch_query[task_id]["input_ids"],
|
||||
attention_mask=batch_query[task_id]["input_mask"],
|
||||
token_type_ids=batch_query[task_id]["segment_ids"],
|
||||
labels=batch_query[task_id]["label_ids"],
|
||||
e_mask=batch_query[task_id]["e_mask"],
|
||||
e_type_ids=batch_query[task_id]["e_type_ids"],
|
||||
e_type_mask=batch_query[task_id]["e_type_mask"],
|
||||
entity_types=self.entity_types,
|
||||
lambda_max_loss=self.args.lambda_max_loss,
|
||||
)
|
||||
if loss is not None:
|
||||
span_losses.append(loss.item())
|
||||
if type_loss is not None:
|
||||
type_losses.append(type_loss.item())
|
||||
if loss is None:
|
||||
loss = type_loss
|
||||
elif type_loss is not None:
|
||||
loss = loss + type_loss
|
||||
|
||||
loss.backward()
|
||||
self.opt.step()
|
||||
self.scheduler.step()
|
||||
self.model.zero_grad()
|
||||
|
||||
for task_id in range(task_num):
|
||||
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||
input_ids=batch_support[task_id]["input_ids"],
|
||||
attention_mask=batch_support[task_id]["input_mask"],
|
||||
token_type_ids=batch_support[task_id]["segment_ids"],
|
||||
labels=batch_support[task_id]["label_ids"],
|
||||
e_mask=batch_support[task_id]["e_mask"],
|
||||
e_type_ids=batch_support[task_id]["e_type_ids"],
|
||||
e_type_mask=batch_support[task_id]["e_type_mask"],
|
||||
entity_types=self.entity_types,
|
||||
lambda_max_loss=self.args.lambda_max_loss,
|
||||
)
|
||||
if loss is not None:
|
||||
span_losses.append(loss.item())
|
||||
if type_loss is not None:
|
||||
type_losses.append(type_loss.item())
|
||||
if loss is None:
|
||||
loss = type_loss
|
||||
elif type_loss is not None:
|
||||
loss = loss + type_loss
|
||||
|
||||
loss.backward()
|
||||
self.opt.step()
|
||||
self.scheduler.step()
|
||||
self.model.zero_grad()
|
||||
|
||||
return (
|
||||
np.mean(span_losses) if span_losses else 0,
|
||||
np.mean(type_losses) if type_losses else 0,
|
||||
)
|
||||
|
||||
def forward_meta(self, batch_query, batch_support, progress, inner_steps):
|
||||
names = self.get_names()
|
||||
params = self.get_params()
|
||||
weights = deepcopy(params)
|
||||
|
||||
meta_grad = []
|
||||
span_losses, type_losses = [], []
|
||||
|
||||
task_num = len(batch_query)
|
||||
lr_inner = self.get_learning_rate(
|
||||
self.lr_inner, progress, self.warmup_prop_inner
|
||||
)
|
||||
|
||||
# compute meta_grad of each task
|
||||
for task_id in range(task_num):
|
||||
self.inner_update(batch_support[task_id], lr_inner, inner_steps=inner_steps)
|
||||
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||
input_ids=batch_query[task_id]["input_ids"],
|
||||
attention_mask=batch_query[task_id]["input_mask"],
|
||||
token_type_ids=batch_query[task_id]["segment_ids"],
|
||||
labels=batch_query[task_id]["label_ids"],
|
||||
e_mask=batch_query[task_id]["e_mask"],
|
||||
e_type_ids=batch_query[task_id]["e_type_ids"],
|
||||
e_type_mask=batch_query[task_id]["e_type_mask"],
|
||||
entity_types=self.entity_types,
|
||||
lambda_max_loss=self.args.lambda_max_loss,
|
||||
)
|
||||
if loss is not None:
|
||||
span_losses.append(loss.item())
|
||||
if type_loss is not None:
|
||||
type_losses.append(type_loss.item())
|
||||
if loss is None:
|
||||
loss = type_loss
|
||||
elif type_loss is not None:
|
||||
loss = loss + type_loss
|
||||
grad = torch.autograd.grad(loss, params)
|
||||
meta_grad.append(grad)
|
||||
|
||||
self.load_weights(names, weights)
|
||||
|
||||
# accumulate grads of all tasks to param.grad
|
||||
self.opt.zero_grad()
|
||||
|
||||
# similar to backward()
|
||||
for g in meta_grad:
|
||||
self.load_gradients(names, g)
|
||||
self.opt.step()
|
||||
self.scheduler.step()
|
||||
|
||||
return (
|
||||
np.mean(span_losses) if span_losses else 0,
|
||||
np.mean(type_losses) if type_losses else 0,
|
||||
)
|
||||
|
||||
# ---------------------------------------- Evaluation -------------------------------------- #
|
||||
def write_result(self, words, y_true, y_pred, tmp_fn):
|
||||
assert len(y_pred) == len(y_true)
|
||||
with open(tmp_fn, "w", encoding="utf-8") as fw:
|
||||
for i, sent in enumerate(y_true):
|
||||
for j, word in enumerate(sent):
|
||||
fw.write("{} {} {}\n".format(words[i][j], word, y_pred[i][j]))
|
||||
fw.write("\n")
|
||||
|
||||
def batch_test(self, data):
|
||||
N = data["input_ids"].shape[0]
|
||||
B = 16
|
||||
BATCH_KEY = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"token_type_ids",
|
||||
"labels",
|
||||
"e_mask",
|
||||
"e_type_ids",
|
||||
"e_type_mask",
|
||||
]
|
||||
|
||||
logits, e_logits, loss, type_loss = [], [], 0, 0
|
||||
for i in range((N - 1) // B + 1):
|
||||
tmp = {
|
||||
ii: jj if ii not in BATCH_KEY else jj[i * B : (i + 1) * B]
|
||||
for ii, jj in data.items()
|
||||
}
|
||||
tmp_l, tmp_el, tmp_loss, tmp_eval_type_loss = self.model.forward_wuqh(**tmp)
|
||||
if tmp_l is not None:
|
||||
logits.extend(tmp_l.detach().cpu().numpy())
|
||||
if tmp_el is not None:
|
||||
e_logits.extend(tmp_el.detach().cpu().numpy())
|
||||
if tmp_loss is not None:
|
||||
loss += tmp_loss
|
||||
if tmp_eval_type_loss is not None:
|
||||
type_loss += tmp_eval_type_loss
|
||||
return logits, e_logits, loss, type_loss
|
||||
|
||||
def evaluate_meta_(
|
||||
self,
|
||||
corpus,
|
||||
logger,
|
||||
lr,
|
||||
steps,
|
||||
mode,
|
||||
set_type,
|
||||
type_steps: int = None,
|
||||
viterbi_decoder=None,
|
||||
):
|
||||
if not type_steps:
|
||||
type_steps = steps
|
||||
if self.is_debug:
|
||||
self.save_model(self.args.result_dir, "begin", self.args.max_seq_len, "all")
|
||||
|
||||
logger.info("Begin first Stage.")
|
||||
if self.eval_mode == "two-stage":
|
||||
self.load_model("span")
|
||||
names = self.get_names()
|
||||
params = self.get_params()
|
||||
weights = deepcopy(params)
|
||||
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
t_tmp = time.time()
|
||||
targets, predes, spans, lss, type_preds, type_g = [], [], [], [], [], []
|
||||
for item_id in range(corpus.n_total):
|
||||
eval_query, eval_support = corpus.get_batch_meta(
|
||||
batch_size=1, shuffle=False
|
||||
)
|
||||
|
||||
# train on support examples
|
||||
if not self.args.nouse_inner_ft:
|
||||
self.inner_update(eval_support[0], lr_curr=lr, inner_steps=steps)
|
||||
|
||||
# eval on pseudo query examples (test example)
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
logits, e_ls, tmp_eval_loss, _ = self.batch_test(
|
||||
{
|
||||
"input_ids": eval_query[0]["input_ids"],
|
||||
"attention_mask": eval_query[0]["input_mask"],
|
||||
"token_type_ids": eval_query[0]["segment_ids"],
|
||||
"labels": eval_query[0]["label_ids"],
|
||||
"e_mask": eval_query[0]["e_mask"],
|
||||
"e_type_ids": eval_query[0]["e_type_ids"],
|
||||
"e_type_mask": eval_query[0]["e_type_mask"],
|
||||
"entity_types": self.entity_types,
|
||||
}
|
||||
)
|
||||
lss.append(logits)
|
||||
if self.model.train_mode != "type":
|
||||
eval_loss += tmp_eval_loss
|
||||
if self.model.train_mode != "span":
|
||||
type_pred, type_ground = self.eval_typing(
|
||||
e_ls, eval_query[0]["e_type_mask"]
|
||||
)
|
||||
type_preds.append(type_pred)
|
||||
type_g.append(type_ground)
|
||||
else:
|
||||
e_mask, e_type_ids, e_type_mask, result, types = self.decode_span(
|
||||
logits,
|
||||
eval_query[0]["label_ids"],
|
||||
eval_query[0]["types"],
|
||||
eval_query[0]["input_mask"],
|
||||
viterbi_decoder,
|
||||
)
|
||||
targets.extend(eval_query[0]["entities"])
|
||||
spans.extend(result)
|
||||
|
||||
nb_eval_steps += 1
|
||||
|
||||
self.load_weights(names, weights)
|
||||
if item_id % 200 == 0:
|
||||
logger.info(
|
||||
" To sentence {}/{}. Time: {}sec".format(
|
||||
item_id, corpus.n_total, time.time() - t_tmp
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Begin second Stage.")
|
||||
if self.eval_mode == "two-stage":
|
||||
self.load_model("type")
|
||||
names = self.get_names()
|
||||
params = self.get_params()
|
||||
weights = deepcopy(params)
|
||||
|
||||
if self.train_mode == "add":
|
||||
for item_id in range(corpus.n_total):
|
||||
eval_query, eval_support = corpus.get_batch_meta(
|
||||
batch_size=1, shuffle=False
|
||||
)
|
||||
logits = lss[item_id]
|
||||
|
||||
# train on support examples
|
||||
self.inner_update(eval_support[0], lr_curr=lr, inner_steps=type_steps)
|
||||
|
||||
# eval on pseudo query examples (test example)
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
e_mask, e_type_ids, e_type_mask, result, types = self.decode_span(
|
||||
logits,
|
||||
eval_query[0]["label_ids"],
|
||||
eval_query[0]["types"],
|
||||
eval_query[0]["input_mask"],
|
||||
viterbi_decoder,
|
||||
)
|
||||
|
||||
_, e_logits, _, tmp_eval_type_loss = self.batch_test(
|
||||
{
|
||||
"input_ids": eval_query[0]["input_ids"],
|
||||
"attention_mask": eval_query[0]["input_mask"],
|
||||
"token_type_ids": eval_query[0]["segment_ids"],
|
||||
"labels": eval_query[0]["label_ids"],
|
||||
"e_mask": e_mask,
|
||||
"e_type_ids": e_type_ids,
|
||||
"e_type_mask": e_type_mask,
|
||||
"entity_types": self.entity_types,
|
||||
}
|
||||
)
|
||||
|
||||
eval_loss += tmp_eval_type_loss
|
||||
|
||||
if self.eval_mode == "two-stage":
|
||||
logits, e_ls, tmp_eval_loss, _ = self.batch_test(
|
||||
{
|
||||
"input_ids": eval_query[0]["input_ids"],
|
||||
"attention_mask": eval_query[0]["input_mask"],
|
||||
"token_type_ids": eval_query[0]["segment_ids"],
|
||||
"labels": eval_query[0]["label_ids"],
|
||||
"e_mask": eval_query[0]["e_mask"],
|
||||
"e_type_ids": eval_query[0]["e_type_ids"],
|
||||
"e_type_mask": eval_query[0]["e_type_mask"],
|
||||
"entity_types": self.entity_types,
|
||||
}
|
||||
)
|
||||
|
||||
type_pred, type_ground = self.eval_typing(
|
||||
e_ls, eval_query[0]["e_type_mask"]
|
||||
)
|
||||
type_preds.append(type_pred)
|
||||
type_g.append(type_ground)
|
||||
taregt, p = self.decode_entity(
|
||||
e_logits, result, types, eval_query[0]["entities"]
|
||||
)
|
||||
predes.extend(p)
|
||||
|
||||
self.load_weights(names, weights)
|
||||
if item_id % 200 == 0:
|
||||
logger.info(
|
||||
" To sentence {}/{}. Time: {}sec".format(
|
||||
item_id, corpus.n_total, time.time() - t_tmp
|
||||
)
|
||||
)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
if self.is_debug:
|
||||
joblib.dump([targets, predes, spans], "debug/f1.pkl")
|
||||
store_dir = self.args.model_dir if self.args.model_dir else self.args.result_dir
|
||||
joblib.dump(
|
||||
[targets, predes, spans],
|
||||
"{}/{}_{}_preds.pkl".format(store_dir, "all", set_type),
|
||||
)
|
||||
joblib.dump(
|
||||
[type_g, type_preds],
|
||||
"{}/{}_{}_preds.pkl".format(store_dir, "typing", set_type),
|
||||
)
|
||||
pred = [[jj[:-1] for jj in ii] for ii in predes]
|
||||
p, r, f1 = self.cacl_f1(targets, pred)
|
||||
pred = [
|
||||
[jj[:-1] for jj in ii if jj[-1] > self.args.type_threshold] for ii in predes
|
||||
]
|
||||
p_t, r_t, f1_t = self.cacl_f1(targets, pred)
|
||||
|
||||
span_p, span_r, span_f1 = self.cacl_f1(
|
||||
[[(jj[0], jj[1]) for jj in ii] for ii in targets], spans
|
||||
)
|
||||
type_p, type_r, type_f1 = self.cacl_f1(type_g, type_preds)
|
||||
|
||||
results = {
|
||||
"loss": eval_loss,
|
||||
"precision": p,
|
||||
"recall": r,
|
||||
"f1": f1,
|
||||
"span_p": span_p,
|
||||
"span_r": span_r,
|
||||
"span_f1": span_f1,
|
||||
"type_p": type_p,
|
||||
"type_r": type_r,
|
||||
"type_f1": type_f1,
|
||||
"precision_threshold": p_t,
|
||||
"recall_threshold": r_t,
|
||||
"f1_threshold": f1_t,
|
||||
}
|
||||
|
||||
logger.info("***** Eval results %s-%s *****", mode, set_type)
|
||||
for key in sorted(results.keys()):
|
||||
logger.info(" %s = %s", key, str(results[key]))
|
||||
logger.info(
|
||||
"%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f",
|
||||
results["precision"] * 100,
|
||||
results["recall"] * 100,
|
||||
results["f1"] * 100,
|
||||
results["span_p"] * 100,
|
||||
results["span_r"] * 100,
|
||||
results["span_f1"] * 100,
|
||||
results["type_p"] * 100,
|
||||
results["type_r"] * 100,
|
||||
results["type_f1"] * 100,
|
||||
results["precision_threshold"] * 100,
|
||||
results["recall_threshold"] * 100,
|
||||
results["f1_threshold"] * 100,
|
||||
)
|
||||
|
||||
return results, preds
|
||||
|
||||
def save_model(self, result_dir, fn_prefix, max_seq_len, mode: str = "all"):
|
||||
# Save a trained model and the associated configuration
|
||||
model_to_save = (
|
||||
self.model.module if hasattr(self.model, "module") else self.model
|
||||
) # Only save the model it-self
|
||||
output_model_file = os.path.join(
|
||||
result_dir, "{}_{}_{}".format(fn_prefix, mode, WEIGHTS_NAME)
|
||||
)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
output_config_file = os.path.join(result_dir, CONFIG_NAME)
|
||||
with open(output_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(model_to_save.config.to_json_string())
|
||||
label_map = {i: label for i, label in enumerate(self.label_list, 1)}
|
||||
model_config = {
|
||||
"bert_model": self.bert_model,
|
||||
"do_lower": False,
|
||||
"max_seq_length": max_seq_len,
|
||||
"num_labels": len(self.label_list) + 1,
|
||||
"label_map": label_map,
|
||||
}
|
||||
json.dump(
|
||||
model_config,
|
||||
open(
|
||||
os.path.join(result_dir, f"{mode}-model_config.json"),
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
),
|
||||
)
|
||||
if mode == "type":
|
||||
joblib.dump(
|
||||
self.entity_types, os.path.join(result_dir, "type_embedding.pkl")
|
||||
)
|
||||
|
||||
def save_best_model(self, result_dir: str, mode: str):
|
||||
output_model_file = os.path.join(result_dir, "en_tmp_{}".format(WEIGHTS_NAME))
|
||||
config_name = os.path.join(result_dir, "tmp-model_config.json")
|
||||
shutil.copy(output_model_file, output_model_file.replace("tmp", mode))
|
||||
shutil.copy(config_name, config_name.replace("tmp", mode))
|
||||
|
||||
def load_model(self, mode: str = "all"):
|
||||
if not self.model_dir:
|
||||
return
|
||||
model_dir = self.model_dir
|
||||
logger.info(f"********** Loading saved {mode} model **********")
|
||||
output_model_file = os.path.join(
|
||||
model_dir, "en_{}_{}".format(mode, WEIGHTS_NAME)
|
||||
)
|
||||
self.model = BertForTokenClassification_.from_pretrained(
|
||||
self.bert_model, num_labels=len(self.label_list), output_hidden_states=True
|
||||
)
|
||||
self.model.set_config(
|
||||
self.args.use_classify,
|
||||
self.args.distance_mode,
|
||||
self.args.similar_k,
|
||||
self.args.shared_bert,
|
||||
mode,
|
||||
)
|
||||
self.model.to(self.args.device)
|
||||
self.model.load_state_dict(torch.load(output_model_file, map_location="cuda"))
|
||||
self.layer_set()
|
||||
|
||||
def decode_span(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
types,
|
||||
mask: torch.Tensor,
|
||||
viterbi_decoder=None,
|
||||
):
|
||||
if self.is_debug:
|
||||
joblib.dump([logits, target, self.label_list], "debug/span.pkl")
|
||||
device = target.device
|
||||
K = max([len(ii) for ii in types])
|
||||
if viterbi_decoder:
|
||||
N = target.shape[0]
|
||||
B = 16
|
||||
result = []
|
||||
for i in range((N - 1) // B + 1):
|
||||
tmp_logits = torch.tensor(logits[i * B : (i + 1) * B]).to(target.device)
|
||||
if len(tmp_logits.shape) == 2:
|
||||
tmp_logits = tmp_logits.unsqueeze(0)
|
||||
tmp_target = target[i * B : (i + 1) * B]
|
||||
log_probs = nn.functional.log_softmax(
|
||||
tmp_logits.detach(), dim=-1
|
||||
) # batch_size x max_seq_len x n_labels
|
||||
pred_labels = viterbi_decoder.forward(
|
||||
log_probs, mask[i * B : (i + 1) * B], tmp_target
|
||||
)
|
||||
|
||||
for ii, jj in zip(pred_labels, tmp_target.detach().cpu().numpy()):
|
||||
left, right, tmp = 0, 0, []
|
||||
while right < len(jj) and jj[right] == self.ignore_token_label_id:
|
||||
tmp.append(-1)
|
||||
right += 1
|
||||
while left < len(ii):
|
||||
tmp.append(ii[left])
|
||||
left += 1
|
||||
right += 1
|
||||
while (
|
||||
right < len(jj) and jj[right] == self.ignore_token_label_id
|
||||
):
|
||||
tmp.append(-1)
|
||||
right += 1
|
||||
result.append(tmp)
|
||||
target = target.detach().cpu().numpy()
|
||||
B, T = target.shape
|
||||
if not viterbi_decoder:
|
||||
logits = logits.detach().cpu().numpy()
|
||||
result = np.argmax(logits, -1)
|
||||
|
||||
if self.label_list == ["O", "B", "I"]:
|
||||
res = []
|
||||
for ii in range(B):
|
||||
tmp, idx = [], 0
|
||||
max_pad = T - 1
|
||||
while (
|
||||
max_pad > 0 and target[ii][max_pad - 1] == self.pad_token_label_id
|
||||
):
|
||||
max_pad -= 1
|
||||
while idx < max_pad:
|
||||
if target[ii][idx] == self.ignore_token_label_id or (
|
||||
result[ii][idx] != 1
|
||||
):
|
||||
idx += 1
|
||||
continue
|
||||
e = idx
|
||||
while e < max_pad - 1 and (
|
||||
target[ii][e + 1] == self.ignore_token_label_id
|
||||
or result[ii][e + 1] in [self.ignore_token_label_id, 2]
|
||||
):
|
||||
e += 1
|
||||
tmp.append((idx, e))
|
||||
idx = e + 1
|
||||
res.append(tmp)
|
||||
elif self.label_list == ["O", "B", "I", "E", "S"]:
|
||||
res = []
|
||||
for ii in range(B):
|
||||
tmp, idx = [], 0
|
||||
max_pad = T - 1
|
||||
while (
|
||||
max_pad > 0 and target[ii][max_pad - 1] == self.pad_token_label_id
|
||||
):
|
||||
max_pad -= 1
|
||||
while idx < max_pad:
|
||||
if target[ii][idx] == self.ignore_token_label_id or (
|
||||
result[ii][idx] not in [1, 4]
|
||||
):
|
||||
idx += 1
|
||||
continue
|
||||
e = idx
|
||||
while (
|
||||
e < max_pad - 1
|
||||
and result[ii][e] not in [3, 4]
|
||||
and (
|
||||
target[ii][e + 1] == self.ignore_token_label_id
|
||||
or result[ii][e + 1] in [self.ignore_token_label_id, 2, 3]
|
||||
)
|
||||
):
|
||||
e += 1
|
||||
if e < max_pad and result[ii][e] in [3, 4]:
|
||||
while (
|
||||
e < max_pad - 1
|
||||
and target[ii][e + 1] == self.ignore_token_label_id
|
||||
):
|
||||
e += 1
|
||||
tmp.append((idx, e))
|
||||
idx = e + 1
|
||||
res.append(tmp)
|
||||
M = max([len(ii) for ii in res])
|
||||
e_mask = np.zeros((B, M, T), np.int8)
|
||||
e_type_mask = np.zeros((B, M, K), np.int8)
|
||||
e_type_ids = np.zeros((B, M, K), np.int)
|
||||
for ii in range(B):
|
||||
for idx, (s, e) in enumerate(res[ii]):
|
||||
e_mask[ii][idx][s : e + 1] = 1
|
||||
types_set = types[ii]
|
||||
if len(res[ii]):
|
||||
e_type_ids[ii, : len(res[ii]), : len(types_set)] = [types_set] * len(
|
||||
res[ii]
|
||||
)
|
||||
e_type_mask[ii, : len(res[ii]), : len(types_set)] = np.ones(
|
||||
(len(res[ii]), len(types_set))
|
||||
)
|
||||
return (
|
||||
torch.tensor(e_mask).to(device),
|
||||
torch.tensor(e_type_ids, dtype=torch.long).to(device),
|
||||
torch.tensor(e_type_mask).to(device),
|
||||
res,
|
||||
types,
|
||||
)
|
||||
|
||||
def decode_entity(self, e_logits, result, types, entities):
|
||||
if self.is_debug:
|
||||
joblib.dump([e_logits, result, types, entities], "debug/e.pkl")
|
||||
target, preds = entities, []
|
||||
B = len(e_logits)
|
||||
logits = e_logits
|
||||
|
||||
for ii in range(B):
|
||||
tmp = []
|
||||
tmp_res = result[ii]
|
||||
tmp_types = types[ii]
|
||||
for jj in range(len(tmp_res)):
|
||||
lg = logits[ii][jj, : len(tmp_types)]
|
||||
tmp.append((*tmp_res[jj], tmp_types[np.argmax(lg)], lg[np.argmax(lg)]))
|
||||
preds.append(tmp)
|
||||
return target, preds
|
||||
|
||||
def cacl_f1(self, targets: list, predes: list):
|
||||
tp, fp, fn = 0, 0, 0
|
||||
for ii, jj in zip(targets, predes):
|
||||
ii, jj = set(ii), set(jj)
|
||||
same = ii - (ii - jj)
|
||||
tp += len(same)
|
||||
fn += len(ii - jj)
|
||||
fp += len(jj - ii)
|
||||
p = tp / (fp + tp + 1e-10)
|
||||
r = tp / (fn + tp + 1e-10)
|
||||
return p, r, 2 * p * r / (p + r + 1e-10)
|
||||
|
||||
def eval_typing(self, e_logits, e_type_mask):
|
||||
e_logits = e_logits
|
||||
e_type_mask = e_type_mask.detach().cpu().numpy()
|
||||
if self.is_debug:
|
||||
joblib.dump([e_logits, e_type_mask], "debug/typing.pkl")
|
||||
|
||||
N = len(e_logits)
|
||||
B_S = 16
|
||||
res = []
|
||||
for i in range((N - 1) // B_S + 1):
|
||||
tmp_e_logits = np.argmax(e_logits[i * B_S : (i + 1) * B_S], -1)
|
||||
B, M = tmp_e_logits.shape
|
||||
tmp_e_type_mask = e_type_mask[i * B_S : (i + 1) * B_S][:, :M, 0]
|
||||
res.extend(tmp_e_logits[tmp_e_type_mask == 1])
|
||||
ground = [0] * len(res)
|
||||
return enumerate(res), enumerate(ground)
|
|
@ -0,0 +1,661 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
import joblib
|
||||
from learner import Learner
|
||||
from modeling import ViterbiDecoder
|
||||
from preprocessor import Corpus, EntityTypes
|
||||
from utils import set_seed
|
||||
|
||||
|
||||
def get_label_list(args):
|
||||
# prepare dataset
|
||||
if args.tagging_scheme == "BIOES":
|
||||
label_list = ["O", "B", "I", "E", "S"]
|
||||
elif args.tagging_scheme == "BIO":
|
||||
label_list = ["O", "B", "I"]
|
||||
else:
|
||||
label_list = ["O", "I"]
|
||||
return label_list
|
||||
|
||||
|
||||
def get_data_path(agrs, train_mode: str):
|
||||
assert args.dataset in [
|
||||
"FewNERD",
|
||||
"Domain",
|
||||
"Domain2",
|
||||
], f"Dataset: {args.dataset} Not Support."
|
||||
if args.dataset == "FewNERD":
|
||||
return os.path.join(
|
||||
args.data_path,
|
||||
args.mode,
|
||||
"{}_{}_{}.jsonl".format(train_mode, args.N, args.K),
|
||||
)
|
||||
elif args.dataset == "Domain":
|
||||
if train_mode == "dev":
|
||||
train_mode = "valid"
|
||||
text = "_shot_5" if args.K == 5 else ""
|
||||
replace_text = "-" if args.K == 5 else "_"
|
||||
return os.path.join(
|
||||
"ACL2020data",
|
||||
"xval_ner{}".format(text),
|
||||
"ner_{}_{}{}.json".format(train_mode, args.N, text).replace(
|
||||
"_", replace_text
|
||||
),
|
||||
)
|
||||
elif args.dataset == "Domain2":
|
||||
if train_mode == "train":
|
||||
return os.path.join("domain2", "{}_10_5.json".format(train_mode))
|
||||
return os.path.join(
|
||||
"domain2", "{}_{}_{}.json".format(train_mode, args.mode, args.K)
|
||||
)
|
||||
|
||||
|
||||
def replace_type_embedding(learner, args):
|
||||
logger.info("********** Replace trained type embedding **********")
|
||||
entity_types = joblib.load(os.path.join(args.result_dir, "type_embedding.pkl"))
|
||||
N, H = entity_types.types_embedding.weight.data.shape
|
||||
for ii in range(N):
|
||||
learner.models.embeddings.word_embeddings.weight.data[
|
||||
ii + 1
|
||||
] = entity_types.types_embedding.weight.data[ii]
|
||||
|
||||
|
||||
def train_meta(args):
|
||||
logger.info("********** Scheme: Meta Learning **********")
|
||||
label_list = get_label_list(args)
|
||||
|
||||
valid_data_path = get_data_path(args, "dev")
|
||||
valid_corpus = Corpus(
|
||||
logger,
|
||||
valid_data_path,
|
||||
args.bert_model,
|
||||
args.max_seq_len,
|
||||
label_list,
|
||||
args.entity_types,
|
||||
do_lower_case=True,
|
||||
shuffle=False,
|
||||
viterbi=args.viterbi,
|
||||
tagging=args.tagging_scheme,
|
||||
device=args.device,
|
||||
concat_types=args.concat_types,
|
||||
dataset=args.dataset,
|
||||
)
|
||||
|
||||
if args.debug:
|
||||
test_corpus = valid_corpus
|
||||
train_corpus = valid_corpus
|
||||
else:
|
||||
train_data_path = get_data_path(args, "train")
|
||||
train_corpus = Corpus(
|
||||
logger,
|
||||
train_data_path,
|
||||
args.bert_model,
|
||||
args.max_seq_len,
|
||||
label_list,
|
||||
args.entity_types,
|
||||
do_lower_case=True,
|
||||
shuffle=True,
|
||||
tagging=args.tagging_scheme,
|
||||
device=args.device,
|
||||
concat_types=args.concat_types,
|
||||
dataset=args.dataset,
|
||||
)
|
||||
|
||||
if not args.ignore_eval_test:
|
||||
test_data_path = get_data_path(args, "test")
|
||||
test_corpus = Corpus(
|
||||
logger,
|
||||
test_data_path,
|
||||
args.bert_model,
|
||||
args.max_seq_len,
|
||||
label_list,
|
||||
args.entity_types,
|
||||
do_lower_case=True,
|
||||
shuffle=False,
|
||||
viterbi=args.viterbi,
|
||||
tagging=args.tagging_scheme,
|
||||
device=args.device,
|
||||
concat_types=args.concat_types,
|
||||
dataset=args.dataset,
|
||||
)
|
||||
|
||||
learner = Learner(
|
||||
args.bert_model,
|
||||
label_list,
|
||||
args.freeze_layer,
|
||||
logger,
|
||||
args.lr_meta,
|
||||
args.lr_inner,
|
||||
args.warmup_prop_meta,
|
||||
args.warmup_prop_inner,
|
||||
args.max_meta_steps,
|
||||
py_alias=args.py_alias,
|
||||
args=args,
|
||||
)
|
||||
|
||||
if "embedding" in args.concat_types:
|
||||
replace_type_embedding(learner, args)
|
||||
|
||||
t = time.time()
|
||||
F1_valid_best = {ii: -1.0 for ii in ["all", "type", "span"]}
|
||||
F1_test = -1.0
|
||||
best_step, protect_step = -1.0, 100 if args.train_mode != "type" else 50
|
||||
|
||||
for step in range(args.max_meta_steps):
|
||||
progress = 1.0 * step / args.max_meta_steps
|
||||
|
||||
batch_query, batch_support = train_corpus.get_batch_meta(
|
||||
batch_size=args.inner_size
|
||||
) # (batch_size=32)
|
||||
if args.use_supervise:
|
||||
span_loss, type_loss = learner.forward_supervise(
|
||||
batch_query,
|
||||
batch_support,
|
||||
progress=progress,
|
||||
inner_steps=args.inner_steps,
|
||||
)
|
||||
else:
|
||||
span_loss, type_loss = learner.forward_meta(
|
||||
batch_query,
|
||||
batch_support,
|
||||
progress=progress,
|
||||
inner_steps=args.inner_steps,
|
||||
)
|
||||
|
||||
if step % 20 == 0:
|
||||
logger.info(
|
||||
"Step: {}/{}, span loss = {:.6f}, type loss = {:.6f}, time = {:.2f}s.".format(
|
||||
step, args.max_meta_steps, span_loss, type_loss, time.time() - t
|
||||
)
|
||||
)
|
||||
|
||||
if step % args.eval_every_meta_steps == 0 and step > protect_step:
|
||||
logger.info("********** Scheme: evaluate - [valid] **********")
|
||||
result_valid, predictions_valid = test(args, learner, valid_corpus, "valid")
|
||||
|
||||
F1_valid = result_valid["f1"]
|
||||
is_best = False
|
||||
if F1_valid > F1_valid_best["all"]:
|
||||
logger.info("===> Best Valid F1: {}".format(F1_valid))
|
||||
logger.info(" Saving model...")
|
||||
learner.save_model(args.result_dir, "en", args.max_seq_len, "all")
|
||||
F1_valid_best["all"] = F1_valid
|
||||
best_step = step
|
||||
is_best = True
|
||||
|
||||
if (
|
||||
result_valid["span_f1"] > F1_valid_best["span"]
|
||||
and args.train_mode != "type"
|
||||
):
|
||||
F1_valid_best["span"] = result_valid["span_f1"]
|
||||
learner.save_model(args.result_dir, "en", args.max_seq_len, "span")
|
||||
logger.info("Best Span Store {}".format(step))
|
||||
is_best = True
|
||||
if (
|
||||
result_valid["type_f1"] > F1_valid_best["type"]
|
||||
and args.train_mode != "span"
|
||||
):
|
||||
F1_valid_best["type"] = result_valid["type_f1"]
|
||||
learner.save_model(args.result_dir, "en", args.max_seq_len, "type")
|
||||
logger.info("Best Type Store {}".format(step))
|
||||
is_best = True
|
||||
|
||||
if is_best and not args.ignore_eval_test:
|
||||
logger.info("********** Scheme: evaluate - [test] **********")
|
||||
result_test, predictions_test = test(args, learner, test_corpus, "test")
|
||||
|
||||
F1_test = result_test["f1"]
|
||||
logger.info(
|
||||
"Best Valid F1: {}, Step: {}".format(F1_valid_best, best_step)
|
||||
)
|
||||
logger.info("Test F1: {}".format(F1_test))
|
||||
|
||||
|
||||
def test(args, learner, corpus, types: str):
|
||||
if corpus.viterbi != "none":
|
||||
id2label = corpus.id2label
|
||||
transition_matrix = corpus.transition_matrix
|
||||
if args.viterbi == "soft":
|
||||
label_list = get_label_list(args)
|
||||
train_data_path = get_data_path(args, "train")
|
||||
train_corpus = Corpus(
|
||||
logger,
|
||||
train_data_path,
|
||||
args.bert_model,
|
||||
args.max_seq_len,
|
||||
label_list,
|
||||
args.entity_types,
|
||||
do_lower_case=True,
|
||||
shuffle=True,
|
||||
tagging=args.tagging_scheme,
|
||||
viterbi="soft",
|
||||
device=args.device,
|
||||
concat_types=args.concat_types,
|
||||
dataset=args.dataset,
|
||||
)
|
||||
id2label = train_corpus.id2label
|
||||
transition_matrix = train_corpus.transition_matrix
|
||||
|
||||
viterbi_decoder = ViterbiDecoder(id2label, transition_matrix)
|
||||
else:
|
||||
viterbi_decoder = None
|
||||
result_test, predictions = learner.evaluate_meta_(
|
||||
corpus,
|
||||
logger,
|
||||
lr=args.lr_finetune,
|
||||
steps=args.max_ft_steps,
|
||||
mode=args.mode,
|
||||
set_type=types,
|
||||
type_steps=args.max_type_ft_steps,
|
||||
viterbi_decoder=viterbi_decoder,
|
||||
)
|
||||
return result_test, predictions
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
logger.info("********** Scheme: Meta Test **********")
|
||||
label_list = get_label_list(args)
|
||||
|
||||
valid_data_path = get_data_path(args, "dev")
|
||||
valid_corpus = Corpus(
|
||||
logger,
|
||||
valid_data_path,
|
||||
args.bert_model,
|
||||
args.max_seq_len,
|
||||
label_list,
|
||||
args.entity_types,
|
||||
do_lower_case=True,
|
||||
shuffle=False,
|
||||
tagging=args.tagging_scheme,
|
||||
viterbi=args.viterbi,
|
||||
concat_types=args.concat_types,
|
||||
dataset=args.dataset,
|
||||
)
|
||||
|
||||
test_data_path = get_data_path(args, "test")
|
||||
test_corpus = Corpus(
|
||||
logger,
|
||||
test_data_path,
|
||||
args.bert_model,
|
||||
args.max_seq_len,
|
||||
label_list,
|
||||
args.entity_types,
|
||||
do_lower_case=True,
|
||||
shuffle=False,
|
||||
tagging=args.tagging_scheme,
|
||||
viterbi=args.viterbi,
|
||||
concat_types=args.concat_types,
|
||||
dataset=args.dataset,
|
||||
)
|
||||
|
||||
learner = Learner(
|
||||
args.bert_model,
|
||||
label_list,
|
||||
args.freeze_layer,
|
||||
logger,
|
||||
args.lr_meta,
|
||||
args.lr_inner,
|
||||
args.warmup_prop_meta,
|
||||
args.warmup_prop_inner,
|
||||
args.max_meta_steps,
|
||||
model_dir=args.model_dir,
|
||||
py_alias=args.py_alias,
|
||||
args=args,
|
||||
)
|
||||
|
||||
logger.info("********** Scheme: evaluate - [valid] **********")
|
||||
test(args, learner, valid_corpus, "valid")
|
||||
|
||||
logger.info("********** Scheme: evaluate - [test] **********")
|
||||
test(args, learner, test_corpus, "test")
|
||||
|
||||
|
||||
def convert_bpe(args):
|
||||
def convert_base(train_mode: str):
|
||||
data_path = get_data_path(args, train_mode)
|
||||
corpus = Corpus(
|
||||
logger,
|
||||
data_path,
|
||||
args.bert_model,
|
||||
args.max_seq_len,
|
||||
label_list,
|
||||
args.entity_types,
|
||||
do_lower_case=True,
|
||||
shuffle=False,
|
||||
tagging=args.tagging_scheme,
|
||||
viterbi=args.viterbi,
|
||||
concat_types=args.concat_types,
|
||||
dataset=args.dataset,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
for seed in [171, 354, 550, 667, 985]:
|
||||
path = os.path.join(
|
||||
args.model_dir,
|
||||
f"all_{train_mode if train_mode == 'test' else 'valid'}_preds.pkl",
|
||||
).replace("171", str(seed))
|
||||
data = joblib.load(path)
|
||||
if len(data) == 3:
|
||||
spans = data[-1]
|
||||
else:
|
||||
spans = [[jj[:-2] for jj in ii] for ii in data[-1]]
|
||||
target = [[jj[:-1] for jj in ii] for ii in data[0]]
|
||||
|
||||
res = corpus._decoder_bpe_index(spans)
|
||||
target = corpus._decoder_bpe_index(target)
|
||||
with open(
|
||||
f"preds/{args.mode}-{args.N}way{args.K}shot-seed{seed}-{train_mode}.jsonl",
|
||||
"w",
|
||||
) as f:
|
||||
json.dump(res, f)
|
||||
if seed != 171:
|
||||
continue
|
||||
with open(
|
||||
f"preds/{args.mode}-{args.N}way{args.K}shot-seed{seed}-{train_mode}_golden.jsonl",
|
||||
"w",
|
||||
) as f:
|
||||
json.dump(target, f)
|
||||
|
||||
logger.info("********** Scheme: Convert BPE **********")
|
||||
os.makedirs("preds", exist_ok=True)
|
||||
label_list = get_label_list(args)
|
||||
convert_base("dev")
|
||||
convert_base("test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def my_bool(s):
|
||||
return s != "False"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dataset", type=str, default="FewNERD", help="FewNERD or Domain"
|
||||
)
|
||||
parser.add_argument("--N", type=int, default=5)
|
||||
parser.add_argument("--K", type=int, default=1)
|
||||
parser.add_argument("--mode", type=str, default="intra")
|
||||
parser.add_argument(
|
||||
"--test_only",
|
||||
action="store_true",
|
||||
help="if true, will load the trained model and run test only",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--convert_bpe",
|
||||
action="store_true",
|
||||
help="if true, will convert the bpe encode result to word level.",
|
||||
)
|
||||
parser.add_argument("--tagging_scheme", type=str, default="BIO", help="BIO or IO")
|
||||
# dataset settings
|
||||
parser.add_argument("--data_path", type=str, default="episode-data")
|
||||
parser.add_argument(
|
||||
"--result_dir", type=str, help="where to save the result.", default="test"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir", type=str, help="dir name of a trained model", default=""
|
||||
)
|
||||
|
||||
# meta-test setting
|
||||
parser.add_argument(
|
||||
"--lr_finetune",
|
||||
type=float,
|
||||
help="finetune learning rate, used in [test_meta]. and [k_shot setting]",
|
||||
default=3e-5,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_ft_steps", type=int, help="maximal steps token for fine-tune.", default=3
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_type_ft_steps",
|
||||
type=int,
|
||||
help="maximal steps token for entity type fine-tune.",
|
||||
default=0,
|
||||
)
|
||||
|
||||
# meta-train setting
|
||||
parser.add_argument(
|
||||
"--inner_steps",
|
||||
type=int,
|
||||
help="every ** inner update for one meta-update",
|
||||
default=2,
|
||||
) # ===>
|
||||
parser.add_argument(
|
||||
"--inner_size",
|
||||
type=int,
|
||||
help="[number of tasks] for one meta-update",
|
||||
default=32,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_inner", type=float, help="inner loop learning rate", default=3e-5
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_meta", type=float, help="meta learning rate", default=3e-5
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_meta_steps",
|
||||
type=int,
|
||||
help="maximal steps token for meta training.",
|
||||
default=5001,
|
||||
)
|
||||
parser.add_argument("--eval_every_meta_steps", type=int, default=500)
|
||||
parser.add_argument(
|
||||
"--warmup_prop_inner",
|
||||
type=int,
|
||||
help="warm up proportion for inner update",
|
||||
default=0.1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup_prop_meta",
|
||||
type=int,
|
||||
help="warm up proportion for meta update",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
# permanent params
|
||||
parser.add_argument(
|
||||
"--freeze_layer", type=int, help="the layer of mBERT to be frozen", default=0
|
||||
)
|
||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--bert_model",
|
||||
type=str,
|
||||
default="bert-base-uncased",
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--viterbi", type=str, default="hard", help="hard or soft or None"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concat_types", type=str, default="past", help="past or before or None"
|
||||
)
|
||||
# expt setting
|
||||
parser.add_argument(
|
||||
"--seed", type=int, help="random seed to reproduce the result.", default=667
|
||||
)
|
||||
parser.add_argument("--gpu_device", type=int, help="GPU device num", default=0)
|
||||
parser.add_argument("--py_alias", type=str, help="python alias", default="python")
|
||||
parser.add_argument(
|
||||
"--types_path",
|
||||
type=str,
|
||||
help="the path of entities types",
|
||||
default="data/entity_types.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative_types_number",
|
||||
type=int,
|
||||
help="the number of negative types in each batch",
|
||||
default=4,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative_mode", type=str, help="the mode of negative types", default="batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--types_mode", type=str, help="the embedding mode of type span", default="cls"
|
||||
)
|
||||
parser.add_argument("--name", type=str, help="the name of experiment", default="")
|
||||
parser.add_argument("--debug", help="debug mode", action="store_true")
|
||||
parser.add_argument(
|
||||
"--init_type_embedding_from_bert",
|
||||
action="store_true",
|
||||
help="initialization type embedding from BERT",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_classify",
|
||||
action="store_true",
|
||||
help="use classifier after entity embedding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance_mode", type=str, help="embedding distance mode", default="cos"
|
||||
)
|
||||
parser.add_argument("--similar_k", type=float, help="cosine similar k", default=10)
|
||||
parser.add_argument("--shared_bert", default=True, type=my_bool, help="shared BERT")
|
||||
parser.add_argument("--train_mode", default="add", type=str, help="add, span, type")
|
||||
parser.add_argument("--eval_mode", default="add", type=str, help="add, two-stage")
|
||||
parser.add_argument(
|
||||
"--type_threshold", default=2.5, type=float, help="typing decoder threshold"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lambda_max_loss", default=2.0, type=float, help="span max loss lambda"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inner_lambda_max_loss", default=2.0, type=float, help="span max loss lambda"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inner_similar_k", type=float, help="cosine similar k", default=10
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_eval_test", help="if/not eval in test", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nouse_inner_ft",
|
||||
action="store_true",
|
||||
help="if true, will convert the bpe encode result to word level.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_supervise",
|
||||
action="store_true",
|
||||
help="if true, will convert the bpe encode result to word level.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.negative_types_number = args.N - 1
|
||||
if "Domain" in args.dataset:
|
||||
args.types_path = "data/entity_types_domain.json"
|
||||
|
||||
# setup random seed
|
||||
set_seed(args.seed, args.gpu_device)
|
||||
|
||||
# set up GPU device
|
||||
device = torch.device("cuda")
|
||||
torch.cuda.set_device(args.gpu_device)
|
||||
|
||||
# setup logger settings
|
||||
if args.test_only:
|
||||
top_dir = "models-{}-{}-{}".format(args.N, args.K, args.mode)
|
||||
args.model_dir = "{}-innerSteps_{}-innerSize_{}-lrInner_{}-lrMeta_{}-maxSteps_{}-seed_{}{}".format(
|
||||
args.bert_model,
|
||||
args.inner_steps,
|
||||
args.inner_size,
|
||||
args.lr_inner,
|
||||
args.lr_meta,
|
||||
args.max_meta_steps,
|
||||
args.seed,
|
||||
"-name_{}".format(args.name) if args.name else "",
|
||||
)
|
||||
args.model_dir = os.path.join(top_dir, args.model_dir)
|
||||
if not os.path.exists(args.model_dir):
|
||||
if args.convert_bpe:
|
||||
os.makedirs(args.model_dir)
|
||||
else:
|
||||
raise ValueError("Model directory does not exist!")
|
||||
fh = logging.FileHandler(
|
||||
"{}/log-test-ftLr_{}-ftSteps_{}.txt".format(
|
||||
args.model_dir, args.lr_finetune, args.max_ft_steps
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
top_dir = "models-{}-{}-{}".format(args.N, args.K, args.mode)
|
||||
args.result_dir = "{}-innerSteps_{}-innerSize_{}-lrInner_{}-lrMeta_{}-maxSteps_{}-seed_{}{}".format(
|
||||
args.bert_model,
|
||||
args.inner_steps,
|
||||
args.inner_size,
|
||||
args.lr_inner,
|
||||
args.lr_meta,
|
||||
args.max_meta_steps,
|
||||
args.seed,
|
||||
"-name_{}".format(args.name) if args.name else "",
|
||||
)
|
||||
os.makedirs(top_dir, exist_ok=True)
|
||||
if not os.path.exists("{}/{}".format(top_dir, args.result_dir)):
|
||||
os.mkdir("{}/{}".format(top_dir, args.result_dir))
|
||||
elif args.result_dir != "test":
|
||||
pass
|
||||
|
||||
args.result_dir = "{}/{}".format(top_dir, args.result_dir)
|
||||
fh = logging.FileHandler("{}/log-training.txt".format(args.result_dir))
|
||||
|
||||
# dump args
|
||||
with Path("{}/args-train.json".format(args.result_dir)).open(
|
||||
"w", encoding="utf-8"
|
||||
) as fw:
|
||||
json.dump(vars(args), fw, indent=4, sort_keys=True)
|
||||
|
||||
if args.debug:
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s %(levelname)s: - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
fh.setLevel(logging.INFO)
|
||||
fh.setFormatter(formatter)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
ch.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(ch)
|
||||
logger.addHandler(fh)
|
||||
|
||||
args.device = device
|
||||
logger.info(f"Using Device {device}")
|
||||
args.entity_types = EntityTypes(
|
||||
args.types_path, args.negative_types_number, args.negative_mode
|
||||
)
|
||||
args.entity_types.build_types_embedding(
|
||||
args.bert_model,
|
||||
True,
|
||||
args.device,
|
||||
args.types_mode,
|
||||
args.init_type_embedding_from_bert,
|
||||
)
|
||||
|
||||
if args.convert_bpe:
|
||||
convert_bpe(args)
|
||||
elif args.test_only:
|
||||
if args.model_dir == "":
|
||||
raise ValueError("NULL model directory!")
|
||||
evaluate(args)
|
||||
else:
|
||||
if args.model_dir != "":
|
||||
raise ValueError("Model directory should be NULL!")
|
||||
train_meta(args)
|
|
@ -0,0 +1,270 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers import BertForTokenClassification
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class BertForTokenClassification_(BertForTokenClassification):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BertForTokenClassification_, self).__init__(*args, **kwargs)
|
||||
self.input_size = 768
|
||||
self.span_loss = nn.functional.cross_entropy
|
||||
self.type_loss = nn.functional.cross_entropy
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
self.log_softmax = nn.functional.log_softmax
|
||||
|
||||
def set_config(
|
||||
self,
|
||||
use_classify: bool = False,
|
||||
distance_mode: str = "cos",
|
||||
similar_k: float = 30,
|
||||
shared_bert: bool = True,
|
||||
train_mode: str = "add",
|
||||
):
|
||||
self.use_classify = use_classify
|
||||
self.distance_mode = distance_mode
|
||||
self.similar_k = similar_k
|
||||
self.shared_bert = shared_bert
|
||||
self.train_mode = train_mode
|
||||
if train_mode == "type":
|
||||
self.classifier = None
|
||||
|
||||
if train_mode != "span":
|
||||
self.ln = nn.LayerNorm(768, 1e-5, True)
|
||||
if use_classify:
|
||||
# self.type_classify = nn.Linear(self.input_size, self.input_size)
|
||||
self.type_classify = nn.Sequential(
|
||||
nn.Linear(self.input_size, self.input_size * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(self.input_size * 2, self.input_size),
|
||||
)
|
||||
if self.distance_mode != "cos":
|
||||
self.dis_cls = nn.Sequential(
|
||||
nn.Linear(self.input_size * 3, self.input_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(self.input_size, 2),
|
||||
)
|
||||
config = {
|
||||
"use_classify": use_classify,
|
||||
"distance_mode": distance_mode,
|
||||
"similar_k": similar_k,
|
||||
"shared_bert": shared_bert,
|
||||
"train_mode": train_mode,
|
||||
}
|
||||
logger.info(f"Model Setting: {config}")
|
||||
if not shared_bert:
|
||||
self.bert2 = deepcopy(self.bert)
|
||||
|
||||
def forward_wuqh(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
labels=None,
|
||||
e_mask=None,
|
||||
e_type_ids=None,
|
||||
e_type_mask=None,
|
||||
entity_types=None,
|
||||
entity_mode: str = "mean",
|
||||
is_update_type_embedding: bool = False,
|
||||
lambda_max_loss: float = 0.0,
|
||||
sim_k: float = 0,
|
||||
):
|
||||
max_len = (attention_mask != 0).max(0)[0].nonzero(as_tuple=False)[-1].item() + 1
|
||||
input_ids = input_ids[:, :max_len]
|
||||
attention_mask = attention_mask[:, :max_len].type(torch.int8)
|
||||
token_type_ids = token_type_ids[:, :max_len]
|
||||
labels = labels[:, :max_len]
|
||||
output = self.bert(
|
||||
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
||||
)
|
||||
sequence_output = self.dropout(output[0])
|
||||
|
||||
if self.train_mode != "type":
|
||||
logits = self.classifier(
|
||||
sequence_output
|
||||
) # batch_size x seq_len x num_labels
|
||||
else:
|
||||
logits = None
|
||||
if not self.shared_bert:
|
||||
output2 = self.bert2(
|
||||
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
||||
)
|
||||
sequence_output2 = self.dropout(output2[0])
|
||||
|
||||
if e_type_ids is not None and self.train_mode != "span":
|
||||
if e_type_mask.sum() != 0:
|
||||
M = (e_type_mask[:, :, 0] != 0).max(0)[0].nonzero(as_tuple=False)[
|
||||
-1
|
||||
].item() + 1
|
||||
else:
|
||||
M = 1
|
||||
e_mask = e_mask[:, :M, :max_len].type(torch.int8)
|
||||
e_type_ids = e_type_ids[:, :M, :]
|
||||
e_type_mask = e_type_mask[:, :M, :].type(torch.int8)
|
||||
B, M, K = e_type_ids.shape
|
||||
e_out = self.get_enity_hidden(
|
||||
sequence_output if self.shared_bert else sequence_output2,
|
||||
e_mask,
|
||||
entity_mode,
|
||||
)
|
||||
if self.use_classify:
|
||||
e_out = self.type_classify(e_out)
|
||||
e_out = self.ln(e_out) # batch_size x max_entity_num x hidden_size
|
||||
if is_update_type_embedding:
|
||||
entity_types.update_type_embedding(e_out, e_type_ids, e_type_mask)
|
||||
e_out = e_out.unsqueeze(2).expand(B, M, K, -1)
|
||||
types = self.get_types_embedding(
|
||||
e_type_ids, entity_types
|
||||
) # batch_size x max_entity_num x K x hidden_size
|
||||
|
||||
if self.distance_mode == "cat":
|
||||
e_types = torch.cat([e_out, types, (e_out - types).abs()], -1)
|
||||
e_types = self.dis_cls(e_types)
|
||||
e_types = e_types[:, :, :0]
|
||||
elif self.distance_mode == "l2":
|
||||
e_types = -(torch.pow(e_out - types, 2)).sum(-1)
|
||||
elif self.distance_mode == "cos":
|
||||
sim_k = sim_k if sim_k else self.similar_k
|
||||
e_types = sim_k * (e_out * types).sum(-1) / 768
|
||||
|
||||
e_logits = e_types
|
||||
if M:
|
||||
em = e_type_mask.clone()
|
||||
em[em.sum(-1) == 0] = 1
|
||||
e = e_types * em
|
||||
e_type_label = torch.zeros((B, M)).to(e_types.device)
|
||||
type_loss = self.calc_loss(
|
||||
self.type_loss, e, e_type_label, e_type_mask[:, :, 0]
|
||||
)
|
||||
else:
|
||||
type_loss = torch.tensor(0).to(sequence_output.device)
|
||||
else:
|
||||
e_logits, type_loss = None, None
|
||||
|
||||
if labels is not None and self.train_mode != "type":
|
||||
# Only keep active parts of the loss
|
||||
loss_fct = CrossEntropyLoss(reduction="none")
|
||||
B, M, T = logits.shape
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.reshape(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.reshape(-1)[active_loss]
|
||||
base_loss = loss_fct(active_logits, active_labels)
|
||||
loss = torch.mean(base_loss)
|
||||
|
||||
# max-loss
|
||||
if lambda_max_loss > 0:
|
||||
active_loss = active_loss.view(B, M)
|
||||
active_max = []
|
||||
start_id = 0
|
||||
for i in range(B):
|
||||
sent_len = torch.sum(active_loss[i])
|
||||
end_id = start_id + sent_len
|
||||
active_max.append(torch.max(base_loss[start_id:end_id]))
|
||||
start_id = end_id
|
||||
|
||||
loss += lambda_max_loss * torch.mean(torch.stack(active_max))
|
||||
else:
|
||||
raise ValueError("Miss attention mask!")
|
||||
else:
|
||||
loss = None
|
||||
|
||||
return logits, e_logits, loss, type_loss
|
||||
|
||||
def get_enity_hidden(
|
||||
self, hidden: torch.Tensor, e_mask: torch.Tensor, entity_mode: str
|
||||
):
|
||||
B, M, T = e_mask.shape
|
||||
e_out = hidden.unsqueeze(1).expand(B, M, T, -1) * e_mask.unsqueeze(
|
||||
-1
|
||||
) # batch_size x max_entity_num x seq_len x hidden_size
|
||||
if entity_mode == "mean":
|
||||
return e_out.sum(2) / (
|
||||
e_mask.sum(-1).unsqueeze(-1) + 1e-30
|
||||
) # batch_size x max_entity_num x hidden_size
|
||||
|
||||
def get_types_embedding(self, e_type_ids: torch.Tensor, entity_types):
|
||||
return entity_types.get_types_embedding(e_type_ids)
|
||||
|
||||
def calc_loss(self, loss_fn, preds, target, mask=None):
|
||||
target = target.reshape(-1)
|
||||
preds += 1e-10
|
||||
preds = preds.reshape(-1, preds.shape[-1])
|
||||
ce_loss = loss_fn(preds, target.long(), reduction="none")
|
||||
if mask is not None:
|
||||
mask = mask.reshape(-1)
|
||||
ce_loss = ce_loss * mask
|
||||
return ce_loss.sum() / (mask.sum() + 1e-10)
|
||||
return ce_loss.sum() / (target.sum() + 1e-10)
|
||||
|
||||
|
||||
class ViterbiDecoder(object):
|
||||
def __init__(
|
||||
self,
|
||||
id2label,
|
||||
transition_matrix,
|
||||
ignore_token_label_id=torch.nn.CrossEntropyLoss().ignore_index,
|
||||
):
|
||||
self.id2label = id2label
|
||||
self.n_labels = len(id2label)
|
||||
self.transitions = transition_matrix
|
||||
self.ignore_token_label_id = ignore_token_label_id
|
||||
|
||||
def forward(self, logprobs, attention_mask, label_ids):
|
||||
# probs: batch_size x max_seq_len x n_labels
|
||||
batch_size, max_seq_len, n_labels = logprobs.size()
|
||||
attention_mask = attention_mask[:, :max_seq_len]
|
||||
label_ids = label_ids[:, :max_seq_len]
|
||||
|
||||
active_tokens = (attention_mask == 1) & (
|
||||
label_ids != self.ignore_token_label_id
|
||||
)
|
||||
if n_labels != self.n_labels:
|
||||
raise ValueError("Labels do not match!")
|
||||
|
||||
label_seqs = []
|
||||
|
||||
for idx in range(batch_size):
|
||||
logprob_i = logprobs[idx, :, :][
|
||||
active_tokens[idx]
|
||||
] # seq_len(active) x n_labels
|
||||
|
||||
back_pointers = []
|
||||
|
||||
forward_var = logprob_i[0] # n_labels
|
||||
|
||||
for j in range(1, len(logprob_i)): # for tag_feat in feat:
|
||||
next_label_var = forward_var + self.transitions # n_labels x n_labels
|
||||
viterbivars_t, bptrs_t = torch.max(next_label_var, dim=1) # n_labels
|
||||
|
||||
logp_j = logprob_i[j] # n_labels
|
||||
forward_var = viterbivars_t + logp_j # n_labels
|
||||
bptrs_t = bptrs_t.cpu().numpy().tolist()
|
||||
back_pointers.append(bptrs_t)
|
||||
|
||||
path_score, best_label_id = torch.max(forward_var, dim=-1)
|
||||
best_label_id = best_label_id.item()
|
||||
best_path = [best_label_id]
|
||||
|
||||
for bptrs_t in reversed(back_pointers):
|
||||
best_label_id = bptrs_t[best_label_id]
|
||||
best_path.append(best_label_id)
|
||||
|
||||
if len(best_path) != len(logprob_i):
|
||||
raise ValueError("Number of labels doesn't match!")
|
||||
|
||||
best_path.reverse()
|
||||
label_seqs.append(best_path)
|
||||
|
||||
return label_seqs
|
|
@ -0,0 +1,830 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import bisect
|
||||
import json
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import BertModel, BertTokenizer
|
||||
from utils import load_file
|
||||
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class EntityTypes(object):
|
||||
def __init__(self, types_path: str, negative_types_number: int, negative_mode: str):
|
||||
self.types = {}
|
||||
self.types_map = {}
|
||||
self.O_id = 0
|
||||
self.types_embedding = None
|
||||
self.negative_mode = negative_mode
|
||||
self.load_entity_types(types_path)
|
||||
|
||||
def load_entity_types(self, types_path: str):
|
||||
self.types = load_file(types_path, "json")
|
||||
types_list = sorted([jj for ii in self.types.values() for jj in ii])
|
||||
self.types_map = {jj: ii for ii, jj in enumerate(types_list)}
|
||||
self.O_id = self.types_map["O"]
|
||||
self.types_list = types_list
|
||||
logger.info("Load %d entity types from %s.", len(types_list), types_path)
|
||||
|
||||
def build_types_embedding(
|
||||
self,
|
||||
model: str,
|
||||
do_lower_case: bool,
|
||||
device,
|
||||
types_mode: str = "cls",
|
||||
init_type_embedding_from_bert: bool = False,
|
||||
):
|
||||
types_list = sorted([jj for ii in self.types.values() for jj in ii])
|
||||
if init_type_embedding_from_bert:
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
model, do_lower_case=do_lower_case
|
||||
)
|
||||
|
||||
tokens = [
|
||||
[tokenizer.cls_token_id]
|
||||
+ tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ii))
|
||||
+ [tokenizer.sep_token_id]
|
||||
for ii in types_list
|
||||
]
|
||||
token_max = max([len(ii) for ii in tokens])
|
||||
mask = [[1] * len(ii) + [0] * (token_max - len(ii)) for ii in tokens]
|
||||
ids = [
|
||||
ii + [tokenizer.pad_token_id] * (token_max - len(ii)) for ii in tokens
|
||||
]
|
||||
mask = torch.tensor(np.array(mask), dtype=torch.long).to(device)
|
||||
ids = torch.tensor(np.array(ids), dtype=torch.long).to(
|
||||
device
|
||||
) # len(type_list) x token_max
|
||||
|
||||
model = BertModel.from_pretrained(model).to(device)
|
||||
outs = model(ids, mask)
|
||||
else:
|
||||
outs = [0, torch.rand((len(types_list), 768)).to(device)]
|
||||
self.types_embedding = nn.Embedding(*outs[1].shape).to(device)
|
||||
if types_mode.lower() == "cls":
|
||||
self.types_embedding.weight = nn.Parameter(outs[1])
|
||||
self.types_embedding.requires_grad = False
|
||||
logger.info("Built the types embedding.")
|
||||
|
||||
def generate_negative_types(
|
||||
self, labels: list, types: list, negative_types_number: int
|
||||
):
|
||||
N = len(labels)
|
||||
data = np.zeros((N, 1 + negative_types_number), np.int)
|
||||
if self.negative_mode == "batch":
|
||||
batch_labels = set(types)
|
||||
if negative_types_number > len(batch_labels):
|
||||
other = list(
|
||||
set(range(len(self.types_map))) - batch_labels - set([self.O_id])
|
||||
)
|
||||
other_size = negative_types_number - len(batch_labels)
|
||||
else:
|
||||
other, other_size = [], 0
|
||||
b_size = min(len(batch_labels) - 1, negative_types_number)
|
||||
o_set = [self.O_id] if negative_types_number > len(batch_labels) - 1 else []
|
||||
for idx, l in enumerate(labels):
|
||||
data[idx][0] = l
|
||||
data[idx][1:] = np.concatenate(
|
||||
[
|
||||
np.random.choice(list(batch_labels - set([l])), b_size, False),
|
||||
o_set,
|
||||
np.random.choice(other, other_size, False),
|
||||
]
|
||||
)
|
||||
return data
|
||||
|
||||
def get_types_embedding(self, labels: torch.Tensor):
|
||||
return self.types_embedding(labels)
|
||||
|
||||
def update_type_embedding(self, e_out, e_type_ids, e_type_mask):
|
||||
labels = e_type_ids[:, :, 0][e_type_mask[:, :, 0] == 1]
|
||||
hiddens = e_out[e_type_mask[:, :, 0] == 1]
|
||||
label_set = set(labels.detach().cpu().numpy())
|
||||
for ii in label_set:
|
||||
self.types_embedding.weight.data[ii] = hiddens[labels == ii].mean(0)
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
def __init__(
|
||||
self, guid: str, words: list, labels: list, types: list, entities: list
|
||||
):
|
||||
self.guid = guid
|
||||
self.words = words
|
||||
self.labels = labels
|
||||
self.types = types
|
||||
self.entities = entities
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(
|
||||
self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
label_ids,
|
||||
e_mask,
|
||||
e_type_ids,
|
||||
e_type_mask,
|
||||
types,
|
||||
entities,
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_ids = label_ids
|
||||
self.e_mask = e_mask
|
||||
self.e_type_ids = e_type_ids
|
||||
self.e_type_mask = e_type_mask
|
||||
self.types = types
|
||||
self.entities = entities
|
||||
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(
|
||||
self,
|
||||
logger,
|
||||
data_fn,
|
||||
bert_model,
|
||||
max_seq_length,
|
||||
label_list,
|
||||
entity_types: EntityTypes,
|
||||
do_lower_case=True,
|
||||
shuffle=True,
|
||||
tagging="BIO",
|
||||
viterbi="none",
|
||||
device="cuda",
|
||||
concat_types: str = "None",
|
||||
dataset: str = "FewNERD",
|
||||
negative_types_number: int = -1,
|
||||
):
|
||||
self.logger = logger
|
||||
self.tokenizer = BertTokenizer.from_pretrained(
|
||||
bert_model, do_lower_case=do_lower_case
|
||||
)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.entity_types = entity_types
|
||||
self.label_list = label_list
|
||||
self.label_map = {label: i for i, label in enumerate(label_list)}
|
||||
self.id2label = {i: label for i, label in enumerate(label_list)}
|
||||
self.n_labels = len(label_list)
|
||||
self.tagging_scheme = tagging
|
||||
self.max_len_dict = {"entity": 0, "type": 0, "sentence": 0}
|
||||
self.max_entities_length = 50
|
||||
self.viterbi = viterbi
|
||||
self.dataset = dataset
|
||||
self.negative_types_number = negative_types_number
|
||||
|
||||
logger.info(
|
||||
"Construct the transition matrix via [{}] scheme...".format(viterbi)
|
||||
)
|
||||
# M[ij]: p(j->i)
|
||||
if viterbi == "none":
|
||||
self.transition_matrix = None
|
||||
update_transition_matrix = False
|
||||
elif viterbi == "hard":
|
||||
self.transition_matrix = torch.zeros(
|
||||
[self.n_labels, self.n_labels], device=device
|
||||
) # pij: p(j -> i)
|
||||
if self.n_labels == 3:
|
||||
self.transition_matrix[2][0] = -10000 # p(O -> I) = 0
|
||||
elif self.n_labels == 5:
|
||||
for (i, j) in [
|
||||
(2, 0),
|
||||
(3, 0),
|
||||
(0, 1),
|
||||
(1, 1),
|
||||
(4, 1),
|
||||
(0, 2),
|
||||
(1, 2),
|
||||
(4, 2),
|
||||
(2, 3),
|
||||
(3, 3),
|
||||
(2, 4),
|
||||
(3, 4),
|
||||
]:
|
||||
self.transition_matrix[i][j] = -10000
|
||||
else:
|
||||
raise ValueError()
|
||||
update_transition_matrix = False
|
||||
elif viterbi == "soft":
|
||||
self.transition_matrix = (
|
||||
torch.zeros([self.n_labels, self.n_labels], device=device) + 1e-8
|
||||
)
|
||||
update_transition_matrix = True
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.tasks = self.read_tasks_from_file(
|
||||
data_fn, update_transition_matrix, concat_types, dataset
|
||||
)
|
||||
|
||||
self.n_total = len(self.tasks)
|
||||
self.batch_start_idx = 0
|
||||
self.batch_idxs = (
|
||||
np.random.permutation(self.n_total)
|
||||
if shuffle
|
||||
else np.array([i for i in range(self.n_total)])
|
||||
) # for batch sampling in training
|
||||
|
||||
def read_tasks_from_file(
|
||||
self,
|
||||
data_fn,
|
||||
update_transition_matrix=False,
|
||||
concat_types: str = "None",
|
||||
dataset: str = "FewNERD",
|
||||
):
|
||||
"""
|
||||
return: List[task]
|
||||
task['support'] = List[InputExample]
|
||||
task['query'] = List[InputExample]
|
||||
"""
|
||||
|
||||
self.logger.info("Reading tasks from {}...".format(data_fn))
|
||||
self.logger.info(
|
||||
" update_transition_matrix = {}".format(update_transition_matrix)
|
||||
)
|
||||
self.logger.info(" concat_types = {}".format(concat_types))
|
||||
|
||||
with open(data_fn, "r", encoding="utf-8") as json_file:
|
||||
json_list = list(json_file)
|
||||
output_tasks = []
|
||||
all_labels = [] if update_transition_matrix else None
|
||||
if dataset == "Domain":
|
||||
json_list = self._convert_Domain2FewNERD(json_list)
|
||||
|
||||
for task_id, json_str in enumerate(json_list):
|
||||
if task_id % 1000 == 0:
|
||||
self.logger.info("Reading tasks %d of %d", task_id, len(json_list))
|
||||
task = json.loads(json_str) if dataset != "Domain" else json_str
|
||||
|
||||
support = task["support"]
|
||||
types = task["types"]
|
||||
if self.negative_types_number == -1:
|
||||
self.negative_types_number = len(types) - 1
|
||||
tmp_support, entities, tmp_support_tokens, tmp_query_tokens = [], [], [], []
|
||||
self.max_len_dict["type"] = max(self.max_len_dict["type"], len(types))
|
||||
|
||||
if concat_types != "None":
|
||||
types = set()
|
||||
for l_list in support["label"]:
|
||||
types.update(l_list)
|
||||
types.remove("O")
|
||||
tokenized_types = self.__tokenize_types__(types, concat_types)
|
||||
else:
|
||||
tokenized_types = None
|
||||
|
||||
for i, (words, labels) in enumerate(zip(support["word"], support["label"])):
|
||||
entities = self._convert_label_to_entities_(labels)
|
||||
self.max_len_dict["entity"] = max(
|
||||
len(entities), self.max_len_dict["entity"]
|
||||
)
|
||||
if self.tagging_scheme == "BIOES":
|
||||
labels = self._convert_label_to_BIOES_(labels)
|
||||
elif self.tagging_scheme == "BIO":
|
||||
labels = self._convert_label_to_BIO_(labels)
|
||||
elif self.tagging_scheme == "IO":
|
||||
labels = self._convert_label_to_IO_(labels)
|
||||
else:
|
||||
raise ValueError("Invalid tagging scheme!")
|
||||
guid = "task[%s]-%s" % (task_id, i)
|
||||
feature, token_sum = self._convert_example_to_feature_(
|
||||
InputExample(
|
||||
guid=guid,
|
||||
words=words,
|
||||
labels=labels,
|
||||
types=types,
|
||||
entities=entities,
|
||||
),
|
||||
tokenized_types=tokenized_types,
|
||||
concat_types=concat_types,
|
||||
)
|
||||
tmp_support.append(feature)
|
||||
tmp_support_tokens.append(token_sum)
|
||||
if update_transition_matrix:
|
||||
all_labels.append(labels)
|
||||
|
||||
query = task["query"]
|
||||
tmp_query = []
|
||||
for i, (words, labels) in enumerate(zip(query["word"], query["label"])):
|
||||
entities = self._convert_label_to_entities_(labels)
|
||||
self.max_len_dict["entity"] = max(
|
||||
len(entities), self.max_len_dict["entity"]
|
||||
)
|
||||
if self.tagging_scheme == "BIOES":
|
||||
labels = self._convert_label_to_BIOES_(labels)
|
||||
elif self.tagging_scheme == "BIO":
|
||||
labels = self._convert_label_to_BIO_(labels)
|
||||
elif self.tagging_scheme == "IO":
|
||||
labels = self._convert_label_to_IO_(labels)
|
||||
else:
|
||||
raise ValueError("Invalid tagging scheme!")
|
||||
guid = "task[%s]-%s" % (task_id, i)
|
||||
feature, token_sum = self._convert_example_to_feature_(
|
||||
InputExample(
|
||||
guid=guid,
|
||||
words=words,
|
||||
labels=labels,
|
||||
types=types,
|
||||
entities=entities,
|
||||
),
|
||||
tokenized_types=tokenized_types,
|
||||
concat_types=concat_types,
|
||||
)
|
||||
tmp_query.append(feature)
|
||||
tmp_query_tokens.append(token_sum)
|
||||
if update_transition_matrix:
|
||||
all_labels.append(labels)
|
||||
|
||||
output_tasks.append(
|
||||
{
|
||||
"support": tmp_support,
|
||||
"query": tmp_query,
|
||||
"support_token": tmp_support_tokens,
|
||||
"query_token": tmp_query_tokens,
|
||||
}
|
||||
)
|
||||
self.logger.info(
|
||||
"%s Max Entities Lengths: %d, Max batch Types Number: %d, Max sentence Length: %d",
|
||||
data_fn,
|
||||
self.max_len_dict["entity"],
|
||||
self.max_len_dict["type"],
|
||||
self.max_len_dict["sentence"],
|
||||
)
|
||||
if update_transition_matrix:
|
||||
self._count_transition_matrix_(all_labels)
|
||||
return output_tasks
|
||||
|
||||
def _convert_Domain2FewNERD(self, data: list):
|
||||
def decode_batch(batch: dict):
|
||||
word = batch["seq_ins"]
|
||||
label = [
|
||||
[jj.replace("B-", "").replace("I-", "") for jj in ii]
|
||||
for ii in batch["seq_outs"]
|
||||
]
|
||||
return {"word": word, "label": label}
|
||||
|
||||
data = json.loads(data[0])
|
||||
res = []
|
||||
for domain in data.keys():
|
||||
d = data[domain]
|
||||
labels = self.entity_types.types[domain]
|
||||
res.extend(
|
||||
[
|
||||
{
|
||||
"support": decode_batch(ii["support"]),
|
||||
"query": decode_batch(ii["batch"]),
|
||||
"types": labels,
|
||||
}
|
||||
for ii in d
|
||||
]
|
||||
)
|
||||
return res
|
||||
|
||||
def __tokenize_types__(self, types, concat_types: str = "past"):
|
||||
tokens = []
|
||||
for t in types:
|
||||
if "embedding" in concat_types:
|
||||
t_tokens = [f"[unused{self.entity_types.types_map[t]}]"]
|
||||
else:
|
||||
t_tokens = self.tokenizer.tokenize(t)
|
||||
if len(t_tokens) == 0:
|
||||
continue
|
||||
tokens.extend(t_tokens)
|
||||
tokens.append(",") # separate different types with a comma ','.
|
||||
tokens.pop() # pop the last comma
|
||||
return tokens
|
||||
|
||||
def _count_transition_matrix_(self, labels):
|
||||
self.logger.info("Computing transition matrix...")
|
||||
for sent_labels in labels:
|
||||
for i in range(len(sent_labels) - 1):
|
||||
start = self.label_map[sent_labels[i]]
|
||||
end = self.label_map[sent_labels[i + 1]]
|
||||
self.transition_matrix[end][start] += 1
|
||||
self.transition_matrix /= torch.sum(self.transition_matrix, dim=0)
|
||||
self.transition_matrix = torch.log(self.transition_matrix)
|
||||
self.logger.info("Done.")
|
||||
|
||||
def _convert_label_to_entities_(self, label_list: list):
|
||||
N = len(label_list)
|
||||
S = [
|
||||
ii
|
||||
for ii in range(N)
|
||||
if label_list[ii] != "O"
|
||||
and (not ii or label_list[ii] != label_list[ii - 1])
|
||||
]
|
||||
E = [
|
||||
ii
|
||||
for ii in range(N)
|
||||
if label_list[ii] != "O"
|
||||
and (ii == N - 1 or label_list[ii] != label_list[ii + 1])
|
||||
]
|
||||
return [(s, e, label_list[s]) for s, e in zip(S, E)]
|
||||
|
||||
def _convert_label_to_BIOES_(self, label_list):
|
||||
res = []
|
||||
label_list = ["O"] + label_list + ["O"]
|
||||
for i in range(1, len(label_list) - 1):
|
||||
if label_list[i] == "O":
|
||||
res.append("O")
|
||||
continue
|
||||
# for S
|
||||
if (
|
||||
label_list[i] != label_list[i - 1]
|
||||
and label_list[i] != label_list[i + 1]
|
||||
):
|
||||
res.append("S")
|
||||
elif (
|
||||
label_list[i] != label_list[i - 1]
|
||||
and label_list[i] == label_list[i + 1]
|
||||
):
|
||||
res.append("B")
|
||||
elif (
|
||||
label_list[i] == label_list[i - 1]
|
||||
and label_list[i] != label_list[i + 1]
|
||||
):
|
||||
res.append("E")
|
||||
elif (
|
||||
label_list[i] == label_list[i - 1]
|
||||
and label_list[i] == label_list[i + 1]
|
||||
):
|
||||
res.append("I")
|
||||
else:
|
||||
raise ValueError("Some bugs exist in your code!")
|
||||
return res
|
||||
|
||||
def _convert_label_to_BIO_(self, label_list):
|
||||
precursor = ""
|
||||
label_output = []
|
||||
for label in label_list:
|
||||
if label == "O":
|
||||
label_output.append("O")
|
||||
elif label != precursor:
|
||||
label_output.append("B")
|
||||
else:
|
||||
label_output.append("I")
|
||||
precursor = label
|
||||
|
||||
return label_output
|
||||
|
||||
def _convert_label_to_IO_(self, label_list):
|
||||
label_output = []
|
||||
for label in label_list:
|
||||
if label == "O":
|
||||
label_output.append("O")
|
||||
else:
|
||||
label_output.append("I")
|
||||
|
||||
return label_output
|
||||
|
||||
def _convert_example_to_feature_(
|
||||
self,
|
||||
example,
|
||||
cls_token_at_end=False,
|
||||
cls_token="[CLS]",
|
||||
cls_token_segment_id=0,
|
||||
sep_token="[SEP]",
|
||||
sep_token_extra=False,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
pad_token_label_id=-1,
|
||||
sequence_a_segment_id=0,
|
||||
mask_padding_with_zero=True,
|
||||
ignore_token_label_id=torch.nn.CrossEntropyLoss().ignore_index,
|
||||
sequence_b_segment_id=1,
|
||||
tokenized_types=None,
|
||||
concat_types: str = "None",
|
||||
):
|
||||
"""
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||
"""
|
||||
tokens, label_ids, token_sum = [], [], [1]
|
||||
if tokenized_types is None:
|
||||
tokenized_types = []
|
||||
if "before" in concat_types:
|
||||
token_sum[-1] += 1 + len(tokenized_types)
|
||||
|
||||
for words, labels in zip(example.words, example.labels):
|
||||
word_tokens = self.tokenizer.tokenize(words)
|
||||
token_sum.append(token_sum[-1] + len(word_tokens))
|
||||
if len(word_tokens) == 0:
|
||||
continue
|
||||
tokens.extend(word_tokens)
|
||||
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
||||
label_ids.extend(
|
||||
[self.label_map[labels]]
|
||||
+ [ignore_token_label_id] * (len(word_tokens) - 1)
|
||||
)
|
||||
self.max_len_dict["sentence"] = max(self.max_len_dict["sentence"], len(tokens))
|
||||
e_ids = [(token_sum[s], token_sum[e + 1] - 1) for s, e, _ in example.entities]
|
||||
e_mask = np.zeros((self.max_entities_length, self.max_seq_length), np.int8)
|
||||
e_type_mask = np.zeros(
|
||||
(self.max_entities_length, 1 + self.negative_types_number), np.int8
|
||||
)
|
||||
e_type_mask[: len(e_ids), :] = np.ones(
|
||||
(len(e_ids), 1 + self.negative_types_number), np.int8
|
||||
)
|
||||
for idx, (s, e) in enumerate(e_ids):
|
||||
e_mask[idx][s : e + 1] = 1
|
||||
e_type_ids = [self.entity_types.types_map[t] for _, _, t in example.entities]
|
||||
entities = [(s, e, t) for (s, e), t in zip(e_ids, e_type_ids)]
|
||||
batch_types = [self.entity_types.types_map[ii] for ii in example.types]
|
||||
# e_type_ids[i, 0] is the possitive label, while e_type_ids[i, 1:] are negative labels
|
||||
e_type_ids = self.entity_types.generate_negative_types(
|
||||
e_type_ids, batch_types, self.negative_types_number
|
||||
)
|
||||
if len(e_type_ids) < self.max_entities_length:
|
||||
e_type_ids = np.concatenate(
|
||||
[
|
||||
e_type_ids,
|
||||
[[0] * (1 + self.negative_types_number)]
|
||||
* (self.max_entities_length - len(e_type_ids)),
|
||||
]
|
||||
)
|
||||
|
||||
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
||||
special_tokens_count = 3 if sep_token_extra else 2
|
||||
if len(tokens) > self.max_seq_length - special_tokens_count - len(
|
||||
tokenized_types
|
||||
):
|
||||
tokens = tokens[
|
||||
: (self.max_seq_length - special_tokens_count - len(tokenized_types))
|
||||
]
|
||||
label_ids = label_ids[
|
||||
: (self.max_seq_length - special_tokens_count - len(tokenized_types))
|
||||
]
|
||||
|
||||
types = [self.entity_types.types_map[t] for t in example.types]
|
||||
|
||||
if "before" in concat_types:
|
||||
# OPTION 1: Concatenated tokenized types at START
|
||||
len_sentence = len(tokens)
|
||||
tokens = [cls_token] + tokenized_types + [sep_token] + tokens
|
||||
label_ids = [ignore_token_label_id] * (len(tokenized_types) + 2) + label_ids
|
||||
segment_ids = (
|
||||
[cls_token_segment_id]
|
||||
+ [sequence_a_segment_id] * (len(tokenized_types) + 1)
|
||||
+ [sequence_b_segment_id] * len_sentence
|
||||
)
|
||||
else:
|
||||
# OPTION 2: Concatenated tokenized types at END
|
||||
tokens += [sep_token]
|
||||
label_ids += [ignore_token_label_id]
|
||||
if sep_token_extra:
|
||||
raise ValueError("Unexpected path!")
|
||||
# roberta uses an extra separator b/w pairs of sentences
|
||||
tokens += [sep_token]
|
||||
label_ids += [ignore_token_label_id]
|
||||
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||||
|
||||
if cls_token_at_end:
|
||||
raise ValueError("Unexpected path!")
|
||||
tokens += [cls_token]
|
||||
label_ids += [ignore_token_label_id]
|
||||
segment_ids += [cls_token_segment_id]
|
||||
else:
|
||||
tokens = [cls_token] + tokens
|
||||
label_ids = [ignore_token_label_id] + label_ids
|
||||
segment_ids = [cls_token_segment_id] + segment_ids
|
||||
|
||||
if "past" in concat_types:
|
||||
tokens += tokenized_types
|
||||
label_ids += [ignore_token_label_id] * len(tokenized_types)
|
||||
segment_ids += [sequence_b_segment_id] * len(tokenized_types)
|
||||
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
|
||||
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
padding_length = self.max_seq_length - len(input_ids)
|
||||
if pad_on_left:
|
||||
input_ids = ([pad_token] * padding_length) + input_ids
|
||||
input_mask = (
|
||||
[0 if mask_padding_with_zero else 1] * padding_length
|
||||
) + input_mask
|
||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
||||
else:
|
||||
input_ids += [pad_token] * padding_length
|
||||
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
||||
segment_ids += [pad_token_segment_id] * padding_length
|
||||
label_ids += [pad_token_label_id] * padding_length
|
||||
|
||||
assert len(input_ids) == self.max_seq_length
|
||||
assert len(input_mask) == self.max_seq_length
|
||||
assert len(segment_ids) == self.max_seq_length
|
||||
assert len(label_ids) == self.max_seq_length
|
||||
|
||||
return (
|
||||
InputFeatures(
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_ids=label_ids,
|
||||
e_mask=e_mask, # max_entities_length x 128
|
||||
e_type_ids=e_type_ids, # max_entities_length x 5 (n_types)
|
||||
e_type_mask=np.array(e_type_mask), # max_entities_length x 5 (n_types)
|
||||
types=np.array(types),
|
||||
entities=entities,
|
||||
),
|
||||
token_sum,
|
||||
)
|
||||
|
||||
def reset_batch_info(self, shuffle=False):
|
||||
self.batch_start_idx = 0
|
||||
self.batch_idxs = (
|
||||
np.random.permutation(self.n_total)
|
||||
if shuffle
|
||||
else np.array([i for i in range(self.n_total)])
|
||||
) # for batch sampling in training
|
||||
|
||||
def get_batch_meta(self, batch_size, device="cuda", shuffle=True):
|
||||
if self.batch_start_idx + batch_size > self.n_total:
|
||||
self.reset_batch_info(shuffle=shuffle)
|
||||
|
||||
query_batch = []
|
||||
support_batch = []
|
||||
start_id = self.batch_start_idx
|
||||
|
||||
for i in range(start_id, start_id + batch_size):
|
||||
idx = self.batch_idxs[i]
|
||||
task_curr = self.tasks[idx]
|
||||
|
||||
query_item = {
|
||||
"input_ids": torch.tensor(
|
||||
[f.input_ids for f in task_curr["query"]], dtype=torch.long
|
||||
).to(
|
||||
device
|
||||
), # 1 x max_seq_len
|
||||
"input_mask": torch.tensor(
|
||||
[f.input_mask for f in task_curr["query"]], dtype=torch.long
|
||||
).to(device),
|
||||
"segment_ids": torch.tensor(
|
||||
[f.segment_ids for f in task_curr["query"]], dtype=torch.long
|
||||
).to(device),
|
||||
"label_ids": torch.tensor(
|
||||
[f.label_ids for f in task_curr["query"]], dtype=torch.long
|
||||
).to(device),
|
||||
"e_mask": torch.tensor(
|
||||
[f.e_mask for f in task_curr["query"]], dtype=torch.int
|
||||
).to(device),
|
||||
"e_type_ids": torch.tensor(
|
||||
[f.e_type_ids for f in task_curr["query"]], dtype=torch.long
|
||||
).to(device),
|
||||
"e_type_mask": torch.tensor(
|
||||
[f.e_type_mask for f in task_curr["query"]], dtype=torch.int
|
||||
).to(device),
|
||||
"types": [f.types for f in task_curr["query"]],
|
||||
"entities": [f.entities for f in task_curr["query"]],
|
||||
"idx": idx,
|
||||
}
|
||||
query_batch.append(query_item)
|
||||
|
||||
support_item = {
|
||||
"input_ids": torch.tensor(
|
||||
[f.input_ids for f in task_curr["support"]], dtype=torch.long
|
||||
).to(device),
|
||||
# 1 x max_seq_len
|
||||
"input_mask": torch.tensor(
|
||||
[f.input_mask for f in task_curr["support"]], dtype=torch.long
|
||||
).to(device),
|
||||
"segment_ids": torch.tensor(
|
||||
[f.segment_ids for f in task_curr["support"]], dtype=torch.long
|
||||
).to(device),
|
||||
"label_ids": torch.tensor(
|
||||
[f.label_ids for f in task_curr["support"]], dtype=torch.long
|
||||
).to(device),
|
||||
"e_mask": torch.tensor(
|
||||
[f.e_mask for f in task_curr["support"]], dtype=torch.int
|
||||
).to(device),
|
||||
"e_type_ids": torch.tensor(
|
||||
[f.e_type_ids for f in task_curr["support"]], dtype=torch.long
|
||||
).to(device),
|
||||
"e_type_mask": torch.tensor(
|
||||
[f.e_type_mask for f in task_curr["support"]], dtype=torch.int
|
||||
).to(device),
|
||||
"types": [f.types for f in task_curr["support"]],
|
||||
"entities": [f.entities for f in task_curr["support"]],
|
||||
"idx": idx,
|
||||
}
|
||||
support_batch.append(support_item)
|
||||
|
||||
self.batch_start_idx += batch_size
|
||||
|
||||
return query_batch, support_batch
|
||||
|
||||
def get_batch_NOmeta(self, batch_size, device="cuda", shuffle=True):
|
||||
if self.batch_start_idx + batch_size >= self.n_total:
|
||||
self.reset_batch_info(shuffle=shuffle)
|
||||
if self.mask_rate >= 0:
|
||||
self.query_features = self.build_query_features_with_mask(
|
||||
mask_rate=self.mask_rate
|
||||
)
|
||||
|
||||
idxs = self.batch_idxs[self.batch_start_idx : self.batch_start_idx + batch_size]
|
||||
batch_features = [self.query_features[idx] for idx in idxs]
|
||||
|
||||
batch = {
|
||||
"input_ids": torch.tensor(
|
||||
[f.input_ids for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"input_mask": torch.tensor(
|
||||
[f.input_mask for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"segment_ids": torch.tensor(
|
||||
[f.segment_ids for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"label_ids": torch.tensor(
|
||||
[f.label_id for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"e_mask": torch.tensor(
|
||||
[f.e_mask for f in batch_features], dtype=torch.int
|
||||
).to(device),
|
||||
"e_type_ids": torch.tensor(
|
||||
[f.e_type_ids for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"e_type_mask": torch.tensor(
|
||||
[f.e_type_mask for f in batch_features], dtype=torch.int
|
||||
).to(device),
|
||||
"types": [f.types for f in batch_features],
|
||||
"entities": [f.entities for f in batch_features],
|
||||
}
|
||||
|
||||
self.batch_start_idx += batch_size
|
||||
|
||||
return batch
|
||||
|
||||
def get_batches(self, batch_size, device="cuda", shuffle=False):
|
||||
batches = []
|
||||
|
||||
if shuffle:
|
||||
idxs = np.random.permutation(self.n_total)
|
||||
features = [self.query_features[i] for i in idxs]
|
||||
else:
|
||||
features = self.query_features
|
||||
|
||||
for i in range(0, self.n_total, batch_size):
|
||||
batch_features = features[i : min(self.n_total, i + batch_size)]
|
||||
|
||||
batch = {
|
||||
"input_ids": torch.tensor(
|
||||
[f.input_ids for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"input_mask": torch.tensor(
|
||||
[f.input_mask for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"segment_ids": torch.tensor(
|
||||
[f.segment_ids for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"label_ids": torch.tensor(
|
||||
[f.label_id for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"e_mask": torch.tensor(
|
||||
[f.e_mask for f in batch_features], dtype=torch.int
|
||||
).to(device),
|
||||
"e_type_ids": torch.tensor(
|
||||
[f.e_type_ids for f in batch_features], dtype=torch.long
|
||||
).to(device),
|
||||
"e_type_mask": torch.tensor(
|
||||
[f.e_type_mask for f in batch_features], dtype=torch.int
|
||||
).to(device),
|
||||
"types": [f.types for f in batch_features],
|
||||
"entities": [f.entities for f in batch_features],
|
||||
}
|
||||
|
||||
batches.append(batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _decoder_bpe_index(self, sentences_spans: list):
|
||||
res = []
|
||||
tokens = [jj for ii in self.tasks for jj in ii["query_token"]]
|
||||
assert len(tokens) == len(
|
||||
sentences_spans
|
||||
), f"tokens size: {len(tokens)}, sentences size: {len(sentences_spans)}"
|
||||
for sentence_idx, spans in enumerate(sentences_spans):
|
||||
token = tokens[sentence_idx]
|
||||
tmp = []
|
||||
for b, e in spans:
|
||||
nb = bisect.bisect_left(token, b)
|
||||
ne = bisect.bisect_left(token, e)
|
||||
tmp.append((nb, ne))
|
||||
res.append(tmp)
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
|
@ -0,0 +1,3 @@
|
|||
numpy
|
||||
transformers==4.10.0
|
||||
torch==1.9.0
|
|
@ -0,0 +1,71 @@
|
|||
rem Copyright (c) Microsoft Corporation.
|
||||
rem Licensed under the MIT license.
|
||||
|
||||
set SEEDS=171 354 550 667 985
|
||||
set N=5
|
||||
set K=1
|
||||
set mode=inter
|
||||
|
||||
for %%e in (%SEEDS%) do (
|
||||
python3 main.py ^
|
||||
--gpu_device=1 ^
|
||||
--seed=%%e ^
|
||||
--mode=%mode% ^
|
||||
--N=%N% ^
|
||||
--K=%K% ^
|
||||
--similar_k=10 ^
|
||||
--eval_every_meta_steps=100 ^
|
||||
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
|
||||
--train_mode=span ^
|
||||
--inner_steps=2 ^
|
||||
--inner_size=32 ^
|
||||
--max_ft_steps=3 ^
|
||||
--lambda_max_loss=2 ^
|
||||
--inner_lambda_max_loss=5 ^
|
||||
--tagging_scheme=BIOES ^
|
||||
--viterbi=hard ^
|
||||
--concat_types=None ^
|
||||
--ignore_eval_test
|
||||
|
||||
python3 main.py ^
|
||||
--seed=%%e ^
|
||||
--gpu_device=1 ^
|
||||
--lr_inner=1e-4 ^
|
||||
--lr_meta=1e-4 ^
|
||||
--mode=%mode% ^
|
||||
--N=%N% ^
|
||||
--K=%K% ^
|
||||
--similar_k=10 ^
|
||||
--inner_similar_k=10 ^
|
||||
--eval_every_meta_steps=100 ^
|
||||
--name=10-k_100_type_2_32_3_10_10 ^
|
||||
--train_mode=type ^
|
||||
--inner_steps=2 ^
|
||||
--inner_size=32 ^
|
||||
--max_ft_steps=3 ^
|
||||
--concat_types=None ^
|
||||
--lambda_max_loss=2.0
|
||||
|
||||
cp models-%N%-%K%-%mode%\bert-base-uncased-innerSteps_2-innerSize_32-lrInner_0.0001-lrMeta_0.0001-maxSteps_5001-seed_%%e-name_10-k_100_type_2_32_3_10_10\en_type_pytorch_model.bin models-%N%-%K%-%mode%\bert-base-uncased-innerSteps_2-innerSize_32-lrInner_3e-05-lrMeta_3e-05-maxSteps_5001-seed_%%e-name_10-k_100_2_32_3_max_loss_2_5_BIOES
|
||||
|
||||
python3 main.py ^
|
||||
--gpu_device=1 ^
|
||||
--seed=%%e ^
|
||||
--N=%N% ^
|
||||
--K=%K% ^
|
||||
--mode=%mode% ^
|
||||
--similar_k=10 ^
|
||||
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
|
||||
--concat_types=None ^
|
||||
--test_only ^
|
||||
--eval_mode=two-stage ^
|
||||
--inner_steps=2 ^
|
||||
--inner_size=32 ^
|
||||
--max_ft_steps=3 ^
|
||||
--max_type_ft_steps=3 ^
|
||||
--lambda_max_loss=2.0 ^
|
||||
--inner_lambda_max_loss=5.0 ^
|
||||
--inner_similar_k=10 ^
|
||||
--viterbi=hard ^
|
||||
--tagging_scheme=BIOES
|
||||
)
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
SEEDS=(171 354 550 667 985)
|
||||
N=5
|
||||
K=1
|
||||
mode=inter
|
||||
|
||||
for seed in ${SEEDS[@]}; do
|
||||
python3 main.py \
|
||||
--gpu_device=1 \
|
||||
--seed=${seed} \
|
||||
--mode=${mode} \
|
||||
--N=${N} \
|
||||
--K=${K} \
|
||||
--similar_k=10 \
|
||||
--eval_every_meta_steps=100 \
|
||||
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
|
||||
--train_mode=span \
|
||||
--inner_steps=2 \
|
||||
--inner_size=32 \
|
||||
--max_ft_steps=3 \
|
||||
--lambda_max_loss=2 \
|
||||
--inner_lambda_max_loss=5 \
|
||||
--tagging_scheme=BIOES \
|
||||
--viterbi=hard \
|
||||
--concat_types=None \
|
||||
--ignore_eval_test
|
||||
|
||||
python3 main.py \
|
||||
--seed=${seed} \
|
||||
--gpu_device=1 \
|
||||
--lr_inner=1e-4 \
|
||||
--lr_meta=1e-4 \
|
||||
--mode=${mode} \
|
||||
--N=${N} \
|
||||
--K=${K} \
|
||||
--similar_k=10 \
|
||||
--inner_similar_k=10 \
|
||||
--eval_every_meta_steps=100 \
|
||||
--name=10-k_100_type_2_32_3_10_10 \
|
||||
--train_mode=type \
|
||||
--inner_steps=2 \
|
||||
--inner_size=32 \
|
||||
--max_ft_steps=3 \
|
||||
--concat_types=None \
|
||||
--lambda_max_loss=2.0
|
||||
|
||||
cp models-${N}-${K}-${mode}/bert-base-uncased-innerSteps_2-innerSize_32-lrInner_0.0001-lrMeta_0.0001-maxSteps_5001-seed_${seed}-name_10-k_100_type_2_32_3_10_10/en_type_pytorch_model.bin models-${N}-${K}-${mode}/bert-base-uncased-innerSteps_2-innerSize_32-lrInner_3e-05-lrMeta_3e-05-maxSteps_5001-seed_${seed}-name_10-k_100_2_32_3_max_loss_2_5_BIOES
|
||||
|
||||
python3 main.py \
|
||||
--gpu_device=1 \
|
||||
--seed=${seed} \
|
||||
--N=${N} \
|
||||
--K=${K} \
|
||||
--mode=${mode} \
|
||||
--similar_k=10 \
|
||||
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
|
||||
--concat_types=None \
|
||||
--test_only \
|
||||
--eval_mode=two-stage \
|
||||
--inner_steps=2 \
|
||||
--inner_size=32 \
|
||||
--max_ft_steps=3 \
|
||||
--max_type_ft_steps=3 \
|
||||
--lambda_max_loss=2.0 \
|
||||
--inner_lambda_max_loss=5.0 \
|
||||
--inner_similar_k=10 \
|
||||
--viterbi=hard \
|
||||
--tagging_scheme=BIOES
|
||||
done
|
|
@ -0,0 +1,30 @@
|
|||
rem Copyright (c) Microsoft Corporation.
|
||||
rem Licensed under the MIT license.
|
||||
|
||||
set SEEDS=171
|
||||
set N=5
|
||||
set K=1
|
||||
set mode=inter
|
||||
|
||||
for %%e in (%SEEDS%) do (
|
||||
python3 main.py ^
|
||||
--gpu_device=1 ^
|
||||
--seed=%%e ^
|
||||
--N=%N% ^
|
||||
--K=%K% ^
|
||||
--mode=%mode% ^
|
||||
--similar_k=10 ^
|
||||
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
|
||||
--concat_types=None ^
|
||||
--test_only ^
|
||||
--eval_mode=two-stage ^
|
||||
--inner_steps=2 ^
|
||||
--inner_size=32 ^
|
||||
--max_ft_steps=3 ^
|
||||
--max_type_ft_steps=3 ^
|
||||
--lambda_max_loss=2.0 ^
|
||||
--inner_lambda_max_loss=5.0 ^
|
||||
--inner_similar_k=10 ^
|
||||
--viterbi=hard ^
|
||||
--tagging_scheme=BIOES
|
||||
)
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
SEEDS=(171)
|
||||
N=5
|
||||
K=1
|
||||
mode=inter
|
||||
|
||||
for seed in ${SEEDS[@]}; do
|
||||
python3 main.py \
|
||||
--gpu_device=1 \
|
||||
--seed=${seed} \
|
||||
--N=${N} \
|
||||
--K=${K} \
|
||||
--mode=${mode} \
|
||||
--similar_k=10 \
|
||||
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
|
||||
--concat_types=None \
|
||||
--test_only \
|
||||
--eval_mode=two-stage \
|
||||
--inner_steps=2 \
|
||||
--inner_size=32 \
|
||||
--max_ft_steps=3 \
|
||||
--max_type_ft_steps=3 \
|
||||
--lambda_max_loss=2.0 \
|
||||
--inner_lambda_max_loss=5.0 \
|
||||
--inner_similar_k=10 \
|
||||
--viterbi=hard \
|
||||
--tagging_scheme=BIOES
|
||||
done
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def load_file(path: str, mode: str = "list-strip"):
|
||||
if not os.path.exists(path):
|
||||
return [] if not mode else ""
|
||||
with open(path, "r", encoding="utf-8", newline="\n") as f:
|
||||
if mode == "list-strip":
|
||||
data = [ii.strip() for ii in f.readlines()]
|
||||
elif mode == "str":
|
||||
data = f.read()
|
||||
elif mode == "list":
|
||||
data = list(f.readlines())
|
||||
elif mode == "json":
|
||||
data = json.loads(f.read())
|
||||
elif mode == "json-list":
|
||||
data = [json.loads(ii) for ii in f.readlines()]
|
||||
return data
|
||||
|
||||
|
||||
def set_seed(seed, gpu_device):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if gpu_device > -1:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
Загрузка…
Ссылка в новой задаче