Code for the DecomposedMetaNER paper at Findings of the ACL 2022 (#22)

Co-authored-by: Huiqiang Jiang <hjiang@microsoft.com>
This commit is contained in:
Huiqiang Jiang 2022-04-29 17:32:54 +08:00 коммит произвёл GitHub
Родитель 9c8eedc5af
Коммит dc14a4f1bf
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 3085 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,142 @@
# Decomposed Meta-Learning for Few-Shot Named Entity Recognition
This repository contains the open-sourced official implementation of the paper:
[Decomposed Meta-Learning for Few-Shot Named Entity Recognition](https://arxiv.org/abs/2204.05751) (Findings of the ACL 2022).
_Tingting Ma, Huiqiang Jiang, Qianhui Wu, Tiejun Zhao, and Chin-Yew Lin_
If you find this repo helpful, please cite the following paper:
```bibtex
@inproceedings{ma2022decomposedmetaner,
title="{D}ecomposed {M}eta-{L}earning for {F}ew-{S}hot {N}amed {E}ntity {R}ecognition",
author={Tingting Ma and Huiqiang Jiang and Qianhui Wu and Tiejun Zhao and Chin-Yew Lin},
year={2022},
month={aug},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics},
publisher={Association for Computational Linguistics},
url={https://arxiv.org/abs/2204.05751},
}
```
For any questions/comments, please feel free to open GitHub issues.
## 🎥 Overview
Few-shot named entity recognition (NER) systems aim at recognizing novel-class named entities based on only a few labeled examples. In this paper, we present a decomposed metalearning approach which addresses the problem of few-shot NER by sequentially tackling fewshot span detection and few-shot entity typing using meta-learning. In particular, we take the few-shot span detection as a sequence labeling problem and train the span detector by introducing the model-agnostic meta-learning (MAML) algorithm to find a good model parameter initialization that could fast adapt to new entity classes. For few-shot entity typing, we propose MAML-ProtoNet, i.e., MAML-enhanced prototypical networks to find a good embedding space that can better distinguish text span representations from different entity classes. Extensive experiments on various benchmarks show that our approach achieves superior performance over prior methods.
![image](./images/framework_DecomposedMetaNER.png)
## 🎯 Quick Start
### Requirements
- python 3.9
- pytorch 1.9.0+cu111
- [HuggingFace Transformers 4.10.0](https://github.com/huggingface/transformers)
- Few-NERD dataset
Other pip package show in `requirements.txt`.
```bash
pip3 install -r requirements.txt
```
The code may work on other python and pytorch version. However, all experiments were run in the above environment.
### Train and Evaluate
For _Linux_ machines,
```bash
bash scripts/run.sh
```
For _Windows_ machines,
```cmd
call scripts\run.bat
```
If you only want to predict in the trained model, you can get the model from [Azure blob](https://kcpapers.blob.core.windows.net/dmn-findings-acl-2022/dmn-all.zip).
Put the `models-{N}-{K}-{mode}` folder in `papers/DecomposedMetaNER` directory.
For _Linux_ machines,
```bash
bash scripts/run_pred.sh
```
For _Windows_ machines,
```cmd
call scripts\run_pred.bat
```
## 🍯 Datasets
We use the following widely-used benchmark datasets for the experiments:
- Few-NERD [Ding et al., 2021](https://aclanthology.org/2021.acl-long.248) for Intra and Inter two mode few-shot NER;
- Cross-Dataset [Hou et al., 2020](https://www.aclweb.org/anthology/2020.acl-main.128) for four cross-domain few-shot NER;
The Few-NERD dataset is annotated with 8 coarse-grained and 66 fine-grained entity types. And the Cross-Dataset are annotated with 4, 11, 6, 18 entity types in difference domain. Each dataset is split into training, dev, and test sets.
All datasets in N-way K~2K shot setting and IO tagging scheme. In this repo, we don't publish any data due to the license. You can download them from their respective websites: [Few-NERD](https://cloud.tsinghua.edu.cn/f/8483dc1a34da4a34ab58/?dl=1), and [Cross-Dataset](https://atmahou.github.io/attachments/ACL2020data.zip).
And place them in the correct locations: `./`.
## 📋 Results
We report the few-shot NER results of the proposed DecomposedMetaNER on the 1~2 shot and 5~10 shot, alongside those reported by prior state-of-the-art methods.
### Few-NERD ACL Version
Available from [Ding et al., 2021](https://cloud.tsinghua.edu.cn/f/8483dc1a34da4a34ab58/?dl=1).
| | Intra 5-1 | Intra 10-1 | Intra 5-5 | Intra 10-5 | Inter 5-1 | Inter 10-1 | Inter 5-5 | Inter 10-5 |
| -------------------------------------------------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
| [ProtoBERT](https://aclanthology.org/2021.acl-long.248) | 23.45 ± 0.92 | 19.76 ± 0.59 | 41.93 ± 0.55 | 34.61 ± 0.59 | 44.44 ± 0.11 | 39.09 ± 0.87 | 58.80 ± 1.42 | 53.97 ± 0.38 |
| [NNShot](https://aclanthology.org/2021.acl-long.248) | 31.01 ± 1.21 | 21.88 ± 0.23 | 35.74 ± 2.36 | 27.67 ± 1.06 | 54.29 ± 0.40 | 46.98 ± 1.96 | 50.56 ± 3.33 | 50.00 ± 0.36 |
| [StructShot](https://aclanthology.org/2021.acl-long.248) | 35.92 ± 0.69 | 25.38 ± 0.84 | 38.83 ± 1.72 | 26.39 ± 2.59 | 57.33 ± 0.53 | 49.46 ± 0.53 | 57.16 ± 2.09 | 49.39 ± 1.77 |
| [CONTAINER](https://arxiv.org/abs/2109.07589) | 40.43 | 33.84 | 53.70 | 47.49 | 55.95 | 48.35 | 61.83 | 57.12 |
| [ESD](https://arxiv.org/abs/2109.13023v1) | 41.44 ± 1.16 | 32.29 ± 1.10 | 50.68 ± 0.94 | 42.92 ± 0.75 | 66.46 ± 0.49 | 59.95 ± 0.69 | **74.14** ± 0.80 | 67.91 ± 1.41 |
| **DecomposedMetaNER** | **52.04** ± 0.44 | **43.50** ± 0.59 | **63.23** ± 0.45 | **56.84** ± 0.14 | **68.77** ± 0.24 | **63.26** ± 0.40 | 71.62 ± 0.16 | **68.32** ± 0.10 |
### Few-NERD Arxiv Version
Available from [Ding et al., 2021](https://cloud.tsinghua.edu.cn/f/0e38bd108d7b49808cc4/?dl=1).
| | Intra 5-1 | Intra 10-1 | Intra 5-5 | Intra 10-5 | Inter 5-1 | Inter 10-1 | Inter 5-5 | Inter 10-5 |
| ---------------------------------------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
| [ProtoBERT](https://arxiv.org/abs/2105.07464) | 20.76 ± 0.84 | 15.05 ± 0.44 | 42.54 ± 0.94 | 35.40 ± 0.13 | 38.83 ± 1.49 | 32.45 ± 0.79 | 58.79 ± 0.44 | 52.92 ± 0.37 |
| [NNShot](https://arxiv.org/abs/2105.07464) | 25.78 ± 0.91 | 18.27 ± 0.41 | 36.18 ± 0.79 | 27.38 ± 0.53 | 47.24 ± 1.00 | 38.87 ± 0.21 | 55.64 ± 0.63 | 49.57 ± 2.73 |
| [StructShot](https://arxiv.org/abs/2105.07464) | 30.21 ± 0.90 | 21.03 ± 1.13 | 38.00 ± 1.29 | 26.42 ± 0.60 | 51.88 ± 0.69 | 43.34 ± 0.10 | 57.32 ± 0.63 | 49.57 ± 3.08 |
| [ESD](https://arxiv.org/abs/2109.13023) | 36.08 ± 1.60 | 30.00 ± 0.70 | 52.14 ± 1.50 | 42.15 ± 2.60 | 59.29 ± 1.25 | 52.16 ± 0.79 | 69.06 ± 0.80 | 64.00 ± 0.43 |
| **DecomposedMetaNER** | **49.48** ± 0.85 | **42.84** ± 0.46 | **62.92** ± 0.57 | **57.31** ± 0.25 | **64.75** ± 0.35 | **58.65** ± 0.43 | **71.49** ± 0.47 | **68.11** ± 0.05 |
### Cross-Dataset
| | News 1-shot | Wiki 1-shot | Social 1-shot | Mixed 1-shot | News 5-shot | Wiki 5-shot | Social 5-shot | Mixed 5-shot |
| ---------------------------------------------------------------------- | ---------------- | ---------------- | --------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
| [TransferBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 4.75 ± 1.42 | 0.57 ± 0.32 | 2.71 ± 0.72 | 3.46 ± 0.54 | 15.36 ± 2.81 | 3.62 ± 0.57 | 11.08 ± 0.57 | 35.49 ± 7.60 |
| [SimBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 19.22 ± 0.00 | 6.91 ± 0.00 | 5.18 ± 0.00 | 13.99 ± 0.00 | 32.01 ± 0.00 | 10.63 ± 0.00 | 8.20 ± 0.00 | 21.14 ± 0.00 |
| [Matching Network](https://www.aclweb.org/anthology/2020.acl-main.128) | 19.50 ± 0.35 | 4.73 ± 0.16 | 17.23 ± 2.75 | 15.06 ± 1.61 | 19.85 ± 0.74 | 5.58 ± 0.23 | 6.61 ± 1.75 | 8.08 ± 0.47 |
| [ProtoBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 32.49 ± 2.01 | 3.89 ± 0.24 | 10.68 ± 1.40 | 6.67 ± 0.46 | 50.06 ± 1.57 | 9.54 ± 0.44 | 17.26 ± 2.65 | 13.59 ± 1.61 |
| [L-TapNet+CDT](https://www.aclweb.org/anthology/2020.acl-main.128) | 44.30 ± 3.15 | 12.04 ± 0.65 | 20.80 ± 1.06 | 15.17 ± 1.25 | 45.35 ± 2.67 | 11.65 ± 2.34 | 23.30 ± 2.80 | 20.95 ± 2.81 |
| **DecomposedMetaNER** | **46.09** ± 0.44 | **17.54** ± 0.98 | **25.1** ± 0.24 | **34.13** ± 0.92 | **58.18** ± 0.87 | **31.36** ± 0.91 | **31.02** ± 1.28 | **45.55** ± 0.90 |
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

Просмотреть файл

@ -0,0 +1,87 @@
{
"O": [
"O"
],
"art": [
"art-broadcastprogram",
"art-film",
"art-music",
"art-other",
"art-painting",
"art-writtenart"
],
"building": [
"building-airport",
"building-hospital",
"building-hotel",
"building-library",
"building-other",
"building-restaurant",
"building-sportsfacility",
"building-theater"
],
"event": [
"event-attack/battle/war/militaryconflict",
"event-disaster",
"event-election",
"event-other",
"event-protest",
"event-sportsevent"
],
"location": [
"location-GPE",
"location-bodiesofwater",
"location-island",
"location-mountain",
"location-other",
"location-park",
"location-road/railway/highway/transit"
],
"organization": [
"organization-company",
"organization-education",
"organization-government/governmentagency",
"organization-media/newspaper",
"organization-other",
"organization-politicalparty",
"organization-religion",
"organization-showorganization",
"organization-sportsleague",
"organization-sportsteam"
],
"other": [
"other-astronomything",
"other-award",
"other-biologything",
"other-chemicalthing",
"other-currency",
"other-disease",
"other-educationaldegree",
"other-god",
"other-language",
"other-law",
"other-livingthing",
"other-medical"
],
"person": [
"person-actor",
"person-artist/author",
"person-athlete",
"person-director",
"person-other",
"person-politician",
"person-scholar",
"person-soldier"
],
"product": [
"product-airplane",
"product-car",
"product-food",
"product-game",
"product-other",
"product-ship",
"product-software",
"product-train",
"product-weapon"
]
}

Просмотреть файл

@ -0,0 +1,52 @@
{
"O": [
"O"
],
"Wiki": [
"abstract",
"animal",
"event",
"object",
"organization",
"person",
"place",
"plant",
"quantity",
"substance",
"time"
],
"SocialMedia": [
"corporation",
"creative-work",
"group",
"location",
"person",
"product"
],
"News": [
"LOC",
"MISC",
"ORG",
"PER"
],
"OntoNotes": [
"CARDINAL",
"DATE",
"EVENT",
"FAC",
"GPE",
"LANGUAGE",
"LAW",
"LOC",
"MONEY",
"NORP",
"ORDINAL",
"ORG",
"PERCENT",
"PERSON",
"PRODUCT",
"QUANTITY",
"TIME",
"WORK_OF_ART"
]
}

Двоичные данные
papers/DecomposedMetaNER/images/framework_DecomposedMetaNER.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 345 KiB

Просмотреть файл

@ -0,0 +1,803 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import shutil
import time
from copy import deepcopy
import numpy
import numpy as np
import torch
from torch import nn
import joblib
from modeling import BertForTokenClassification_
from transformers import CONFIG_NAME, PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME
from transformers import AdamW as BertAdam
from transformers import get_linear_schedule_with_warmup
logger = logging.getLogger(__file__)
class Learner(nn.Module):
ignore_token_label_id = torch.nn.CrossEntropyLoss().ignore_index
pad_token_label_id = -1
def __init__(
self,
bert_model,
label_list,
freeze_layer,
logger,
lr_meta,
lr_inner,
warmup_prop_meta,
warmup_prop_inner,
max_meta_steps,
model_dir="",
cache_dir="",
gpu_no=0,
py_alias="python",
args=None,
):
super(Learner, self).__init__()
self.lr_meta = lr_meta
self.lr_inner = lr_inner
self.warmup_prop_meta = warmup_prop_meta
self.warmup_prop_inner = warmup_prop_inner
self.max_meta_steps = max_meta_steps
self.bert_model = bert_model
self.label_list = label_list
self.py_alias = py_alias
self.entity_types = args.entity_types
self.is_debug = args.debug
self.train_mode = args.train_mode
self.eval_mode = args.eval_mode
self.model_dir = model_dir
self.args = args
self.freeze_layer = freeze_layer
num_labels = len(label_list)
# load model
if model_dir != "":
if self.eval_mode != "two-stage":
self.load_model(self.eval_mode)
else:
logger.info("********** Loading pre-trained model **********")
cache_dir = cache_dir if cache_dir else str(PYTORCH_PRETRAINED_BERT_CACHE)
self.model = BertForTokenClassification_.from_pretrained(
bert_model,
cache_dir=cache_dir,
num_labels=num_labels,
output_hidden_states=True,
).to(args.device)
if self.eval_mode != "two-stage":
self.model.set_config(
args.use_classify,
args.distance_mode,
args.similar_k,
args.shared_bert,
self.train_mode,
)
self.model.to(args.device)
self.layer_set()
def layer_set(self):
# layer freezing
no_grad_param_names = ["embeddings", "pooler"] + [
"layer.{}.".format(i) for i in range(self.freeze_layer)
]
logger.info("The frozen parameters are:")
for name, param in self.model.named_parameters():
if any(no_grad_pn in name for no_grad_pn in no_grad_param_names):
param.requires_grad = False
logger.info(" {}".format(name))
self.opt = BertAdam(self.get_optimizer_grouped_parameters(), lr=self.lr_meta)
self.scheduler = get_linear_schedule_with_warmup(
self.opt,
num_warmup_steps=int(self.max_meta_steps * self.warmup_prop_meta),
num_training_steps=self.max_meta_steps,
)
def get_optimizer_grouped_parameters(self):
param_optimizer = list(self.model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in param_optimizer
if not any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": 0.01,
},
{
"params": [
p
for n, p in param_optimizer
if any(nd in n for nd in no_decay) and p.requires_grad
],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def get_names(self):
names = [n for n, p in self.model.named_parameters() if p.requires_grad]
return names
def get_params(self):
params = [p for p in self.model.parameters() if p.requires_grad]
return params
def load_weights(self, names, params):
model_params = self.model.state_dict()
for n, p in zip(names, params):
model_params[n].data.copy_(p.data)
def load_gradients(self, names, grads):
model_params = self.model.state_dict(keep_vars=True)
for n, g in zip(names, grads):
if model_params[n].grad is None:
continue
model_params[n].grad.data.add_(g.data) # accumulate
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
if schedule == "linear":
if progress < warmup:
lr *= progress / warmup
else:
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
return lr
def inner_update(self, data_support, lr_curr, inner_steps, no_grad: bool = False):
inner_opt = BertAdam(self.get_optimizer_grouped_parameters(), lr=self.lr_inner)
self.model.train()
for i in range(inner_steps):
inner_opt.param_groups[0]["lr"] = lr_curr
inner_opt.param_groups[1]["lr"] = lr_curr
inner_opt.zero_grad()
_, _, loss, type_loss = self.model.forward_wuqh(
input_ids=data_support["input_ids"],
attention_mask=data_support["input_mask"],
token_type_ids=data_support["segment_ids"],
labels=data_support["label_ids"],
e_mask=data_support["e_mask"],
e_type_ids=data_support["e_type_ids"],
e_type_mask=data_support["e_type_mask"],
entity_types=self.entity_types,
is_update_type_embedding=True,
lambda_max_loss=self.args.inner_lambda_max_loss,
sim_k=self.args.inner_similar_k,
)
if loss is None:
loss = type_loss
elif type_loss is not None:
loss = loss + type_loss
if no_grad:
continue
loss.backward()
inner_opt.step()
return loss.item()
def forward_supervise(self, batch_query, batch_support, progress, inner_steps):
span_losses, type_losses = [], []
task_num = len(batch_query)
for task_id in range(task_num):
_, _, loss, type_loss = self.model.forward_wuqh(
input_ids=batch_query[task_id]["input_ids"],
attention_mask=batch_query[task_id]["input_mask"],
token_type_ids=batch_query[task_id]["segment_ids"],
labels=batch_query[task_id]["label_ids"],
e_mask=batch_query[task_id]["e_mask"],
e_type_ids=batch_query[task_id]["e_type_ids"],
e_type_mask=batch_query[task_id]["e_type_mask"],
entity_types=self.entity_types,
lambda_max_loss=self.args.lambda_max_loss,
)
if loss is not None:
span_losses.append(loss.item())
if type_loss is not None:
type_losses.append(type_loss.item())
if loss is None:
loss = type_loss
elif type_loss is not None:
loss = loss + type_loss
loss.backward()
self.opt.step()
self.scheduler.step()
self.model.zero_grad()
for task_id in range(task_num):
_, _, loss, type_loss = self.model.forward_wuqh(
input_ids=batch_support[task_id]["input_ids"],
attention_mask=batch_support[task_id]["input_mask"],
token_type_ids=batch_support[task_id]["segment_ids"],
labels=batch_support[task_id]["label_ids"],
e_mask=batch_support[task_id]["e_mask"],
e_type_ids=batch_support[task_id]["e_type_ids"],
e_type_mask=batch_support[task_id]["e_type_mask"],
entity_types=self.entity_types,
lambda_max_loss=self.args.lambda_max_loss,
)
if loss is not None:
span_losses.append(loss.item())
if type_loss is not None:
type_losses.append(type_loss.item())
if loss is None:
loss = type_loss
elif type_loss is not None:
loss = loss + type_loss
loss.backward()
self.opt.step()
self.scheduler.step()
self.model.zero_grad()
return (
np.mean(span_losses) if span_losses else 0,
np.mean(type_losses) if type_losses else 0,
)
def forward_meta(self, batch_query, batch_support, progress, inner_steps):
names = self.get_names()
params = self.get_params()
weights = deepcopy(params)
meta_grad = []
span_losses, type_losses = [], []
task_num = len(batch_query)
lr_inner = self.get_learning_rate(
self.lr_inner, progress, self.warmup_prop_inner
)
# compute meta_grad of each task
for task_id in range(task_num):
self.inner_update(batch_support[task_id], lr_inner, inner_steps=inner_steps)
_, _, loss, type_loss = self.model.forward_wuqh(
input_ids=batch_query[task_id]["input_ids"],
attention_mask=batch_query[task_id]["input_mask"],
token_type_ids=batch_query[task_id]["segment_ids"],
labels=batch_query[task_id]["label_ids"],
e_mask=batch_query[task_id]["e_mask"],
e_type_ids=batch_query[task_id]["e_type_ids"],
e_type_mask=batch_query[task_id]["e_type_mask"],
entity_types=self.entity_types,
lambda_max_loss=self.args.lambda_max_loss,
)
if loss is not None:
span_losses.append(loss.item())
if type_loss is not None:
type_losses.append(type_loss.item())
if loss is None:
loss = type_loss
elif type_loss is not None:
loss = loss + type_loss
grad = torch.autograd.grad(loss, params)
meta_grad.append(grad)
self.load_weights(names, weights)
# accumulate grads of all tasks to param.grad
self.opt.zero_grad()
# similar to backward()
for g in meta_grad:
self.load_gradients(names, g)
self.opt.step()
self.scheduler.step()
return (
np.mean(span_losses) if span_losses else 0,
np.mean(type_losses) if type_losses else 0,
)
# ---------------------------------------- Evaluation -------------------------------------- #
def write_result(self, words, y_true, y_pred, tmp_fn):
assert len(y_pred) == len(y_true)
with open(tmp_fn, "w", encoding="utf-8") as fw:
for i, sent in enumerate(y_true):
for j, word in enumerate(sent):
fw.write("{} {} {}\n".format(words[i][j], word, y_pred[i][j]))
fw.write("\n")
def batch_test(self, data):
N = data["input_ids"].shape[0]
B = 16
BATCH_KEY = [
"input_ids",
"attention_mask",
"token_type_ids",
"labels",
"e_mask",
"e_type_ids",
"e_type_mask",
]
logits, e_logits, loss, type_loss = [], [], 0, 0
for i in range((N - 1) // B + 1):
tmp = {
ii: jj if ii not in BATCH_KEY else jj[i * B : (i + 1) * B]
for ii, jj in data.items()
}
tmp_l, tmp_el, tmp_loss, tmp_eval_type_loss = self.model.forward_wuqh(**tmp)
if tmp_l is not None:
logits.extend(tmp_l.detach().cpu().numpy())
if tmp_el is not None:
e_logits.extend(tmp_el.detach().cpu().numpy())
if tmp_loss is not None:
loss += tmp_loss
if tmp_eval_type_loss is not None:
type_loss += tmp_eval_type_loss
return logits, e_logits, loss, type_loss
def evaluate_meta_(
self,
corpus,
logger,
lr,
steps,
mode,
set_type,
type_steps: int = None,
viterbi_decoder=None,
):
if not type_steps:
type_steps = steps
if self.is_debug:
self.save_model(self.args.result_dir, "begin", self.args.max_seq_len, "all")
logger.info("Begin first Stage.")
if self.eval_mode == "two-stage":
self.load_model("span")
names = self.get_names()
params = self.get_params()
weights = deepcopy(params)
eval_loss = 0.0
nb_eval_steps = 0
preds = None
t_tmp = time.time()
targets, predes, spans, lss, type_preds, type_g = [], [], [], [], [], []
for item_id in range(corpus.n_total):
eval_query, eval_support = corpus.get_batch_meta(
batch_size=1, shuffle=False
)
# train on support examples
if not self.args.nouse_inner_ft:
self.inner_update(eval_support[0], lr_curr=lr, inner_steps=steps)
# eval on pseudo query examples (test example)
self.model.eval()
with torch.no_grad():
logits, e_ls, tmp_eval_loss, _ = self.batch_test(
{
"input_ids": eval_query[0]["input_ids"],
"attention_mask": eval_query[0]["input_mask"],
"token_type_ids": eval_query[0]["segment_ids"],
"labels": eval_query[0]["label_ids"],
"e_mask": eval_query[0]["e_mask"],
"e_type_ids": eval_query[0]["e_type_ids"],
"e_type_mask": eval_query[0]["e_type_mask"],
"entity_types": self.entity_types,
}
)
lss.append(logits)
if self.model.train_mode != "type":
eval_loss += tmp_eval_loss
if self.model.train_mode != "span":
type_pred, type_ground = self.eval_typing(
e_ls, eval_query[0]["e_type_mask"]
)
type_preds.append(type_pred)
type_g.append(type_ground)
else:
e_mask, e_type_ids, e_type_mask, result, types = self.decode_span(
logits,
eval_query[0]["label_ids"],
eval_query[0]["types"],
eval_query[0]["input_mask"],
viterbi_decoder,
)
targets.extend(eval_query[0]["entities"])
spans.extend(result)
nb_eval_steps += 1
self.load_weights(names, weights)
if item_id % 200 == 0:
logger.info(
" To sentence {}/{}. Time: {}sec".format(
item_id, corpus.n_total, time.time() - t_tmp
)
)
logger.info("Begin second Stage.")
if self.eval_mode == "two-stage":
self.load_model("type")
names = self.get_names()
params = self.get_params()
weights = deepcopy(params)
if self.train_mode == "add":
for item_id in range(corpus.n_total):
eval_query, eval_support = corpus.get_batch_meta(
batch_size=1, shuffle=False
)
logits = lss[item_id]
# train on support examples
self.inner_update(eval_support[0], lr_curr=lr, inner_steps=type_steps)
# eval on pseudo query examples (test example)
self.model.eval()
with torch.no_grad():
e_mask, e_type_ids, e_type_mask, result, types = self.decode_span(
logits,
eval_query[0]["label_ids"],
eval_query[0]["types"],
eval_query[0]["input_mask"],
viterbi_decoder,
)
_, e_logits, _, tmp_eval_type_loss = self.batch_test(
{
"input_ids": eval_query[0]["input_ids"],
"attention_mask": eval_query[0]["input_mask"],
"token_type_ids": eval_query[0]["segment_ids"],
"labels": eval_query[0]["label_ids"],
"e_mask": e_mask,
"e_type_ids": e_type_ids,
"e_type_mask": e_type_mask,
"entity_types": self.entity_types,
}
)
eval_loss += tmp_eval_type_loss
if self.eval_mode == "two-stage":
logits, e_ls, tmp_eval_loss, _ = self.batch_test(
{
"input_ids": eval_query[0]["input_ids"],
"attention_mask": eval_query[0]["input_mask"],
"token_type_ids": eval_query[0]["segment_ids"],
"labels": eval_query[0]["label_ids"],
"e_mask": eval_query[0]["e_mask"],
"e_type_ids": eval_query[0]["e_type_ids"],
"e_type_mask": eval_query[0]["e_type_mask"],
"entity_types": self.entity_types,
}
)
type_pred, type_ground = self.eval_typing(
e_ls, eval_query[0]["e_type_mask"]
)
type_preds.append(type_pred)
type_g.append(type_ground)
taregt, p = self.decode_entity(
e_logits, result, types, eval_query[0]["entities"]
)
predes.extend(p)
self.load_weights(names, weights)
if item_id % 200 == 0:
logger.info(
" To sentence {}/{}. Time: {}sec".format(
item_id, corpus.n_total, time.time() - t_tmp
)
)
eval_loss = eval_loss / nb_eval_steps
if self.is_debug:
joblib.dump([targets, predes, spans], "debug/f1.pkl")
store_dir = self.args.model_dir if self.args.model_dir else self.args.result_dir
joblib.dump(
[targets, predes, spans],
"{}/{}_{}_preds.pkl".format(store_dir, "all", set_type),
)
joblib.dump(
[type_g, type_preds],
"{}/{}_{}_preds.pkl".format(store_dir, "typing", set_type),
)
pred = [[jj[:-1] for jj in ii] for ii in predes]
p, r, f1 = self.cacl_f1(targets, pred)
pred = [
[jj[:-1] for jj in ii if jj[-1] > self.args.type_threshold] for ii in predes
]
p_t, r_t, f1_t = self.cacl_f1(targets, pred)
span_p, span_r, span_f1 = self.cacl_f1(
[[(jj[0], jj[1]) for jj in ii] for ii in targets], spans
)
type_p, type_r, type_f1 = self.cacl_f1(type_g, type_preds)
results = {
"loss": eval_loss,
"precision": p,
"recall": r,
"f1": f1,
"span_p": span_p,
"span_r": span_r,
"span_f1": span_f1,
"type_p": type_p,
"type_r": type_r,
"type_f1": type_f1,
"precision_threshold": p_t,
"recall_threshold": r_t,
"f1_threshold": f1_t,
}
logger.info("***** Eval results %s-%s *****", mode, set_type)
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))
logger.info(
"%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f",
results["precision"] * 100,
results["recall"] * 100,
results["f1"] * 100,
results["span_p"] * 100,
results["span_r"] * 100,
results["span_f1"] * 100,
results["type_p"] * 100,
results["type_r"] * 100,
results["type_f1"] * 100,
results["precision_threshold"] * 100,
results["recall_threshold"] * 100,
results["f1_threshold"] * 100,
)
return results, preds
def save_model(self, result_dir, fn_prefix, max_seq_len, mode: str = "all"):
# Save a trained model and the associated configuration
model_to_save = (
self.model.module if hasattr(self.model, "module") else self.model
) # Only save the model it-self
output_model_file = os.path.join(
result_dir, "{}_{}_{}".format(fn_prefix, mode, WEIGHTS_NAME)
)
torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(result_dir, CONFIG_NAME)
with open(output_config_file, "w", encoding="utf-8") as f:
f.write(model_to_save.config.to_json_string())
label_map = {i: label for i, label in enumerate(self.label_list, 1)}
model_config = {
"bert_model": self.bert_model,
"do_lower": False,
"max_seq_length": max_seq_len,
"num_labels": len(self.label_list) + 1,
"label_map": label_map,
}
json.dump(
model_config,
open(
os.path.join(result_dir, f"{mode}-model_config.json"),
"w",
encoding="utf-8",
),
)
if mode == "type":
joblib.dump(
self.entity_types, os.path.join(result_dir, "type_embedding.pkl")
)
def save_best_model(self, result_dir: str, mode: str):
output_model_file = os.path.join(result_dir, "en_tmp_{}".format(WEIGHTS_NAME))
config_name = os.path.join(result_dir, "tmp-model_config.json")
shutil.copy(output_model_file, output_model_file.replace("tmp", mode))
shutil.copy(config_name, config_name.replace("tmp", mode))
def load_model(self, mode: str = "all"):
if not self.model_dir:
return
model_dir = self.model_dir
logger.info(f"********** Loading saved {mode} model **********")
output_model_file = os.path.join(
model_dir, "en_{}_{}".format(mode, WEIGHTS_NAME)
)
self.model = BertForTokenClassification_.from_pretrained(
self.bert_model, num_labels=len(self.label_list), output_hidden_states=True
)
self.model.set_config(
self.args.use_classify,
self.args.distance_mode,
self.args.similar_k,
self.args.shared_bert,
mode,
)
self.model.to(self.args.device)
self.model.load_state_dict(torch.load(output_model_file, map_location="cuda"))
self.layer_set()
def decode_span(
self,
logits: torch.Tensor,
target: torch.Tensor,
types,
mask: torch.Tensor,
viterbi_decoder=None,
):
if self.is_debug:
joblib.dump([logits, target, self.label_list], "debug/span.pkl")
device = target.device
K = max([len(ii) for ii in types])
if viterbi_decoder:
N = target.shape[0]
B = 16
result = []
for i in range((N - 1) // B + 1):
tmp_logits = torch.tensor(logits[i * B : (i + 1) * B]).to(target.device)
if len(tmp_logits.shape) == 2:
tmp_logits = tmp_logits.unsqueeze(0)
tmp_target = target[i * B : (i + 1) * B]
log_probs = nn.functional.log_softmax(
tmp_logits.detach(), dim=-1
) # batch_size x max_seq_len x n_labels
pred_labels = viterbi_decoder.forward(
log_probs, mask[i * B : (i + 1) * B], tmp_target
)
for ii, jj in zip(pred_labels, tmp_target.detach().cpu().numpy()):
left, right, tmp = 0, 0, []
while right < len(jj) and jj[right] == self.ignore_token_label_id:
tmp.append(-1)
right += 1
while left < len(ii):
tmp.append(ii[left])
left += 1
right += 1
while (
right < len(jj) and jj[right] == self.ignore_token_label_id
):
tmp.append(-1)
right += 1
result.append(tmp)
target = target.detach().cpu().numpy()
B, T = target.shape
if not viterbi_decoder:
logits = logits.detach().cpu().numpy()
result = np.argmax(logits, -1)
if self.label_list == ["O", "B", "I"]:
res = []
for ii in range(B):
tmp, idx = [], 0
max_pad = T - 1
while (
max_pad > 0 and target[ii][max_pad - 1] == self.pad_token_label_id
):
max_pad -= 1
while idx < max_pad:
if target[ii][idx] == self.ignore_token_label_id or (
result[ii][idx] != 1
):
idx += 1
continue
e = idx
while e < max_pad - 1 and (
target[ii][e + 1] == self.ignore_token_label_id
or result[ii][e + 1] in [self.ignore_token_label_id, 2]
):
e += 1
tmp.append((idx, e))
idx = e + 1
res.append(tmp)
elif self.label_list == ["O", "B", "I", "E", "S"]:
res = []
for ii in range(B):
tmp, idx = [], 0
max_pad = T - 1
while (
max_pad > 0 and target[ii][max_pad - 1] == self.pad_token_label_id
):
max_pad -= 1
while idx < max_pad:
if target[ii][idx] == self.ignore_token_label_id or (
result[ii][idx] not in [1, 4]
):
idx += 1
continue
e = idx
while (
e < max_pad - 1
and result[ii][e] not in [3, 4]
and (
target[ii][e + 1] == self.ignore_token_label_id
or result[ii][e + 1] in [self.ignore_token_label_id, 2, 3]
)
):
e += 1
if e < max_pad and result[ii][e] in [3, 4]:
while (
e < max_pad - 1
and target[ii][e + 1] == self.ignore_token_label_id
):
e += 1
tmp.append((idx, e))
idx = e + 1
res.append(tmp)
M = max([len(ii) for ii in res])
e_mask = np.zeros((B, M, T), np.int8)
e_type_mask = np.zeros((B, M, K), np.int8)
e_type_ids = np.zeros((B, M, K), np.int)
for ii in range(B):
for idx, (s, e) in enumerate(res[ii]):
e_mask[ii][idx][s : e + 1] = 1
types_set = types[ii]
if len(res[ii]):
e_type_ids[ii, : len(res[ii]), : len(types_set)] = [types_set] * len(
res[ii]
)
e_type_mask[ii, : len(res[ii]), : len(types_set)] = np.ones(
(len(res[ii]), len(types_set))
)
return (
torch.tensor(e_mask).to(device),
torch.tensor(e_type_ids, dtype=torch.long).to(device),
torch.tensor(e_type_mask).to(device),
res,
types,
)
def decode_entity(self, e_logits, result, types, entities):
if self.is_debug:
joblib.dump([e_logits, result, types, entities], "debug/e.pkl")
target, preds = entities, []
B = len(e_logits)
logits = e_logits
for ii in range(B):
tmp = []
tmp_res = result[ii]
tmp_types = types[ii]
for jj in range(len(tmp_res)):
lg = logits[ii][jj, : len(tmp_types)]
tmp.append((*tmp_res[jj], tmp_types[np.argmax(lg)], lg[np.argmax(lg)]))
preds.append(tmp)
return target, preds
def cacl_f1(self, targets: list, predes: list):
tp, fp, fn = 0, 0, 0
for ii, jj in zip(targets, predes):
ii, jj = set(ii), set(jj)
same = ii - (ii - jj)
tp += len(same)
fn += len(ii - jj)
fp += len(jj - ii)
p = tp / (fp + tp + 1e-10)
r = tp / (fn + tp + 1e-10)
return p, r, 2 * p * r / (p + r + 1e-10)
def eval_typing(self, e_logits, e_type_mask):
e_logits = e_logits
e_type_mask = e_type_mask.detach().cpu().numpy()
if self.is_debug:
joblib.dump([e_logits, e_type_mask], "debug/typing.pkl")
N = len(e_logits)
B_S = 16
res = []
for i in range((N - 1) // B_S + 1):
tmp_e_logits = np.argmax(e_logits[i * B_S : (i + 1) * B_S], -1)
B, M = tmp_e_logits.shape
tmp_e_type_mask = e_type_mask[i * B_S : (i + 1) * B_S][:, :M, 0]
res.extend(tmp_e_logits[tmp_e_type_mask == 1])
ground = [0] * len(res)
return enumerate(res), enumerate(ground)

Просмотреть файл

@ -0,0 +1,661 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import json
import logging
import os
import time
from pathlib import Path
import torch
import joblib
from learner import Learner
from modeling import ViterbiDecoder
from preprocessor import Corpus, EntityTypes
from utils import set_seed
def get_label_list(args):
# prepare dataset
if args.tagging_scheme == "BIOES":
label_list = ["O", "B", "I", "E", "S"]
elif args.tagging_scheme == "BIO":
label_list = ["O", "B", "I"]
else:
label_list = ["O", "I"]
return label_list
def get_data_path(agrs, train_mode: str):
assert args.dataset in [
"FewNERD",
"Domain",
"Domain2",
], f"Dataset: {args.dataset} Not Support."
if args.dataset == "FewNERD":
return os.path.join(
args.data_path,
args.mode,
"{}_{}_{}.jsonl".format(train_mode, args.N, args.K),
)
elif args.dataset == "Domain":
if train_mode == "dev":
train_mode = "valid"
text = "_shot_5" if args.K == 5 else ""
replace_text = "-" if args.K == 5 else "_"
return os.path.join(
"ACL2020data",
"xval_ner{}".format(text),
"ner_{}_{}{}.json".format(train_mode, args.N, text).replace(
"_", replace_text
),
)
elif args.dataset == "Domain2":
if train_mode == "train":
return os.path.join("domain2", "{}_10_5.json".format(train_mode))
return os.path.join(
"domain2", "{}_{}_{}.json".format(train_mode, args.mode, args.K)
)
def replace_type_embedding(learner, args):
logger.info("********** Replace trained type embedding **********")
entity_types = joblib.load(os.path.join(args.result_dir, "type_embedding.pkl"))
N, H = entity_types.types_embedding.weight.data.shape
for ii in range(N):
learner.models.embeddings.word_embeddings.weight.data[
ii + 1
] = entity_types.types_embedding.weight.data[ii]
def train_meta(args):
logger.info("********** Scheme: Meta Learning **********")
label_list = get_label_list(args)
valid_data_path = get_data_path(args, "dev")
valid_corpus = Corpus(
logger,
valid_data_path,
args.bert_model,
args.max_seq_len,
label_list,
args.entity_types,
do_lower_case=True,
shuffle=False,
viterbi=args.viterbi,
tagging=args.tagging_scheme,
device=args.device,
concat_types=args.concat_types,
dataset=args.dataset,
)
if args.debug:
test_corpus = valid_corpus
train_corpus = valid_corpus
else:
train_data_path = get_data_path(args, "train")
train_corpus = Corpus(
logger,
train_data_path,
args.bert_model,
args.max_seq_len,
label_list,
args.entity_types,
do_lower_case=True,
shuffle=True,
tagging=args.tagging_scheme,
device=args.device,
concat_types=args.concat_types,
dataset=args.dataset,
)
if not args.ignore_eval_test:
test_data_path = get_data_path(args, "test")
test_corpus = Corpus(
logger,
test_data_path,
args.bert_model,
args.max_seq_len,
label_list,
args.entity_types,
do_lower_case=True,
shuffle=False,
viterbi=args.viterbi,
tagging=args.tagging_scheme,
device=args.device,
concat_types=args.concat_types,
dataset=args.dataset,
)
learner = Learner(
args.bert_model,
label_list,
args.freeze_layer,
logger,
args.lr_meta,
args.lr_inner,
args.warmup_prop_meta,
args.warmup_prop_inner,
args.max_meta_steps,
py_alias=args.py_alias,
args=args,
)
if "embedding" in args.concat_types:
replace_type_embedding(learner, args)
t = time.time()
F1_valid_best = {ii: -1.0 for ii in ["all", "type", "span"]}
F1_test = -1.0
best_step, protect_step = -1.0, 100 if args.train_mode != "type" else 50
for step in range(args.max_meta_steps):
progress = 1.0 * step / args.max_meta_steps
batch_query, batch_support = train_corpus.get_batch_meta(
batch_size=args.inner_size
) # (batch_size=32)
if args.use_supervise:
span_loss, type_loss = learner.forward_supervise(
batch_query,
batch_support,
progress=progress,
inner_steps=args.inner_steps,
)
else:
span_loss, type_loss = learner.forward_meta(
batch_query,
batch_support,
progress=progress,
inner_steps=args.inner_steps,
)
if step % 20 == 0:
logger.info(
"Step: {}/{}, span loss = {:.6f}, type loss = {:.6f}, time = {:.2f}s.".format(
step, args.max_meta_steps, span_loss, type_loss, time.time() - t
)
)
if step % args.eval_every_meta_steps == 0 and step > protect_step:
logger.info("********** Scheme: evaluate - [valid] **********")
result_valid, predictions_valid = test(args, learner, valid_corpus, "valid")
F1_valid = result_valid["f1"]
is_best = False
if F1_valid > F1_valid_best["all"]:
logger.info("===> Best Valid F1: {}".format(F1_valid))
logger.info(" Saving model...")
learner.save_model(args.result_dir, "en", args.max_seq_len, "all")
F1_valid_best["all"] = F1_valid
best_step = step
is_best = True
if (
result_valid["span_f1"] > F1_valid_best["span"]
and args.train_mode != "type"
):
F1_valid_best["span"] = result_valid["span_f1"]
learner.save_model(args.result_dir, "en", args.max_seq_len, "span")
logger.info("Best Span Store {}".format(step))
is_best = True
if (
result_valid["type_f1"] > F1_valid_best["type"]
and args.train_mode != "span"
):
F1_valid_best["type"] = result_valid["type_f1"]
learner.save_model(args.result_dir, "en", args.max_seq_len, "type")
logger.info("Best Type Store {}".format(step))
is_best = True
if is_best and not args.ignore_eval_test:
logger.info("********** Scheme: evaluate - [test] **********")
result_test, predictions_test = test(args, learner, test_corpus, "test")
F1_test = result_test["f1"]
logger.info(
"Best Valid F1: {}, Step: {}".format(F1_valid_best, best_step)
)
logger.info("Test F1: {}".format(F1_test))
def test(args, learner, corpus, types: str):
if corpus.viterbi != "none":
id2label = corpus.id2label
transition_matrix = corpus.transition_matrix
if args.viterbi == "soft":
label_list = get_label_list(args)
train_data_path = get_data_path(args, "train")
train_corpus = Corpus(
logger,
train_data_path,
args.bert_model,
args.max_seq_len,
label_list,
args.entity_types,
do_lower_case=True,
shuffle=True,
tagging=args.tagging_scheme,
viterbi="soft",
device=args.device,
concat_types=args.concat_types,
dataset=args.dataset,
)
id2label = train_corpus.id2label
transition_matrix = train_corpus.transition_matrix
viterbi_decoder = ViterbiDecoder(id2label, transition_matrix)
else:
viterbi_decoder = None
result_test, predictions = learner.evaluate_meta_(
corpus,
logger,
lr=args.lr_finetune,
steps=args.max_ft_steps,
mode=args.mode,
set_type=types,
type_steps=args.max_type_ft_steps,
viterbi_decoder=viterbi_decoder,
)
return result_test, predictions
def evaluate(args):
logger.info("********** Scheme: Meta Test **********")
label_list = get_label_list(args)
valid_data_path = get_data_path(args, "dev")
valid_corpus = Corpus(
logger,
valid_data_path,
args.bert_model,
args.max_seq_len,
label_list,
args.entity_types,
do_lower_case=True,
shuffle=False,
tagging=args.tagging_scheme,
viterbi=args.viterbi,
concat_types=args.concat_types,
dataset=args.dataset,
)
test_data_path = get_data_path(args, "test")
test_corpus = Corpus(
logger,
test_data_path,
args.bert_model,
args.max_seq_len,
label_list,
args.entity_types,
do_lower_case=True,
shuffle=False,
tagging=args.tagging_scheme,
viterbi=args.viterbi,
concat_types=args.concat_types,
dataset=args.dataset,
)
learner = Learner(
args.bert_model,
label_list,
args.freeze_layer,
logger,
args.lr_meta,
args.lr_inner,
args.warmup_prop_meta,
args.warmup_prop_inner,
args.max_meta_steps,
model_dir=args.model_dir,
py_alias=args.py_alias,
args=args,
)
logger.info("********** Scheme: evaluate - [valid] **********")
test(args, learner, valid_corpus, "valid")
logger.info("********** Scheme: evaluate - [test] **********")
test(args, learner, test_corpus, "test")
def convert_bpe(args):
def convert_base(train_mode: str):
data_path = get_data_path(args, train_mode)
corpus = Corpus(
logger,
data_path,
args.bert_model,
args.max_seq_len,
label_list,
args.entity_types,
do_lower_case=True,
shuffle=False,
tagging=args.tagging_scheme,
viterbi=args.viterbi,
concat_types=args.concat_types,
dataset=args.dataset,
device=args.device,
)
for seed in [171, 354, 550, 667, 985]:
path = os.path.join(
args.model_dir,
f"all_{train_mode if train_mode == 'test' else 'valid'}_preds.pkl",
).replace("171", str(seed))
data = joblib.load(path)
if len(data) == 3:
spans = data[-1]
else:
spans = [[jj[:-2] for jj in ii] for ii in data[-1]]
target = [[jj[:-1] for jj in ii] for ii in data[0]]
res = corpus._decoder_bpe_index(spans)
target = corpus._decoder_bpe_index(target)
with open(
f"preds/{args.mode}-{args.N}way{args.K}shot-seed{seed}-{train_mode}.jsonl",
"w",
) as f:
json.dump(res, f)
if seed != 171:
continue
with open(
f"preds/{args.mode}-{args.N}way{args.K}shot-seed{seed}-{train_mode}_golden.jsonl",
"w",
) as f:
json.dump(target, f)
logger.info("********** Scheme: Convert BPE **********")
os.makedirs("preds", exist_ok=True)
label_list = get_label_list(args)
convert_base("dev")
convert_base("test")
if __name__ == "__main__":
def my_bool(s):
return s != "False"
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset", type=str, default="FewNERD", help="FewNERD or Domain"
)
parser.add_argument("--N", type=int, default=5)
parser.add_argument("--K", type=int, default=1)
parser.add_argument("--mode", type=str, default="intra")
parser.add_argument(
"--test_only",
action="store_true",
help="if true, will load the trained model and run test only",
)
parser.add_argument(
"--convert_bpe",
action="store_true",
help="if true, will convert the bpe encode result to word level.",
)
parser.add_argument("--tagging_scheme", type=str, default="BIO", help="BIO or IO")
# dataset settings
parser.add_argument("--data_path", type=str, default="episode-data")
parser.add_argument(
"--result_dir", type=str, help="where to save the result.", default="test"
)
parser.add_argument(
"--model_dir", type=str, help="dir name of a trained model", default=""
)
# meta-test setting
parser.add_argument(
"--lr_finetune",
type=float,
help="finetune learning rate, used in [test_meta]. and [k_shot setting]",
default=3e-5,
)
parser.add_argument(
"--max_ft_steps", type=int, help="maximal steps token for fine-tune.", default=3
)
parser.add_argument(
"--max_type_ft_steps",
type=int,
help="maximal steps token for entity type fine-tune.",
default=0,
)
# meta-train setting
parser.add_argument(
"--inner_steps",
type=int,
help="every ** inner update for one meta-update",
default=2,
) # ===>
parser.add_argument(
"--inner_size",
type=int,
help="[number of tasks] for one meta-update",
default=32,
)
parser.add_argument(
"--lr_inner", type=float, help="inner loop learning rate", default=3e-5
)
parser.add_argument(
"--lr_meta", type=float, help="meta learning rate", default=3e-5
)
parser.add_argument(
"--max_meta_steps",
type=int,
help="maximal steps token for meta training.",
default=5001,
)
parser.add_argument("--eval_every_meta_steps", type=int, default=500)
parser.add_argument(
"--warmup_prop_inner",
type=int,
help="warm up proportion for inner update",
default=0.1,
)
parser.add_argument(
"--warmup_prop_meta",
type=int,
help="warm up proportion for meta update",
default=0.1,
)
# permanent params
parser.add_argument(
"--freeze_layer", type=int, help="the layer of mBERT to be frozen", default=0
)
parser.add_argument("--max_seq_len", type=int, default=128)
parser.add_argument(
"--bert_model",
type=str,
default="bert-base-uncased",
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.",
)
parser.add_argument(
"--cache_dir",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
default="",
)
parser.add_argument(
"--viterbi", type=str, default="hard", help="hard or soft or None"
)
parser.add_argument(
"--concat_types", type=str, default="past", help="past or before or None"
)
# expt setting
parser.add_argument(
"--seed", type=int, help="random seed to reproduce the result.", default=667
)
parser.add_argument("--gpu_device", type=int, help="GPU device num", default=0)
parser.add_argument("--py_alias", type=str, help="python alias", default="python")
parser.add_argument(
"--types_path",
type=str,
help="the path of entities types",
default="data/entity_types.json",
)
parser.add_argument(
"--negative_types_number",
type=int,
help="the number of negative types in each batch",
default=4,
)
parser.add_argument(
"--negative_mode", type=str, help="the mode of negative types", default="batch"
)
parser.add_argument(
"--types_mode", type=str, help="the embedding mode of type span", default="cls"
)
parser.add_argument("--name", type=str, help="the name of experiment", default="")
parser.add_argument("--debug", help="debug mode", action="store_true")
parser.add_argument(
"--init_type_embedding_from_bert",
action="store_true",
help="initialization type embedding from BERT",
)
parser.add_argument(
"--use_classify",
action="store_true",
help="use classifier after entity embedding",
)
parser.add_argument(
"--distance_mode", type=str, help="embedding distance mode", default="cos"
)
parser.add_argument("--similar_k", type=float, help="cosine similar k", default=10)
parser.add_argument("--shared_bert", default=True, type=my_bool, help="shared BERT")
parser.add_argument("--train_mode", default="add", type=str, help="add, span, type")
parser.add_argument("--eval_mode", default="add", type=str, help="add, two-stage")
parser.add_argument(
"--type_threshold", default=2.5, type=float, help="typing decoder threshold"
)
parser.add_argument(
"--lambda_max_loss", default=2.0, type=float, help="span max loss lambda"
)
parser.add_argument(
"--inner_lambda_max_loss", default=2.0, type=float, help="span max loss lambda"
)
parser.add_argument(
"--inner_similar_k", type=float, help="cosine similar k", default=10
)
parser.add_argument(
"--ignore_eval_test", help="if/not eval in test", action="store_true"
)
parser.add_argument(
"--nouse_inner_ft",
action="store_true",
help="if true, will convert the bpe encode result to word level.",
)
parser.add_argument(
"--use_supervise",
action="store_true",
help="if true, will convert the bpe encode result to word level.",
)
args = parser.parse_args()
args.negative_types_number = args.N - 1
if "Domain" in args.dataset:
args.types_path = "data/entity_types_domain.json"
# setup random seed
set_seed(args.seed, args.gpu_device)
# set up GPU device
device = torch.device("cuda")
torch.cuda.set_device(args.gpu_device)
# setup logger settings
if args.test_only:
top_dir = "models-{}-{}-{}".format(args.N, args.K, args.mode)
args.model_dir = "{}-innerSteps_{}-innerSize_{}-lrInner_{}-lrMeta_{}-maxSteps_{}-seed_{}{}".format(
args.bert_model,
args.inner_steps,
args.inner_size,
args.lr_inner,
args.lr_meta,
args.max_meta_steps,
args.seed,
"-name_{}".format(args.name) if args.name else "",
)
args.model_dir = os.path.join(top_dir, args.model_dir)
if not os.path.exists(args.model_dir):
if args.convert_bpe:
os.makedirs(args.model_dir)
else:
raise ValueError("Model directory does not exist!")
fh = logging.FileHandler(
"{}/log-test-ftLr_{}-ftSteps_{}.txt".format(
args.model_dir, args.lr_finetune, args.max_ft_steps
)
)
else:
top_dir = "models-{}-{}-{}".format(args.N, args.K, args.mode)
args.result_dir = "{}-innerSteps_{}-innerSize_{}-lrInner_{}-lrMeta_{}-maxSteps_{}-seed_{}{}".format(
args.bert_model,
args.inner_steps,
args.inner_size,
args.lr_inner,
args.lr_meta,
args.max_meta_steps,
args.seed,
"-name_{}".format(args.name) if args.name else "",
)
os.makedirs(top_dir, exist_ok=True)
if not os.path.exists("{}/{}".format(top_dir, args.result_dir)):
os.mkdir("{}/{}".format(top_dir, args.result_dir))
elif args.result_dir != "test":
pass
args.result_dir = "{}/{}".format(top_dir, args.result_dir)
fh = logging.FileHandler("{}/log-training.txt".format(args.result_dir))
# dump args
with Path("{}/args-train.json".format(args.result_dir)).open(
"w", encoding="utf-8"
) as fw:
json.dump(vars(args), fw, indent=4, sort_keys=True)
if args.debug:
os.makedirs("debug", exist_ok=True)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s %(levelname)s: - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.addHandler(fh)
args.device = device
logger.info(f"Using Device {device}")
args.entity_types = EntityTypes(
args.types_path, args.negative_types_number, args.negative_mode
)
args.entity_types.build_types_embedding(
args.bert_model,
True,
args.device,
args.types_mode,
args.init_type_embedding_from_bert,
)
if args.convert_bpe:
convert_bpe(args)
elif args.test_only:
if args.model_dir == "":
raise ValueError("NULL model directory!")
evaluate(args)
else:
if args.model_dir != "":
raise ValueError("Model directory should be NULL!")
train_meta(args)

Просмотреть файл

@ -0,0 +1,270 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import BertForTokenClassification
logger = logging.getLogger(__file__)
class BertForTokenClassification_(BertForTokenClassification):
def __init__(self, *args, **kwargs):
super(BertForTokenClassification_, self).__init__(*args, **kwargs)
self.input_size = 768
self.span_loss = nn.functional.cross_entropy
self.type_loss = nn.functional.cross_entropy
self.dropout = nn.Dropout(p=0.1)
self.log_softmax = nn.functional.log_softmax
def set_config(
self,
use_classify: bool = False,
distance_mode: str = "cos",
similar_k: float = 30,
shared_bert: bool = True,
train_mode: str = "add",
):
self.use_classify = use_classify
self.distance_mode = distance_mode
self.similar_k = similar_k
self.shared_bert = shared_bert
self.train_mode = train_mode
if train_mode == "type":
self.classifier = None
if train_mode != "span":
self.ln = nn.LayerNorm(768, 1e-5, True)
if use_classify:
# self.type_classify = nn.Linear(self.input_size, self.input_size)
self.type_classify = nn.Sequential(
nn.Linear(self.input_size, self.input_size * 2),
nn.GELU(),
nn.Linear(self.input_size * 2, self.input_size),
)
if self.distance_mode != "cos":
self.dis_cls = nn.Sequential(
nn.Linear(self.input_size * 3, self.input_size),
nn.GELU(),
nn.Linear(self.input_size, 2),
)
config = {
"use_classify": use_classify,
"distance_mode": distance_mode,
"similar_k": similar_k,
"shared_bert": shared_bert,
"train_mode": train_mode,
}
logger.info(f"Model Setting: {config}")
if not shared_bert:
self.bert2 = deepcopy(self.bert)
def forward_wuqh(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
labels=None,
e_mask=None,
e_type_ids=None,
e_type_mask=None,
entity_types=None,
entity_mode: str = "mean",
is_update_type_embedding: bool = False,
lambda_max_loss: float = 0.0,
sim_k: float = 0,
):
max_len = (attention_mask != 0).max(0)[0].nonzero(as_tuple=False)[-1].item() + 1
input_ids = input_ids[:, :max_len]
attention_mask = attention_mask[:, :max_len].type(torch.int8)
token_type_ids = token_type_ids[:, :max_len]
labels = labels[:, :max_len]
output = self.bert(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
)
sequence_output = self.dropout(output[0])
if self.train_mode != "type":
logits = self.classifier(
sequence_output
) # batch_size x seq_len x num_labels
else:
logits = None
if not self.shared_bert:
output2 = self.bert2(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
)
sequence_output2 = self.dropout(output2[0])
if e_type_ids is not None and self.train_mode != "span":
if e_type_mask.sum() != 0:
M = (e_type_mask[:, :, 0] != 0).max(0)[0].nonzero(as_tuple=False)[
-1
].item() + 1
else:
M = 1
e_mask = e_mask[:, :M, :max_len].type(torch.int8)
e_type_ids = e_type_ids[:, :M, :]
e_type_mask = e_type_mask[:, :M, :].type(torch.int8)
B, M, K = e_type_ids.shape
e_out = self.get_enity_hidden(
sequence_output if self.shared_bert else sequence_output2,
e_mask,
entity_mode,
)
if self.use_classify:
e_out = self.type_classify(e_out)
e_out = self.ln(e_out) # batch_size x max_entity_num x hidden_size
if is_update_type_embedding:
entity_types.update_type_embedding(e_out, e_type_ids, e_type_mask)
e_out = e_out.unsqueeze(2).expand(B, M, K, -1)
types = self.get_types_embedding(
e_type_ids, entity_types
) # batch_size x max_entity_num x K x hidden_size
if self.distance_mode == "cat":
e_types = torch.cat([e_out, types, (e_out - types).abs()], -1)
e_types = self.dis_cls(e_types)
e_types = e_types[:, :, :0]
elif self.distance_mode == "l2":
e_types = -(torch.pow(e_out - types, 2)).sum(-1)
elif self.distance_mode == "cos":
sim_k = sim_k if sim_k else self.similar_k
e_types = sim_k * (e_out * types).sum(-1) / 768
e_logits = e_types
if M:
em = e_type_mask.clone()
em[em.sum(-1) == 0] = 1
e = e_types * em
e_type_label = torch.zeros((B, M)).to(e_types.device)
type_loss = self.calc_loss(
self.type_loss, e, e_type_label, e_type_mask[:, :, 0]
)
else:
type_loss = torch.tensor(0).to(sequence_output.device)
else:
e_logits, type_loss = None, None
if labels is not None and self.train_mode != "type":
# Only keep active parts of the loss
loss_fct = CrossEntropyLoss(reduction="none")
B, M, T = logits.shape
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.reshape(-1, self.num_labels)[active_loss]
active_labels = labels.reshape(-1)[active_loss]
base_loss = loss_fct(active_logits, active_labels)
loss = torch.mean(base_loss)
# max-loss
if lambda_max_loss > 0:
active_loss = active_loss.view(B, M)
active_max = []
start_id = 0
for i in range(B):
sent_len = torch.sum(active_loss[i])
end_id = start_id + sent_len
active_max.append(torch.max(base_loss[start_id:end_id]))
start_id = end_id
loss += lambda_max_loss * torch.mean(torch.stack(active_max))
else:
raise ValueError("Miss attention mask!")
else:
loss = None
return logits, e_logits, loss, type_loss
def get_enity_hidden(
self, hidden: torch.Tensor, e_mask: torch.Tensor, entity_mode: str
):
B, M, T = e_mask.shape
e_out = hidden.unsqueeze(1).expand(B, M, T, -1) * e_mask.unsqueeze(
-1
) # batch_size x max_entity_num x seq_len x hidden_size
if entity_mode == "mean":
return e_out.sum(2) / (
e_mask.sum(-1).unsqueeze(-1) + 1e-30
) # batch_size x max_entity_num x hidden_size
def get_types_embedding(self, e_type_ids: torch.Tensor, entity_types):
return entity_types.get_types_embedding(e_type_ids)
def calc_loss(self, loss_fn, preds, target, mask=None):
target = target.reshape(-1)
preds += 1e-10
preds = preds.reshape(-1, preds.shape[-1])
ce_loss = loss_fn(preds, target.long(), reduction="none")
if mask is not None:
mask = mask.reshape(-1)
ce_loss = ce_loss * mask
return ce_loss.sum() / (mask.sum() + 1e-10)
return ce_loss.sum() / (target.sum() + 1e-10)
class ViterbiDecoder(object):
def __init__(
self,
id2label,
transition_matrix,
ignore_token_label_id=torch.nn.CrossEntropyLoss().ignore_index,
):
self.id2label = id2label
self.n_labels = len(id2label)
self.transitions = transition_matrix
self.ignore_token_label_id = ignore_token_label_id
def forward(self, logprobs, attention_mask, label_ids):
# probs: batch_size x max_seq_len x n_labels
batch_size, max_seq_len, n_labels = logprobs.size()
attention_mask = attention_mask[:, :max_seq_len]
label_ids = label_ids[:, :max_seq_len]
active_tokens = (attention_mask == 1) & (
label_ids != self.ignore_token_label_id
)
if n_labels != self.n_labels:
raise ValueError("Labels do not match!")
label_seqs = []
for idx in range(batch_size):
logprob_i = logprobs[idx, :, :][
active_tokens[idx]
] # seq_len(active) x n_labels
back_pointers = []
forward_var = logprob_i[0] # n_labels
for j in range(1, len(logprob_i)): # for tag_feat in feat:
next_label_var = forward_var + self.transitions # n_labels x n_labels
viterbivars_t, bptrs_t = torch.max(next_label_var, dim=1) # n_labels
logp_j = logprob_i[j] # n_labels
forward_var = viterbivars_t + logp_j # n_labels
bptrs_t = bptrs_t.cpu().numpy().tolist()
back_pointers.append(bptrs_t)
path_score, best_label_id = torch.max(forward_var, dim=-1)
best_label_id = best_label_id.item()
best_path = [best_label_id]
for bptrs_t in reversed(back_pointers):
best_label_id = bptrs_t[best_label_id]
best_path.append(best_label_id)
if len(best_path) != len(logprob_i):
raise ValueError("Number of labels doesn't match!")
best_path.reverse()
label_seqs.append(best_path)
return label_seqs

Просмотреть файл

@ -0,0 +1,830 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import bisect
import json
import logging
import numpy as np
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from utils import load_file
logger = logging.getLogger(__file__)
class EntityTypes(object):
def __init__(self, types_path: str, negative_types_number: int, negative_mode: str):
self.types = {}
self.types_map = {}
self.O_id = 0
self.types_embedding = None
self.negative_mode = negative_mode
self.load_entity_types(types_path)
def load_entity_types(self, types_path: str):
self.types = load_file(types_path, "json")
types_list = sorted([jj for ii in self.types.values() for jj in ii])
self.types_map = {jj: ii for ii, jj in enumerate(types_list)}
self.O_id = self.types_map["O"]
self.types_list = types_list
logger.info("Load %d entity types from %s.", len(types_list), types_path)
def build_types_embedding(
self,
model: str,
do_lower_case: bool,
device,
types_mode: str = "cls",
init_type_embedding_from_bert: bool = False,
):
types_list = sorted([jj for ii in self.types.values() for jj in ii])
if init_type_embedding_from_bert:
tokenizer = BertTokenizer.from_pretrained(
model, do_lower_case=do_lower_case
)
tokens = [
[tokenizer.cls_token_id]
+ tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ii))
+ [tokenizer.sep_token_id]
for ii in types_list
]
token_max = max([len(ii) for ii in tokens])
mask = [[1] * len(ii) + [0] * (token_max - len(ii)) for ii in tokens]
ids = [
ii + [tokenizer.pad_token_id] * (token_max - len(ii)) for ii in tokens
]
mask = torch.tensor(np.array(mask), dtype=torch.long).to(device)
ids = torch.tensor(np.array(ids), dtype=torch.long).to(
device
) # len(type_list) x token_max
model = BertModel.from_pretrained(model).to(device)
outs = model(ids, mask)
else:
outs = [0, torch.rand((len(types_list), 768)).to(device)]
self.types_embedding = nn.Embedding(*outs[1].shape).to(device)
if types_mode.lower() == "cls":
self.types_embedding.weight = nn.Parameter(outs[1])
self.types_embedding.requires_grad = False
logger.info("Built the types embedding.")
def generate_negative_types(
self, labels: list, types: list, negative_types_number: int
):
N = len(labels)
data = np.zeros((N, 1 + negative_types_number), np.int)
if self.negative_mode == "batch":
batch_labels = set(types)
if negative_types_number > len(batch_labels):
other = list(
set(range(len(self.types_map))) - batch_labels - set([self.O_id])
)
other_size = negative_types_number - len(batch_labels)
else:
other, other_size = [], 0
b_size = min(len(batch_labels) - 1, negative_types_number)
o_set = [self.O_id] if negative_types_number > len(batch_labels) - 1 else []
for idx, l in enumerate(labels):
data[idx][0] = l
data[idx][1:] = np.concatenate(
[
np.random.choice(list(batch_labels - set([l])), b_size, False),
o_set,
np.random.choice(other, other_size, False),
]
)
return data
def get_types_embedding(self, labels: torch.Tensor):
return self.types_embedding(labels)
def update_type_embedding(self, e_out, e_type_ids, e_type_mask):
labels = e_type_ids[:, :, 0][e_type_mask[:, :, 0] == 1]
hiddens = e_out[e_type_mask[:, :, 0] == 1]
label_set = set(labels.detach().cpu().numpy())
for ii in label_set:
self.types_embedding.weight.data[ii] = hiddens[labels == ii].mean(0)
class InputExample(object):
def __init__(
self, guid: str, words: list, labels: list, types: list, entities: list
):
self.guid = guid
self.words = words
self.labels = labels
self.types = types
self.entities = entities
class InputFeatures(object):
def __init__(
self,
input_ids,
input_mask,
segment_ids,
label_ids,
e_mask,
e_type_ids,
e_type_mask,
types,
entities,
):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_ids = label_ids
self.e_mask = e_mask
self.e_type_ids = e_type_ids
self.e_type_mask = e_type_mask
self.types = types
self.entities = entities
class Corpus(object):
def __init__(
self,
logger,
data_fn,
bert_model,
max_seq_length,
label_list,
entity_types: EntityTypes,
do_lower_case=True,
shuffle=True,
tagging="BIO",
viterbi="none",
device="cuda",
concat_types: str = "None",
dataset: str = "FewNERD",
negative_types_number: int = -1,
):
self.logger = logger
self.tokenizer = BertTokenizer.from_pretrained(
bert_model, do_lower_case=do_lower_case
)
self.max_seq_length = max_seq_length
self.entity_types = entity_types
self.label_list = label_list
self.label_map = {label: i for i, label in enumerate(label_list)}
self.id2label = {i: label for i, label in enumerate(label_list)}
self.n_labels = len(label_list)
self.tagging_scheme = tagging
self.max_len_dict = {"entity": 0, "type": 0, "sentence": 0}
self.max_entities_length = 50
self.viterbi = viterbi
self.dataset = dataset
self.negative_types_number = negative_types_number
logger.info(
"Construct the transition matrix via [{}] scheme...".format(viterbi)
)
# M[ij]: p(j->i)
if viterbi == "none":
self.transition_matrix = None
update_transition_matrix = False
elif viterbi == "hard":
self.transition_matrix = torch.zeros(
[self.n_labels, self.n_labels], device=device
) # pij: p(j -> i)
if self.n_labels == 3:
self.transition_matrix[2][0] = -10000 # p(O -> I) = 0
elif self.n_labels == 5:
for (i, j) in [
(2, 0),
(3, 0),
(0, 1),
(1, 1),
(4, 1),
(0, 2),
(1, 2),
(4, 2),
(2, 3),
(3, 3),
(2, 4),
(3, 4),
]:
self.transition_matrix[i][j] = -10000
else:
raise ValueError()
update_transition_matrix = False
elif viterbi == "soft":
self.transition_matrix = (
torch.zeros([self.n_labels, self.n_labels], device=device) + 1e-8
)
update_transition_matrix = True
else:
raise ValueError()
self.tasks = self.read_tasks_from_file(
data_fn, update_transition_matrix, concat_types, dataset
)
self.n_total = len(self.tasks)
self.batch_start_idx = 0
self.batch_idxs = (
np.random.permutation(self.n_total)
if shuffle
else np.array([i for i in range(self.n_total)])
) # for batch sampling in training
def read_tasks_from_file(
self,
data_fn,
update_transition_matrix=False,
concat_types: str = "None",
dataset: str = "FewNERD",
):
"""
return: List[task]
task['support'] = List[InputExample]
task['query'] = List[InputExample]
"""
self.logger.info("Reading tasks from {}...".format(data_fn))
self.logger.info(
" update_transition_matrix = {}".format(update_transition_matrix)
)
self.logger.info(" concat_types = {}".format(concat_types))
with open(data_fn, "r", encoding="utf-8") as json_file:
json_list = list(json_file)
output_tasks = []
all_labels = [] if update_transition_matrix else None
if dataset == "Domain":
json_list = self._convert_Domain2FewNERD(json_list)
for task_id, json_str in enumerate(json_list):
if task_id % 1000 == 0:
self.logger.info("Reading tasks %d of %d", task_id, len(json_list))
task = json.loads(json_str) if dataset != "Domain" else json_str
support = task["support"]
types = task["types"]
if self.negative_types_number == -1:
self.negative_types_number = len(types) - 1
tmp_support, entities, tmp_support_tokens, tmp_query_tokens = [], [], [], []
self.max_len_dict["type"] = max(self.max_len_dict["type"], len(types))
if concat_types != "None":
types = set()
for l_list in support["label"]:
types.update(l_list)
types.remove("O")
tokenized_types = self.__tokenize_types__(types, concat_types)
else:
tokenized_types = None
for i, (words, labels) in enumerate(zip(support["word"], support["label"])):
entities = self._convert_label_to_entities_(labels)
self.max_len_dict["entity"] = max(
len(entities), self.max_len_dict["entity"]
)
if self.tagging_scheme == "BIOES":
labels = self._convert_label_to_BIOES_(labels)
elif self.tagging_scheme == "BIO":
labels = self._convert_label_to_BIO_(labels)
elif self.tagging_scheme == "IO":
labels = self._convert_label_to_IO_(labels)
else:
raise ValueError("Invalid tagging scheme!")
guid = "task[%s]-%s" % (task_id, i)
feature, token_sum = self._convert_example_to_feature_(
InputExample(
guid=guid,
words=words,
labels=labels,
types=types,
entities=entities,
),
tokenized_types=tokenized_types,
concat_types=concat_types,
)
tmp_support.append(feature)
tmp_support_tokens.append(token_sum)
if update_transition_matrix:
all_labels.append(labels)
query = task["query"]
tmp_query = []
for i, (words, labels) in enumerate(zip(query["word"], query["label"])):
entities = self._convert_label_to_entities_(labels)
self.max_len_dict["entity"] = max(
len(entities), self.max_len_dict["entity"]
)
if self.tagging_scheme == "BIOES":
labels = self._convert_label_to_BIOES_(labels)
elif self.tagging_scheme == "BIO":
labels = self._convert_label_to_BIO_(labels)
elif self.tagging_scheme == "IO":
labels = self._convert_label_to_IO_(labels)
else:
raise ValueError("Invalid tagging scheme!")
guid = "task[%s]-%s" % (task_id, i)
feature, token_sum = self._convert_example_to_feature_(
InputExample(
guid=guid,
words=words,
labels=labels,
types=types,
entities=entities,
),
tokenized_types=tokenized_types,
concat_types=concat_types,
)
tmp_query.append(feature)
tmp_query_tokens.append(token_sum)
if update_transition_matrix:
all_labels.append(labels)
output_tasks.append(
{
"support": tmp_support,
"query": tmp_query,
"support_token": tmp_support_tokens,
"query_token": tmp_query_tokens,
}
)
self.logger.info(
"%s Max Entities Lengths: %d, Max batch Types Number: %d, Max sentence Length: %d",
data_fn,
self.max_len_dict["entity"],
self.max_len_dict["type"],
self.max_len_dict["sentence"],
)
if update_transition_matrix:
self._count_transition_matrix_(all_labels)
return output_tasks
def _convert_Domain2FewNERD(self, data: list):
def decode_batch(batch: dict):
word = batch["seq_ins"]
label = [
[jj.replace("B-", "").replace("I-", "") for jj in ii]
for ii in batch["seq_outs"]
]
return {"word": word, "label": label}
data = json.loads(data[0])
res = []
for domain in data.keys():
d = data[domain]
labels = self.entity_types.types[domain]
res.extend(
[
{
"support": decode_batch(ii["support"]),
"query": decode_batch(ii["batch"]),
"types": labels,
}
for ii in d
]
)
return res
def __tokenize_types__(self, types, concat_types: str = "past"):
tokens = []
for t in types:
if "embedding" in concat_types:
t_tokens = [f"[unused{self.entity_types.types_map[t]}]"]
else:
t_tokens = self.tokenizer.tokenize(t)
if len(t_tokens) == 0:
continue
tokens.extend(t_tokens)
tokens.append(",") # separate different types with a comma ','.
tokens.pop() # pop the last comma
return tokens
def _count_transition_matrix_(self, labels):
self.logger.info("Computing transition matrix...")
for sent_labels in labels:
for i in range(len(sent_labels) - 1):
start = self.label_map[sent_labels[i]]
end = self.label_map[sent_labels[i + 1]]
self.transition_matrix[end][start] += 1
self.transition_matrix /= torch.sum(self.transition_matrix, dim=0)
self.transition_matrix = torch.log(self.transition_matrix)
self.logger.info("Done.")
def _convert_label_to_entities_(self, label_list: list):
N = len(label_list)
S = [
ii
for ii in range(N)
if label_list[ii] != "O"
and (not ii or label_list[ii] != label_list[ii - 1])
]
E = [
ii
for ii in range(N)
if label_list[ii] != "O"
and (ii == N - 1 or label_list[ii] != label_list[ii + 1])
]
return [(s, e, label_list[s]) for s, e in zip(S, E)]
def _convert_label_to_BIOES_(self, label_list):
res = []
label_list = ["O"] + label_list + ["O"]
for i in range(1, len(label_list) - 1):
if label_list[i] == "O":
res.append("O")
continue
# for S
if (
label_list[i] != label_list[i - 1]
and label_list[i] != label_list[i + 1]
):
res.append("S")
elif (
label_list[i] != label_list[i - 1]
and label_list[i] == label_list[i + 1]
):
res.append("B")
elif (
label_list[i] == label_list[i - 1]
and label_list[i] != label_list[i + 1]
):
res.append("E")
elif (
label_list[i] == label_list[i - 1]
and label_list[i] == label_list[i + 1]
):
res.append("I")
else:
raise ValueError("Some bugs exist in your code!")
return res
def _convert_label_to_BIO_(self, label_list):
precursor = ""
label_output = []
for label in label_list:
if label == "O":
label_output.append("O")
elif label != precursor:
label_output.append("B")
else:
label_output.append("I")
precursor = label
return label_output
def _convert_label_to_IO_(self, label_list):
label_output = []
for label in label_list:
if label == "O":
label_output.append("O")
else:
label_output.append("I")
return label_output
def _convert_example_to_feature_(
self,
example,
cls_token_at_end=False,
cls_token="[CLS]",
cls_token_segment_id=0,
sep_token="[SEP]",
sep_token_extra=False,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
pad_token_label_id=-1,
sequence_a_segment_id=0,
mask_padding_with_zero=True,
ignore_token_label_id=torch.nn.CrossEntropyLoss().ignore_index,
sequence_b_segment_id=1,
tokenized_types=None,
concat_types: str = "None",
):
"""
`cls_token_at_end` define the location of the CLS token:
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
"""
tokens, label_ids, token_sum = [], [], [1]
if tokenized_types is None:
tokenized_types = []
if "before" in concat_types:
token_sum[-1] += 1 + len(tokenized_types)
for words, labels in zip(example.words, example.labels):
word_tokens = self.tokenizer.tokenize(words)
token_sum.append(token_sum[-1] + len(word_tokens))
if len(word_tokens) == 0:
continue
tokens.extend(word_tokens)
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
label_ids.extend(
[self.label_map[labels]]
+ [ignore_token_label_id] * (len(word_tokens) - 1)
)
self.max_len_dict["sentence"] = max(self.max_len_dict["sentence"], len(tokens))
e_ids = [(token_sum[s], token_sum[e + 1] - 1) for s, e, _ in example.entities]
e_mask = np.zeros((self.max_entities_length, self.max_seq_length), np.int8)
e_type_mask = np.zeros(
(self.max_entities_length, 1 + self.negative_types_number), np.int8
)
e_type_mask[: len(e_ids), :] = np.ones(
(len(e_ids), 1 + self.negative_types_number), np.int8
)
for idx, (s, e) in enumerate(e_ids):
e_mask[idx][s : e + 1] = 1
e_type_ids = [self.entity_types.types_map[t] for _, _, t in example.entities]
entities = [(s, e, t) for (s, e), t in zip(e_ids, e_type_ids)]
batch_types = [self.entity_types.types_map[ii] for ii in example.types]
# e_type_ids[i, 0] is the possitive label, while e_type_ids[i, 1:] are negative labels
e_type_ids = self.entity_types.generate_negative_types(
e_type_ids, batch_types, self.negative_types_number
)
if len(e_type_ids) < self.max_entities_length:
e_type_ids = np.concatenate(
[
e_type_ids,
[[0] * (1 + self.negative_types_number)]
* (self.max_entities_length - len(e_type_ids)),
]
)
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
special_tokens_count = 3 if sep_token_extra else 2
if len(tokens) > self.max_seq_length - special_tokens_count - len(
tokenized_types
):
tokens = tokens[
: (self.max_seq_length - special_tokens_count - len(tokenized_types))
]
label_ids = label_ids[
: (self.max_seq_length - special_tokens_count - len(tokenized_types))
]
types = [self.entity_types.types_map[t] for t in example.types]
if "before" in concat_types:
# OPTION 1: Concatenated tokenized types at START
len_sentence = len(tokens)
tokens = [cls_token] + tokenized_types + [sep_token] + tokens
label_ids = [ignore_token_label_id] * (len(tokenized_types) + 2) + label_ids
segment_ids = (
[cls_token_segment_id]
+ [sequence_a_segment_id] * (len(tokenized_types) + 1)
+ [sequence_b_segment_id] * len_sentence
)
else:
# OPTION 2: Concatenated tokenized types at END
tokens += [sep_token]
label_ids += [ignore_token_label_id]
if sep_token_extra:
raise ValueError("Unexpected path!")
# roberta uses an extra separator b/w pairs of sentences
tokens += [sep_token]
label_ids += [ignore_token_label_id]
segment_ids = [sequence_a_segment_id] * len(tokens)
if cls_token_at_end:
raise ValueError("Unexpected path!")
tokens += [cls_token]
label_ids += [ignore_token_label_id]
segment_ids += [cls_token_segment_id]
else:
tokens = [cls_token] + tokens
label_ids = [ignore_token_label_id] + label_ids
segment_ids = [cls_token_segment_id] + segment_ids
if "past" in concat_types:
tokens += tokenized_types
label_ids += [ignore_token_label_id] * len(tokenized_types)
segment_ids += [sequence_b_segment_id] * len(tokenized_types)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = self.max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = (
[0 if mask_padding_with_zero else 1] * padding_length
) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
label_ids = ([pad_token_label_id] * padding_length) + label_ids
else:
input_ids += [pad_token] * padding_length
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
segment_ids += [pad_token_segment_id] * padding_length
label_ids += [pad_token_label_id] * padding_length
assert len(input_ids) == self.max_seq_length
assert len(input_mask) == self.max_seq_length
assert len(segment_ids) == self.max_seq_length
assert len(label_ids) == self.max_seq_length
return (
InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_ids=label_ids,
e_mask=e_mask, # max_entities_length x 128
e_type_ids=e_type_ids, # max_entities_length x 5 (n_types)
e_type_mask=np.array(e_type_mask), # max_entities_length x 5 (n_types)
types=np.array(types),
entities=entities,
),
token_sum,
)
def reset_batch_info(self, shuffle=False):
self.batch_start_idx = 0
self.batch_idxs = (
np.random.permutation(self.n_total)
if shuffle
else np.array([i for i in range(self.n_total)])
) # for batch sampling in training
def get_batch_meta(self, batch_size, device="cuda", shuffle=True):
if self.batch_start_idx + batch_size > self.n_total:
self.reset_batch_info(shuffle=shuffle)
query_batch = []
support_batch = []
start_id = self.batch_start_idx
for i in range(start_id, start_id + batch_size):
idx = self.batch_idxs[i]
task_curr = self.tasks[idx]
query_item = {
"input_ids": torch.tensor(
[f.input_ids for f in task_curr["query"]], dtype=torch.long
).to(
device
), # 1 x max_seq_len
"input_mask": torch.tensor(
[f.input_mask for f in task_curr["query"]], dtype=torch.long
).to(device),
"segment_ids": torch.tensor(
[f.segment_ids for f in task_curr["query"]], dtype=torch.long
).to(device),
"label_ids": torch.tensor(
[f.label_ids for f in task_curr["query"]], dtype=torch.long
).to(device),
"e_mask": torch.tensor(
[f.e_mask for f in task_curr["query"]], dtype=torch.int
).to(device),
"e_type_ids": torch.tensor(
[f.e_type_ids for f in task_curr["query"]], dtype=torch.long
).to(device),
"e_type_mask": torch.tensor(
[f.e_type_mask for f in task_curr["query"]], dtype=torch.int
).to(device),
"types": [f.types for f in task_curr["query"]],
"entities": [f.entities for f in task_curr["query"]],
"idx": idx,
}
query_batch.append(query_item)
support_item = {
"input_ids": torch.tensor(
[f.input_ids for f in task_curr["support"]], dtype=torch.long
).to(device),
# 1 x max_seq_len
"input_mask": torch.tensor(
[f.input_mask for f in task_curr["support"]], dtype=torch.long
).to(device),
"segment_ids": torch.tensor(
[f.segment_ids for f in task_curr["support"]], dtype=torch.long
).to(device),
"label_ids": torch.tensor(
[f.label_ids for f in task_curr["support"]], dtype=torch.long
).to(device),
"e_mask": torch.tensor(
[f.e_mask for f in task_curr["support"]], dtype=torch.int
).to(device),
"e_type_ids": torch.tensor(
[f.e_type_ids for f in task_curr["support"]], dtype=torch.long
).to(device),
"e_type_mask": torch.tensor(
[f.e_type_mask for f in task_curr["support"]], dtype=torch.int
).to(device),
"types": [f.types for f in task_curr["support"]],
"entities": [f.entities for f in task_curr["support"]],
"idx": idx,
}
support_batch.append(support_item)
self.batch_start_idx += batch_size
return query_batch, support_batch
def get_batch_NOmeta(self, batch_size, device="cuda", shuffle=True):
if self.batch_start_idx + batch_size >= self.n_total:
self.reset_batch_info(shuffle=shuffle)
if self.mask_rate >= 0:
self.query_features = self.build_query_features_with_mask(
mask_rate=self.mask_rate
)
idxs = self.batch_idxs[self.batch_start_idx : self.batch_start_idx + batch_size]
batch_features = [self.query_features[idx] for idx in idxs]
batch = {
"input_ids": torch.tensor(
[f.input_ids for f in batch_features], dtype=torch.long
).to(device),
"input_mask": torch.tensor(
[f.input_mask for f in batch_features], dtype=torch.long
).to(device),
"segment_ids": torch.tensor(
[f.segment_ids for f in batch_features], dtype=torch.long
).to(device),
"label_ids": torch.tensor(
[f.label_id for f in batch_features], dtype=torch.long
).to(device),
"e_mask": torch.tensor(
[f.e_mask for f in batch_features], dtype=torch.int
).to(device),
"e_type_ids": torch.tensor(
[f.e_type_ids for f in batch_features], dtype=torch.long
).to(device),
"e_type_mask": torch.tensor(
[f.e_type_mask for f in batch_features], dtype=torch.int
).to(device),
"types": [f.types for f in batch_features],
"entities": [f.entities for f in batch_features],
}
self.batch_start_idx += batch_size
return batch
def get_batches(self, batch_size, device="cuda", shuffle=False):
batches = []
if shuffle:
idxs = np.random.permutation(self.n_total)
features = [self.query_features[i] for i in idxs]
else:
features = self.query_features
for i in range(0, self.n_total, batch_size):
batch_features = features[i : min(self.n_total, i + batch_size)]
batch = {
"input_ids": torch.tensor(
[f.input_ids for f in batch_features], dtype=torch.long
).to(device),
"input_mask": torch.tensor(
[f.input_mask for f in batch_features], dtype=torch.long
).to(device),
"segment_ids": torch.tensor(
[f.segment_ids for f in batch_features], dtype=torch.long
).to(device),
"label_ids": torch.tensor(
[f.label_id for f in batch_features], dtype=torch.long
).to(device),
"e_mask": torch.tensor(
[f.e_mask for f in batch_features], dtype=torch.int
).to(device),
"e_type_ids": torch.tensor(
[f.e_type_ids for f in batch_features], dtype=torch.long
).to(device),
"e_type_mask": torch.tensor(
[f.e_type_mask for f in batch_features], dtype=torch.int
).to(device),
"types": [f.types for f in batch_features],
"entities": [f.entities for f in batch_features],
}
batches.append(batch)
return batches
def _decoder_bpe_index(self, sentences_spans: list):
res = []
tokens = [jj for ii in self.tasks for jj in ii["query_token"]]
assert len(tokens) == len(
sentences_spans
), f"tokens size: {len(tokens)}, sentences size: {len(sentences_spans)}"
for sentence_idx, spans in enumerate(sentences_spans):
token = tokens[sentence_idx]
tmp = []
for b, e in spans:
nb = bisect.bisect_left(token, b)
ne = bisect.bisect_left(token, e)
tmp.append((nb, ne))
res.append(tmp)
return res
if __name__ == "__main__":
pass

Просмотреть файл

@ -0,0 +1,3 @@
numpy
transformers==4.10.0
torch==1.9.0

Просмотреть файл

@ -0,0 +1,71 @@
rem Copyright (c) Microsoft Corporation.
rem Licensed under the MIT license.
set SEEDS=171 354 550 667 985
set N=5
set K=1
set mode=inter
for %%e in (%SEEDS%) do (
python3 main.py ^
--gpu_device=1 ^
--seed=%%e ^
--mode=%mode% ^
--N=%N% ^
--K=%K% ^
--similar_k=10 ^
--eval_every_meta_steps=100 ^
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
--train_mode=span ^
--inner_steps=2 ^
--inner_size=32 ^
--max_ft_steps=3 ^
--lambda_max_loss=2 ^
--inner_lambda_max_loss=5 ^
--tagging_scheme=BIOES ^
--viterbi=hard ^
--concat_types=None ^
--ignore_eval_test
python3 main.py ^
--seed=%%e ^
--gpu_device=1 ^
--lr_inner=1e-4 ^
--lr_meta=1e-4 ^
--mode=%mode% ^
--N=%N% ^
--K=%K% ^
--similar_k=10 ^
--inner_similar_k=10 ^
--eval_every_meta_steps=100 ^
--name=10-k_100_type_2_32_3_10_10 ^
--train_mode=type ^
--inner_steps=2 ^
--inner_size=32 ^
--max_ft_steps=3 ^
--concat_types=None ^
--lambda_max_loss=2.0
cp models-%N%-%K%-%mode%\bert-base-uncased-innerSteps_2-innerSize_32-lrInner_0.0001-lrMeta_0.0001-maxSteps_5001-seed_%%e-name_10-k_100_type_2_32_3_10_10\en_type_pytorch_model.bin models-%N%-%K%-%mode%\bert-base-uncased-innerSteps_2-innerSize_32-lrInner_3e-05-lrMeta_3e-05-maxSteps_5001-seed_%%e-name_10-k_100_2_32_3_max_loss_2_5_BIOES
python3 main.py ^
--gpu_device=1 ^
--seed=%%e ^
--N=%N% ^
--K=%K% ^
--mode=%mode% ^
--similar_k=10 ^
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
--concat_types=None ^
--test_only ^
--eval_mode=two-stage ^
--inner_steps=2 ^
--inner_size=32 ^
--max_ft_steps=3 ^
--max_type_ft_steps=3 ^
--lambda_max_loss=2.0 ^
--inner_lambda_max_loss=5.0 ^
--inner_similar_k=10 ^
--viterbi=hard ^
--tagging_scheme=BIOES
)

Просмотреть файл

@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
SEEDS=(171 354 550 667 985)
N=5
K=1
mode=inter
for seed in ${SEEDS[@]}; do
python3 main.py \
--gpu_device=1 \
--seed=${seed} \
--mode=${mode} \
--N=${N} \
--K=${K} \
--similar_k=10 \
--eval_every_meta_steps=100 \
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
--train_mode=span \
--inner_steps=2 \
--inner_size=32 \
--max_ft_steps=3 \
--lambda_max_loss=2 \
--inner_lambda_max_loss=5 \
--tagging_scheme=BIOES \
--viterbi=hard \
--concat_types=None \
--ignore_eval_test
python3 main.py \
--seed=${seed} \
--gpu_device=1 \
--lr_inner=1e-4 \
--lr_meta=1e-4 \
--mode=${mode} \
--N=${N} \
--K=${K} \
--similar_k=10 \
--inner_similar_k=10 \
--eval_every_meta_steps=100 \
--name=10-k_100_type_2_32_3_10_10 \
--train_mode=type \
--inner_steps=2 \
--inner_size=32 \
--max_ft_steps=3 \
--concat_types=None \
--lambda_max_loss=2.0
cp models-${N}-${K}-${mode}/bert-base-uncased-innerSteps_2-innerSize_32-lrInner_0.0001-lrMeta_0.0001-maxSteps_5001-seed_${seed}-name_10-k_100_type_2_32_3_10_10/en_type_pytorch_model.bin models-${N}-${K}-${mode}/bert-base-uncased-innerSteps_2-innerSize_32-lrInner_3e-05-lrMeta_3e-05-maxSteps_5001-seed_${seed}-name_10-k_100_2_32_3_max_loss_2_5_BIOES
python3 main.py \
--gpu_device=1 \
--seed=${seed} \
--N=${N} \
--K=${K} \
--mode=${mode} \
--similar_k=10 \
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
--concat_types=None \
--test_only \
--eval_mode=two-stage \
--inner_steps=2 \
--inner_size=32 \
--max_ft_steps=3 \
--max_type_ft_steps=3 \
--lambda_max_loss=2.0 \
--inner_lambda_max_loss=5.0 \
--inner_similar_k=10 \
--viterbi=hard \
--tagging_scheme=BIOES
done

Просмотреть файл

@ -0,0 +1,30 @@
rem Copyright (c) Microsoft Corporation.
rem Licensed under the MIT license.
set SEEDS=171
set N=5
set K=1
set mode=inter
for %%e in (%SEEDS%) do (
python3 main.py ^
--gpu_device=1 ^
--seed=%%e ^
--N=%N% ^
--K=%K% ^
--mode=%mode% ^
--similar_k=10 ^
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
--concat_types=None ^
--test_only ^
--eval_mode=two-stage ^
--inner_steps=2 ^
--inner_size=32 ^
--max_ft_steps=3 ^
--max_type_ft_steps=3 ^
--lambda_max_loss=2.0 ^
--inner_lambda_max_loss=5.0 ^
--inner_similar_k=10 ^
--viterbi=hard ^
--tagging_scheme=BIOES
)

Просмотреть файл

@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
SEEDS=(171)
N=5
K=1
mode=inter
for seed in ${SEEDS[@]}; do
python3 main.py \
--gpu_device=1 \
--seed=${seed} \
--N=${N} \
--K=${K} \
--mode=${mode} \
--similar_k=10 \
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
--concat_types=None \
--test_only \
--eval_mode=two-stage \
--inner_steps=2 \
--inner_size=32 \
--max_ft_steps=3 \
--max_type_ft_steps=3 \
--lambda_max_loss=2.0 \
--inner_lambda_max_loss=5.0 \
--inner_similar_k=10 \
--viterbi=hard \
--tagging_scheme=BIOES
done

Просмотреть файл

@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import os
import random
import numpy as np
import torch
def load_file(path: str, mode: str = "list-strip"):
if not os.path.exists(path):
return [] if not mode else ""
with open(path, "r", encoding="utf-8", newline="\n") as f:
if mode == "list-strip":
data = [ii.strip() for ii in f.readlines()]
elif mode == "str":
data = f.read()
elif mode == "list":
data = list(f.readlines())
elif mode == "json":
data = json.loads(f.read())
elif mode == "json-list":
data = [json.loads(ii) for ii in f.readlines()]
return data
def set_seed(seed, gpu_device):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if gpu_device > -1:
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True