Code for the DecomposedMetaNER paper at Findings of the ACL 2022 (#22)
Co-authored-by: Huiqiang Jiang <hjiang@microsoft.com>
This commit is contained in:
Родитель
9c8eedc5af
Коммит
dc14a4f1bf
|
@ -0,0 +1,142 @@
|
||||||
|
# Decomposed Meta-Learning for Few-Shot Named Entity Recognition
|
||||||
|
|
||||||
|
This repository contains the open-sourced official implementation of the paper:
|
||||||
|
|
||||||
|
[Decomposed Meta-Learning for Few-Shot Named Entity Recognition](https://arxiv.org/abs/2204.05751) (Findings of the ACL 2022).
|
||||||
|
|
||||||
|
_Tingting Ma, Huiqiang Jiang, Qianhui Wu, Tiejun Zhao, and Chin-Yew Lin_
|
||||||
|
|
||||||
|
If you find this repo helpful, please cite the following paper:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{ma2022decomposedmetaner,
|
||||||
|
title="{D}ecomposed {M}eta-{L}earning for {F}ew-{S}hot {N}amed {E}ntity {R}ecognition",
|
||||||
|
author={Tingting Ma and Huiqiang Jiang and Qianhui Wu and Tiejun Zhao and Chin-Yew Lin},
|
||||||
|
year={2022},
|
||||||
|
month={aug},
|
||||||
|
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics},
|
||||||
|
publisher={Association for Computational Linguistics},
|
||||||
|
url={https://arxiv.org/abs/2204.05751},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For any questions/comments, please feel free to open GitHub issues.
|
||||||
|
|
||||||
|
## 🎥 Overview
|
||||||
|
|
||||||
|
Few-shot named entity recognition (NER) systems aim at recognizing novel-class named entities based on only a few labeled examples. In this paper, we present a decomposed metalearning approach which addresses the problem of few-shot NER by sequentially tackling fewshot span detection and few-shot entity typing using meta-learning. In particular, we take the few-shot span detection as a sequence labeling problem and train the span detector by introducing the model-agnostic meta-learning (MAML) algorithm to find a good model parameter initialization that could fast adapt to new entity classes. For few-shot entity typing, we propose MAML-ProtoNet, i.e., MAML-enhanced prototypical networks to find a good embedding space that can better distinguish text span representations from different entity classes. Extensive experiments on various benchmarks show that our approach achieves superior performance over prior methods.
|
||||||
|
|
||||||
|
![image](./images/framework_DecomposedMetaNER.png)
|
||||||
|
|
||||||
|
## 🎯 Quick Start
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
- python 3.9
|
||||||
|
- pytorch 1.9.0+cu111
|
||||||
|
- [HuggingFace Transformers 4.10.0](https://github.com/huggingface/transformers)
|
||||||
|
- Few-NERD dataset
|
||||||
|
|
||||||
|
Other pip package show in `requirements.txt`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
The code may work on other python and pytorch version. However, all experiments were run in the above environment.
|
||||||
|
|
||||||
|
### Train and Evaluate
|
||||||
|
|
||||||
|
For _Linux_ machines,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
For _Windows_ machines,
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
call scripts\run.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
If you only want to predict in the trained model, you can get the model from [Azure blob](https://kcpapers.blob.core.windows.net/dmn-findings-acl-2022/dmn-all.zip).
|
||||||
|
|
||||||
|
Put the `models-{N}-{K}-{mode}` folder in `papers/DecomposedMetaNER` directory.
|
||||||
|
|
||||||
|
For _Linux_ machines,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/run_pred.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
For _Windows_ machines,
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
call scripts\run_pred.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🍯 Datasets
|
||||||
|
|
||||||
|
We use the following widely-used benchmark datasets for the experiments:
|
||||||
|
|
||||||
|
- Few-NERD [Ding et al., 2021](https://aclanthology.org/2021.acl-long.248) for Intra and Inter two mode few-shot NER;
|
||||||
|
- Cross-Dataset [Hou et al., 2020](https://www.aclweb.org/anthology/2020.acl-main.128) for four cross-domain few-shot NER;
|
||||||
|
|
||||||
|
The Few-NERD dataset is annotated with 8 coarse-grained and 66 fine-grained entity types. And the Cross-Dataset are annotated with 4, 11, 6, 18 entity types in difference domain. Each dataset is split into training, dev, and test sets.
|
||||||
|
|
||||||
|
All datasets in N-way K~2K shot setting and IO tagging scheme. In this repo, we don't publish any data due to the license. You can download them from their respective websites: [Few-NERD](https://cloud.tsinghua.edu.cn/f/8483dc1a34da4a34ab58/?dl=1), and [Cross-Dataset](https://atmahou.github.io/attachments/ACL2020data.zip).
|
||||||
|
And place them in the correct locations: `./`.
|
||||||
|
|
||||||
|
## 📋 Results
|
||||||
|
|
||||||
|
We report the few-shot NER results of the proposed DecomposedMetaNER on the 1~2 shot and 5~10 shot, alongside those reported by prior state-of-the-art methods.
|
||||||
|
|
||||||
|
### Few-NERD ACL Version
|
||||||
|
|
||||||
|
Available from [Ding et al., 2021](https://cloud.tsinghua.edu.cn/f/8483dc1a34da4a34ab58/?dl=1).
|
||||||
|
|
||||||
|
| | Intra 5-1 | Intra 10-1 | Intra 5-5 | Intra 10-5 | Inter 5-1 | Inter 10-1 | Inter 5-5 | Inter 10-5 |
|
||||||
|
| -------------------------------------------------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
|
||||||
|
| [ProtoBERT](https://aclanthology.org/2021.acl-long.248) | 23.45 ± 0.92 | 19.76 ± 0.59 | 41.93 ± 0.55 | 34.61 ± 0.59 | 44.44 ± 0.11 | 39.09 ± 0.87 | 58.80 ± 1.42 | 53.97 ± 0.38 |
|
||||||
|
| [NNShot](https://aclanthology.org/2021.acl-long.248) | 31.01 ± 1.21 | 21.88 ± 0.23 | 35.74 ± 2.36 | 27.67 ± 1.06 | 54.29 ± 0.40 | 46.98 ± 1.96 | 50.56 ± 3.33 | 50.00 ± 0.36 |
|
||||||
|
| [StructShot](https://aclanthology.org/2021.acl-long.248) | 35.92 ± 0.69 | 25.38 ± 0.84 | 38.83 ± 1.72 | 26.39 ± 2.59 | 57.33 ± 0.53 | 49.46 ± 0.53 | 57.16 ± 2.09 | 49.39 ± 1.77 |
|
||||||
|
| [CONTAINER](https://arxiv.org/abs/2109.07589) | 40.43 | 33.84 | 53.70 | 47.49 | 55.95 | 48.35 | 61.83 | 57.12 |
|
||||||
|
| [ESD](https://arxiv.org/abs/2109.13023v1) | 41.44 ± 1.16 | 32.29 ± 1.10 | 50.68 ± 0.94 | 42.92 ± 0.75 | 66.46 ± 0.49 | 59.95 ± 0.69 | **74.14** ± 0.80 | 67.91 ± 1.41 |
|
||||||
|
| **DecomposedMetaNER** | **52.04** ± 0.44 | **43.50** ± 0.59 | **63.23** ± 0.45 | **56.84** ± 0.14 | **68.77** ± 0.24 | **63.26** ± 0.40 | 71.62 ± 0.16 | **68.32** ± 0.10 |
|
||||||
|
|
||||||
|
### Few-NERD Arxiv Version
|
||||||
|
|
||||||
|
Available from [Ding et al., 2021](https://cloud.tsinghua.edu.cn/f/0e38bd108d7b49808cc4/?dl=1).
|
||||||
|
|
||||||
|
| | Intra 5-1 | Intra 10-1 | Intra 5-5 | Intra 10-5 | Inter 5-1 | Inter 10-1 | Inter 5-5 | Inter 10-5 |
|
||||||
|
| ---------------------------------------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
|
||||||
|
| [ProtoBERT](https://arxiv.org/abs/2105.07464) | 20.76 ± 0.84 | 15.05 ± 0.44 | 42.54 ± 0.94 | 35.40 ± 0.13 | 38.83 ± 1.49 | 32.45 ± 0.79 | 58.79 ± 0.44 | 52.92 ± 0.37 |
|
||||||
|
| [NNShot](https://arxiv.org/abs/2105.07464) | 25.78 ± 0.91 | 18.27 ± 0.41 | 36.18 ± 0.79 | 27.38 ± 0.53 | 47.24 ± 1.00 | 38.87 ± 0.21 | 55.64 ± 0.63 | 49.57 ± 2.73 |
|
||||||
|
| [StructShot](https://arxiv.org/abs/2105.07464) | 30.21 ± 0.90 | 21.03 ± 1.13 | 38.00 ± 1.29 | 26.42 ± 0.60 | 51.88 ± 0.69 | 43.34 ± 0.10 | 57.32 ± 0.63 | 49.57 ± 3.08 |
|
||||||
|
| [ESD](https://arxiv.org/abs/2109.13023) | 36.08 ± 1.60 | 30.00 ± 0.70 | 52.14 ± 1.50 | 42.15 ± 2.60 | 59.29 ± 1.25 | 52.16 ± 0.79 | 69.06 ± 0.80 | 64.00 ± 0.43 |
|
||||||
|
| **DecomposedMetaNER** | **49.48** ± 0.85 | **42.84** ± 0.46 | **62.92** ± 0.57 | **57.31** ± 0.25 | **64.75** ± 0.35 | **58.65** ± 0.43 | **71.49** ± 0.47 | **68.11** ± 0.05 |
|
||||||
|
|
||||||
|
### Cross-Dataset
|
||||||
|
|
||||||
|
| | News 1-shot | Wiki 1-shot | Social 1-shot | Mixed 1-shot | News 5-shot | Wiki 5-shot | Social 5-shot | Mixed 5-shot |
|
||||||
|
| ---------------------------------------------------------------------- | ---------------- | ---------------- | --------------- | ---------------- | ---------------- | ---------------- | ---------------- | ---------------- |
|
||||||
|
| [TransferBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 4.75 ± 1.42 | 0.57 ± 0.32 | 2.71 ± 0.72 | 3.46 ± 0.54 | 15.36 ± 2.81 | 3.62 ± 0.57 | 11.08 ± 0.57 | 35.49 ± 7.60 |
|
||||||
|
| [SimBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 19.22 ± 0.00 | 6.91 ± 0.00 | 5.18 ± 0.00 | 13.99 ± 0.00 | 32.01 ± 0.00 | 10.63 ± 0.00 | 8.20 ± 0.00 | 21.14 ± 0.00 |
|
||||||
|
| [Matching Network](https://www.aclweb.org/anthology/2020.acl-main.128) | 19.50 ± 0.35 | 4.73 ± 0.16 | 17.23 ± 2.75 | 15.06 ± 1.61 | 19.85 ± 0.74 | 5.58 ± 0.23 | 6.61 ± 1.75 | 8.08 ± 0.47 |
|
||||||
|
| [ProtoBERT](https://www.aclweb.org/anthology/2020.acl-main.128) | 32.49 ± 2.01 | 3.89 ± 0.24 | 10.68 ± 1.40 | 6.67 ± 0.46 | 50.06 ± 1.57 | 9.54 ± 0.44 | 17.26 ± 2.65 | 13.59 ± 1.61 |
|
||||||
|
| [L-TapNet+CDT](https://www.aclweb.org/anthology/2020.acl-main.128) | 44.30 ± 3.15 | 12.04 ± 0.65 | 20.80 ± 1.06 | 15.17 ± 1.25 | 45.35 ± 2.67 | 11.65 ± 2.34 | 23.30 ± 2.80 | 20.95 ± 2.81 |
|
||||||
|
| **DecomposedMetaNER** | **46.09** ± 0.44 | **17.54** ± 0.98 | **25.1** ± 0.24 | **34.13** ± 0.92 | **58.18** ± 0.87 | **31.36** ± 0.91 | **31.02** ± 1.28 | **45.55** ± 0.90 |
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||||
|
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||||
|
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||||
|
|
||||||
|
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||||
|
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||||
|
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||||
|
|
||||||
|
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||||
|
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||||
|
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,87 @@
|
||||||
|
{
|
||||||
|
"O": [
|
||||||
|
"O"
|
||||||
|
],
|
||||||
|
"art": [
|
||||||
|
"art-broadcastprogram",
|
||||||
|
"art-film",
|
||||||
|
"art-music",
|
||||||
|
"art-other",
|
||||||
|
"art-painting",
|
||||||
|
"art-writtenart"
|
||||||
|
],
|
||||||
|
"building": [
|
||||||
|
"building-airport",
|
||||||
|
"building-hospital",
|
||||||
|
"building-hotel",
|
||||||
|
"building-library",
|
||||||
|
"building-other",
|
||||||
|
"building-restaurant",
|
||||||
|
"building-sportsfacility",
|
||||||
|
"building-theater"
|
||||||
|
],
|
||||||
|
"event": [
|
||||||
|
"event-attack/battle/war/militaryconflict",
|
||||||
|
"event-disaster",
|
||||||
|
"event-election",
|
||||||
|
"event-other",
|
||||||
|
"event-protest",
|
||||||
|
"event-sportsevent"
|
||||||
|
],
|
||||||
|
"location": [
|
||||||
|
"location-GPE",
|
||||||
|
"location-bodiesofwater",
|
||||||
|
"location-island",
|
||||||
|
"location-mountain",
|
||||||
|
"location-other",
|
||||||
|
"location-park",
|
||||||
|
"location-road/railway/highway/transit"
|
||||||
|
],
|
||||||
|
"organization": [
|
||||||
|
"organization-company",
|
||||||
|
"organization-education",
|
||||||
|
"organization-government/governmentagency",
|
||||||
|
"organization-media/newspaper",
|
||||||
|
"organization-other",
|
||||||
|
"organization-politicalparty",
|
||||||
|
"organization-religion",
|
||||||
|
"organization-showorganization",
|
||||||
|
"organization-sportsleague",
|
||||||
|
"organization-sportsteam"
|
||||||
|
],
|
||||||
|
"other": [
|
||||||
|
"other-astronomything",
|
||||||
|
"other-award",
|
||||||
|
"other-biologything",
|
||||||
|
"other-chemicalthing",
|
||||||
|
"other-currency",
|
||||||
|
"other-disease",
|
||||||
|
"other-educationaldegree",
|
||||||
|
"other-god",
|
||||||
|
"other-language",
|
||||||
|
"other-law",
|
||||||
|
"other-livingthing",
|
||||||
|
"other-medical"
|
||||||
|
],
|
||||||
|
"person": [
|
||||||
|
"person-actor",
|
||||||
|
"person-artist/author",
|
||||||
|
"person-athlete",
|
||||||
|
"person-director",
|
||||||
|
"person-other",
|
||||||
|
"person-politician",
|
||||||
|
"person-scholar",
|
||||||
|
"person-soldier"
|
||||||
|
],
|
||||||
|
"product": [
|
||||||
|
"product-airplane",
|
||||||
|
"product-car",
|
||||||
|
"product-food",
|
||||||
|
"product-game",
|
||||||
|
"product-other",
|
||||||
|
"product-ship",
|
||||||
|
"product-software",
|
||||||
|
"product-train",
|
||||||
|
"product-weapon"
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
{
|
||||||
|
"O": [
|
||||||
|
"O"
|
||||||
|
],
|
||||||
|
"Wiki": [
|
||||||
|
"abstract",
|
||||||
|
"animal",
|
||||||
|
"event",
|
||||||
|
"object",
|
||||||
|
"organization",
|
||||||
|
"person",
|
||||||
|
"place",
|
||||||
|
"plant",
|
||||||
|
"quantity",
|
||||||
|
"substance",
|
||||||
|
"time"
|
||||||
|
],
|
||||||
|
"SocialMedia": [
|
||||||
|
"corporation",
|
||||||
|
"creative-work",
|
||||||
|
"group",
|
||||||
|
"location",
|
||||||
|
"person",
|
||||||
|
"product"
|
||||||
|
],
|
||||||
|
"News": [
|
||||||
|
"LOC",
|
||||||
|
"MISC",
|
||||||
|
"ORG",
|
||||||
|
"PER"
|
||||||
|
],
|
||||||
|
"OntoNotes": [
|
||||||
|
"CARDINAL",
|
||||||
|
"DATE",
|
||||||
|
"EVENT",
|
||||||
|
"FAC",
|
||||||
|
"GPE",
|
||||||
|
"LANGUAGE",
|
||||||
|
"LAW",
|
||||||
|
"LOC",
|
||||||
|
"MONEY",
|
||||||
|
"NORP",
|
||||||
|
"ORDINAL",
|
||||||
|
"ORG",
|
||||||
|
"PERCENT",
|
||||||
|
"PERSON",
|
||||||
|
"PRODUCT",
|
||||||
|
"QUANTITY",
|
||||||
|
"TIME",
|
||||||
|
"WORK_OF_ART"
|
||||||
|
]
|
||||||
|
}
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 345 KiB |
|
@ -0,0 +1,803 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
from modeling import BertForTokenClassification_
|
||||||
|
from transformers import CONFIG_NAME, PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME
|
||||||
|
from transformers import AdamW as BertAdam
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class Learner(nn.Module):
|
||||||
|
ignore_token_label_id = torch.nn.CrossEntropyLoss().ignore_index
|
||||||
|
pad_token_label_id = -1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bert_model,
|
||||||
|
label_list,
|
||||||
|
freeze_layer,
|
||||||
|
logger,
|
||||||
|
lr_meta,
|
||||||
|
lr_inner,
|
||||||
|
warmup_prop_meta,
|
||||||
|
warmup_prop_inner,
|
||||||
|
max_meta_steps,
|
||||||
|
model_dir="",
|
||||||
|
cache_dir="",
|
||||||
|
gpu_no=0,
|
||||||
|
py_alias="python",
|
||||||
|
args=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super(Learner, self).__init__()
|
||||||
|
|
||||||
|
self.lr_meta = lr_meta
|
||||||
|
self.lr_inner = lr_inner
|
||||||
|
self.warmup_prop_meta = warmup_prop_meta
|
||||||
|
self.warmup_prop_inner = warmup_prop_inner
|
||||||
|
self.max_meta_steps = max_meta_steps
|
||||||
|
|
||||||
|
self.bert_model = bert_model
|
||||||
|
self.label_list = label_list
|
||||||
|
self.py_alias = py_alias
|
||||||
|
self.entity_types = args.entity_types
|
||||||
|
self.is_debug = args.debug
|
||||||
|
self.train_mode = args.train_mode
|
||||||
|
self.eval_mode = args.eval_mode
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.args = args
|
||||||
|
self.freeze_layer = freeze_layer
|
||||||
|
|
||||||
|
num_labels = len(label_list)
|
||||||
|
|
||||||
|
# load model
|
||||||
|
if model_dir != "":
|
||||||
|
if self.eval_mode != "two-stage":
|
||||||
|
self.load_model(self.eval_mode)
|
||||||
|
else:
|
||||||
|
logger.info("********** Loading pre-trained model **********")
|
||||||
|
cache_dir = cache_dir if cache_dir else str(PYTORCH_PRETRAINED_BERT_CACHE)
|
||||||
|
self.model = BertForTokenClassification_.from_pretrained(
|
||||||
|
bert_model,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
num_labels=num_labels,
|
||||||
|
output_hidden_states=True,
|
||||||
|
).to(args.device)
|
||||||
|
|
||||||
|
if self.eval_mode != "two-stage":
|
||||||
|
self.model.set_config(
|
||||||
|
args.use_classify,
|
||||||
|
args.distance_mode,
|
||||||
|
args.similar_k,
|
||||||
|
args.shared_bert,
|
||||||
|
self.train_mode,
|
||||||
|
)
|
||||||
|
self.model.to(args.device)
|
||||||
|
self.layer_set()
|
||||||
|
|
||||||
|
def layer_set(self):
|
||||||
|
# layer freezing
|
||||||
|
no_grad_param_names = ["embeddings", "pooler"] + [
|
||||||
|
"layer.{}.".format(i) for i in range(self.freeze_layer)
|
||||||
|
]
|
||||||
|
logger.info("The frozen parameters are:")
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
if any(no_grad_pn in name for no_grad_pn in no_grad_param_names):
|
||||||
|
param.requires_grad = False
|
||||||
|
logger.info(" {}".format(name))
|
||||||
|
|
||||||
|
self.opt = BertAdam(self.get_optimizer_grouped_parameters(), lr=self.lr_meta)
|
||||||
|
self.scheduler = get_linear_schedule_with_warmup(
|
||||||
|
self.opt,
|
||||||
|
num_warmup_steps=int(self.max_meta_steps * self.warmup_prop_meta),
|
||||||
|
num_training_steps=self.max_meta_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_optimizer_grouped_parameters(self):
|
||||||
|
param_optimizer = list(self.model.named_parameters())
|
||||||
|
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in param_optimizer
|
||||||
|
if not any(nd in n for nd in no_decay) and p.requires_grad
|
||||||
|
],
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in param_optimizer
|
||||||
|
if any(nd in n for nd in no_decay) and p.requires_grad
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def get_names(self):
|
||||||
|
names = [n for n, p in self.model.named_parameters() if p.requires_grad]
|
||||||
|
return names
|
||||||
|
|
||||||
|
def get_params(self):
|
||||||
|
params = [p for p in self.model.parameters() if p.requires_grad]
|
||||||
|
return params
|
||||||
|
|
||||||
|
def load_weights(self, names, params):
|
||||||
|
model_params = self.model.state_dict()
|
||||||
|
for n, p in zip(names, params):
|
||||||
|
model_params[n].data.copy_(p.data)
|
||||||
|
|
||||||
|
def load_gradients(self, names, grads):
|
||||||
|
model_params = self.model.state_dict(keep_vars=True)
|
||||||
|
for n, g in zip(names, grads):
|
||||||
|
if model_params[n].grad is None:
|
||||||
|
continue
|
||||||
|
model_params[n].grad.data.add_(g.data) # accumulate
|
||||||
|
|
||||||
|
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
|
||||||
|
if schedule == "linear":
|
||||||
|
if progress < warmup:
|
||||||
|
lr *= progress / warmup
|
||||||
|
else:
|
||||||
|
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def inner_update(self, data_support, lr_curr, inner_steps, no_grad: bool = False):
|
||||||
|
inner_opt = BertAdam(self.get_optimizer_grouped_parameters(), lr=self.lr_inner)
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
for i in range(inner_steps):
|
||||||
|
inner_opt.param_groups[0]["lr"] = lr_curr
|
||||||
|
inner_opt.param_groups[1]["lr"] = lr_curr
|
||||||
|
|
||||||
|
inner_opt.zero_grad()
|
||||||
|
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||||
|
input_ids=data_support["input_ids"],
|
||||||
|
attention_mask=data_support["input_mask"],
|
||||||
|
token_type_ids=data_support["segment_ids"],
|
||||||
|
labels=data_support["label_ids"],
|
||||||
|
e_mask=data_support["e_mask"],
|
||||||
|
e_type_ids=data_support["e_type_ids"],
|
||||||
|
e_type_mask=data_support["e_type_mask"],
|
||||||
|
entity_types=self.entity_types,
|
||||||
|
is_update_type_embedding=True,
|
||||||
|
lambda_max_loss=self.args.inner_lambda_max_loss,
|
||||||
|
sim_k=self.args.inner_similar_k,
|
||||||
|
)
|
||||||
|
if loss is None:
|
||||||
|
loss = type_loss
|
||||||
|
elif type_loss is not None:
|
||||||
|
loss = loss + type_loss
|
||||||
|
if no_grad:
|
||||||
|
continue
|
||||||
|
loss.backward()
|
||||||
|
inner_opt.step()
|
||||||
|
|
||||||
|
return loss.item()
|
||||||
|
|
||||||
|
def forward_supervise(self, batch_query, batch_support, progress, inner_steps):
|
||||||
|
span_losses, type_losses = [], []
|
||||||
|
task_num = len(batch_query)
|
||||||
|
|
||||||
|
for task_id in range(task_num):
|
||||||
|
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||||
|
input_ids=batch_query[task_id]["input_ids"],
|
||||||
|
attention_mask=batch_query[task_id]["input_mask"],
|
||||||
|
token_type_ids=batch_query[task_id]["segment_ids"],
|
||||||
|
labels=batch_query[task_id]["label_ids"],
|
||||||
|
e_mask=batch_query[task_id]["e_mask"],
|
||||||
|
e_type_ids=batch_query[task_id]["e_type_ids"],
|
||||||
|
e_type_mask=batch_query[task_id]["e_type_mask"],
|
||||||
|
entity_types=self.entity_types,
|
||||||
|
lambda_max_loss=self.args.lambda_max_loss,
|
||||||
|
)
|
||||||
|
if loss is not None:
|
||||||
|
span_losses.append(loss.item())
|
||||||
|
if type_loss is not None:
|
||||||
|
type_losses.append(type_loss.item())
|
||||||
|
if loss is None:
|
||||||
|
loss = type_loss
|
||||||
|
elif type_loss is not None:
|
||||||
|
loss = loss + type_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
self.opt.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
self.model.zero_grad()
|
||||||
|
|
||||||
|
for task_id in range(task_num):
|
||||||
|
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||||
|
input_ids=batch_support[task_id]["input_ids"],
|
||||||
|
attention_mask=batch_support[task_id]["input_mask"],
|
||||||
|
token_type_ids=batch_support[task_id]["segment_ids"],
|
||||||
|
labels=batch_support[task_id]["label_ids"],
|
||||||
|
e_mask=batch_support[task_id]["e_mask"],
|
||||||
|
e_type_ids=batch_support[task_id]["e_type_ids"],
|
||||||
|
e_type_mask=batch_support[task_id]["e_type_mask"],
|
||||||
|
entity_types=self.entity_types,
|
||||||
|
lambda_max_loss=self.args.lambda_max_loss,
|
||||||
|
)
|
||||||
|
if loss is not None:
|
||||||
|
span_losses.append(loss.item())
|
||||||
|
if type_loss is not None:
|
||||||
|
type_losses.append(type_loss.item())
|
||||||
|
if loss is None:
|
||||||
|
loss = type_loss
|
||||||
|
elif type_loss is not None:
|
||||||
|
loss = loss + type_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
self.opt.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
self.model.zero_grad()
|
||||||
|
|
||||||
|
return (
|
||||||
|
np.mean(span_losses) if span_losses else 0,
|
||||||
|
np.mean(type_losses) if type_losses else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_meta(self, batch_query, batch_support, progress, inner_steps):
|
||||||
|
names = self.get_names()
|
||||||
|
params = self.get_params()
|
||||||
|
weights = deepcopy(params)
|
||||||
|
|
||||||
|
meta_grad = []
|
||||||
|
span_losses, type_losses = [], []
|
||||||
|
|
||||||
|
task_num = len(batch_query)
|
||||||
|
lr_inner = self.get_learning_rate(
|
||||||
|
self.lr_inner, progress, self.warmup_prop_inner
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute meta_grad of each task
|
||||||
|
for task_id in range(task_num):
|
||||||
|
self.inner_update(batch_support[task_id], lr_inner, inner_steps=inner_steps)
|
||||||
|
_, _, loss, type_loss = self.model.forward_wuqh(
|
||||||
|
input_ids=batch_query[task_id]["input_ids"],
|
||||||
|
attention_mask=batch_query[task_id]["input_mask"],
|
||||||
|
token_type_ids=batch_query[task_id]["segment_ids"],
|
||||||
|
labels=batch_query[task_id]["label_ids"],
|
||||||
|
e_mask=batch_query[task_id]["e_mask"],
|
||||||
|
e_type_ids=batch_query[task_id]["e_type_ids"],
|
||||||
|
e_type_mask=batch_query[task_id]["e_type_mask"],
|
||||||
|
entity_types=self.entity_types,
|
||||||
|
lambda_max_loss=self.args.lambda_max_loss,
|
||||||
|
)
|
||||||
|
if loss is not None:
|
||||||
|
span_losses.append(loss.item())
|
||||||
|
if type_loss is not None:
|
||||||
|
type_losses.append(type_loss.item())
|
||||||
|
if loss is None:
|
||||||
|
loss = type_loss
|
||||||
|
elif type_loss is not None:
|
||||||
|
loss = loss + type_loss
|
||||||
|
grad = torch.autograd.grad(loss, params)
|
||||||
|
meta_grad.append(grad)
|
||||||
|
|
||||||
|
self.load_weights(names, weights)
|
||||||
|
|
||||||
|
# accumulate grads of all tasks to param.grad
|
||||||
|
self.opt.zero_grad()
|
||||||
|
|
||||||
|
# similar to backward()
|
||||||
|
for g in meta_grad:
|
||||||
|
self.load_gradients(names, g)
|
||||||
|
self.opt.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
return (
|
||||||
|
np.mean(span_losses) if span_losses else 0,
|
||||||
|
np.mean(type_losses) if type_losses else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------- Evaluation -------------------------------------- #
|
||||||
|
def write_result(self, words, y_true, y_pred, tmp_fn):
|
||||||
|
assert len(y_pred) == len(y_true)
|
||||||
|
with open(tmp_fn, "w", encoding="utf-8") as fw:
|
||||||
|
for i, sent in enumerate(y_true):
|
||||||
|
for j, word in enumerate(sent):
|
||||||
|
fw.write("{} {} {}\n".format(words[i][j], word, y_pred[i][j]))
|
||||||
|
fw.write("\n")
|
||||||
|
|
||||||
|
def batch_test(self, data):
|
||||||
|
N = data["input_ids"].shape[0]
|
||||||
|
B = 16
|
||||||
|
BATCH_KEY = [
|
||||||
|
"input_ids",
|
||||||
|
"attention_mask",
|
||||||
|
"token_type_ids",
|
||||||
|
"labels",
|
||||||
|
"e_mask",
|
||||||
|
"e_type_ids",
|
||||||
|
"e_type_mask",
|
||||||
|
]
|
||||||
|
|
||||||
|
logits, e_logits, loss, type_loss = [], [], 0, 0
|
||||||
|
for i in range((N - 1) // B + 1):
|
||||||
|
tmp = {
|
||||||
|
ii: jj if ii not in BATCH_KEY else jj[i * B : (i + 1) * B]
|
||||||
|
for ii, jj in data.items()
|
||||||
|
}
|
||||||
|
tmp_l, tmp_el, tmp_loss, tmp_eval_type_loss = self.model.forward_wuqh(**tmp)
|
||||||
|
if tmp_l is not None:
|
||||||
|
logits.extend(tmp_l.detach().cpu().numpy())
|
||||||
|
if tmp_el is not None:
|
||||||
|
e_logits.extend(tmp_el.detach().cpu().numpy())
|
||||||
|
if tmp_loss is not None:
|
||||||
|
loss += tmp_loss
|
||||||
|
if tmp_eval_type_loss is not None:
|
||||||
|
type_loss += tmp_eval_type_loss
|
||||||
|
return logits, e_logits, loss, type_loss
|
||||||
|
|
||||||
|
def evaluate_meta_(
|
||||||
|
self,
|
||||||
|
corpus,
|
||||||
|
logger,
|
||||||
|
lr,
|
||||||
|
steps,
|
||||||
|
mode,
|
||||||
|
set_type,
|
||||||
|
type_steps: int = None,
|
||||||
|
viterbi_decoder=None,
|
||||||
|
):
|
||||||
|
if not type_steps:
|
||||||
|
type_steps = steps
|
||||||
|
if self.is_debug:
|
||||||
|
self.save_model(self.args.result_dir, "begin", self.args.max_seq_len, "all")
|
||||||
|
|
||||||
|
logger.info("Begin first Stage.")
|
||||||
|
if self.eval_mode == "two-stage":
|
||||||
|
self.load_model("span")
|
||||||
|
names = self.get_names()
|
||||||
|
params = self.get_params()
|
||||||
|
weights = deepcopy(params)
|
||||||
|
|
||||||
|
eval_loss = 0.0
|
||||||
|
nb_eval_steps = 0
|
||||||
|
preds = None
|
||||||
|
t_tmp = time.time()
|
||||||
|
targets, predes, spans, lss, type_preds, type_g = [], [], [], [], [], []
|
||||||
|
for item_id in range(corpus.n_total):
|
||||||
|
eval_query, eval_support = corpus.get_batch_meta(
|
||||||
|
batch_size=1, shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# train on support examples
|
||||||
|
if not self.args.nouse_inner_ft:
|
||||||
|
self.inner_update(eval_support[0], lr_curr=lr, inner_steps=steps)
|
||||||
|
|
||||||
|
# eval on pseudo query examples (test example)
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits, e_ls, tmp_eval_loss, _ = self.batch_test(
|
||||||
|
{
|
||||||
|
"input_ids": eval_query[0]["input_ids"],
|
||||||
|
"attention_mask": eval_query[0]["input_mask"],
|
||||||
|
"token_type_ids": eval_query[0]["segment_ids"],
|
||||||
|
"labels": eval_query[0]["label_ids"],
|
||||||
|
"e_mask": eval_query[0]["e_mask"],
|
||||||
|
"e_type_ids": eval_query[0]["e_type_ids"],
|
||||||
|
"e_type_mask": eval_query[0]["e_type_mask"],
|
||||||
|
"entity_types": self.entity_types,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
lss.append(logits)
|
||||||
|
if self.model.train_mode != "type":
|
||||||
|
eval_loss += tmp_eval_loss
|
||||||
|
if self.model.train_mode != "span":
|
||||||
|
type_pred, type_ground = self.eval_typing(
|
||||||
|
e_ls, eval_query[0]["e_type_mask"]
|
||||||
|
)
|
||||||
|
type_preds.append(type_pred)
|
||||||
|
type_g.append(type_ground)
|
||||||
|
else:
|
||||||
|
e_mask, e_type_ids, e_type_mask, result, types = self.decode_span(
|
||||||
|
logits,
|
||||||
|
eval_query[0]["label_ids"],
|
||||||
|
eval_query[0]["types"],
|
||||||
|
eval_query[0]["input_mask"],
|
||||||
|
viterbi_decoder,
|
||||||
|
)
|
||||||
|
targets.extend(eval_query[0]["entities"])
|
||||||
|
spans.extend(result)
|
||||||
|
|
||||||
|
nb_eval_steps += 1
|
||||||
|
|
||||||
|
self.load_weights(names, weights)
|
||||||
|
if item_id % 200 == 0:
|
||||||
|
logger.info(
|
||||||
|
" To sentence {}/{}. Time: {}sec".format(
|
||||||
|
item_id, corpus.n_total, time.time() - t_tmp
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Begin second Stage.")
|
||||||
|
if self.eval_mode == "two-stage":
|
||||||
|
self.load_model("type")
|
||||||
|
names = self.get_names()
|
||||||
|
params = self.get_params()
|
||||||
|
weights = deepcopy(params)
|
||||||
|
|
||||||
|
if self.train_mode == "add":
|
||||||
|
for item_id in range(corpus.n_total):
|
||||||
|
eval_query, eval_support = corpus.get_batch_meta(
|
||||||
|
batch_size=1, shuffle=False
|
||||||
|
)
|
||||||
|
logits = lss[item_id]
|
||||||
|
|
||||||
|
# train on support examples
|
||||||
|
self.inner_update(eval_support[0], lr_curr=lr, inner_steps=type_steps)
|
||||||
|
|
||||||
|
# eval on pseudo query examples (test example)
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
e_mask, e_type_ids, e_type_mask, result, types = self.decode_span(
|
||||||
|
logits,
|
||||||
|
eval_query[0]["label_ids"],
|
||||||
|
eval_query[0]["types"],
|
||||||
|
eval_query[0]["input_mask"],
|
||||||
|
viterbi_decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, e_logits, _, tmp_eval_type_loss = self.batch_test(
|
||||||
|
{
|
||||||
|
"input_ids": eval_query[0]["input_ids"],
|
||||||
|
"attention_mask": eval_query[0]["input_mask"],
|
||||||
|
"token_type_ids": eval_query[0]["segment_ids"],
|
||||||
|
"labels": eval_query[0]["label_ids"],
|
||||||
|
"e_mask": e_mask,
|
||||||
|
"e_type_ids": e_type_ids,
|
||||||
|
"e_type_mask": e_type_mask,
|
||||||
|
"entity_types": self.entity_types,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_loss += tmp_eval_type_loss
|
||||||
|
|
||||||
|
if self.eval_mode == "two-stage":
|
||||||
|
logits, e_ls, tmp_eval_loss, _ = self.batch_test(
|
||||||
|
{
|
||||||
|
"input_ids": eval_query[0]["input_ids"],
|
||||||
|
"attention_mask": eval_query[0]["input_mask"],
|
||||||
|
"token_type_ids": eval_query[0]["segment_ids"],
|
||||||
|
"labels": eval_query[0]["label_ids"],
|
||||||
|
"e_mask": eval_query[0]["e_mask"],
|
||||||
|
"e_type_ids": eval_query[0]["e_type_ids"],
|
||||||
|
"e_type_mask": eval_query[0]["e_type_mask"],
|
||||||
|
"entity_types": self.entity_types,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type_pred, type_ground = self.eval_typing(
|
||||||
|
e_ls, eval_query[0]["e_type_mask"]
|
||||||
|
)
|
||||||
|
type_preds.append(type_pred)
|
||||||
|
type_g.append(type_ground)
|
||||||
|
taregt, p = self.decode_entity(
|
||||||
|
e_logits, result, types, eval_query[0]["entities"]
|
||||||
|
)
|
||||||
|
predes.extend(p)
|
||||||
|
|
||||||
|
self.load_weights(names, weights)
|
||||||
|
if item_id % 200 == 0:
|
||||||
|
logger.info(
|
||||||
|
" To sentence {}/{}. Time: {}sec".format(
|
||||||
|
item_id, corpus.n_total, time.time() - t_tmp
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
|
if self.is_debug:
|
||||||
|
joblib.dump([targets, predes, spans], "debug/f1.pkl")
|
||||||
|
store_dir = self.args.model_dir if self.args.model_dir else self.args.result_dir
|
||||||
|
joblib.dump(
|
||||||
|
[targets, predes, spans],
|
||||||
|
"{}/{}_{}_preds.pkl".format(store_dir, "all", set_type),
|
||||||
|
)
|
||||||
|
joblib.dump(
|
||||||
|
[type_g, type_preds],
|
||||||
|
"{}/{}_{}_preds.pkl".format(store_dir, "typing", set_type),
|
||||||
|
)
|
||||||
|
pred = [[jj[:-1] for jj in ii] for ii in predes]
|
||||||
|
p, r, f1 = self.cacl_f1(targets, pred)
|
||||||
|
pred = [
|
||||||
|
[jj[:-1] for jj in ii if jj[-1] > self.args.type_threshold] for ii in predes
|
||||||
|
]
|
||||||
|
p_t, r_t, f1_t = self.cacl_f1(targets, pred)
|
||||||
|
|
||||||
|
span_p, span_r, span_f1 = self.cacl_f1(
|
||||||
|
[[(jj[0], jj[1]) for jj in ii] for ii in targets], spans
|
||||||
|
)
|
||||||
|
type_p, type_r, type_f1 = self.cacl_f1(type_g, type_preds)
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"loss": eval_loss,
|
||||||
|
"precision": p,
|
||||||
|
"recall": r,
|
||||||
|
"f1": f1,
|
||||||
|
"span_p": span_p,
|
||||||
|
"span_r": span_r,
|
||||||
|
"span_f1": span_f1,
|
||||||
|
"type_p": type_p,
|
||||||
|
"type_r": type_r,
|
||||||
|
"type_f1": type_f1,
|
||||||
|
"precision_threshold": p_t,
|
||||||
|
"recall_threshold": r_t,
|
||||||
|
"f1_threshold": f1_t,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("***** Eval results %s-%s *****", mode, set_type)
|
||||||
|
for key in sorted(results.keys()):
|
||||||
|
logger.info(" %s = %s", key, str(results[key]))
|
||||||
|
logger.info(
|
||||||
|
"%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f",
|
||||||
|
results["precision"] * 100,
|
||||||
|
results["recall"] * 100,
|
||||||
|
results["f1"] * 100,
|
||||||
|
results["span_p"] * 100,
|
||||||
|
results["span_r"] * 100,
|
||||||
|
results["span_f1"] * 100,
|
||||||
|
results["type_p"] * 100,
|
||||||
|
results["type_r"] * 100,
|
||||||
|
results["type_f1"] * 100,
|
||||||
|
results["precision_threshold"] * 100,
|
||||||
|
results["recall_threshold"] * 100,
|
||||||
|
results["f1_threshold"] * 100,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results, preds
|
||||||
|
|
||||||
|
def save_model(self, result_dir, fn_prefix, max_seq_len, mode: str = "all"):
|
||||||
|
# Save a trained model and the associated configuration
|
||||||
|
model_to_save = (
|
||||||
|
self.model.module if hasattr(self.model, "module") else self.model
|
||||||
|
) # Only save the model it-self
|
||||||
|
output_model_file = os.path.join(
|
||||||
|
result_dir, "{}_{}_{}".format(fn_prefix, mode, WEIGHTS_NAME)
|
||||||
|
)
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
output_config_file = os.path.join(result_dir, CONFIG_NAME)
|
||||||
|
with open(output_config_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(model_to_save.config.to_json_string())
|
||||||
|
label_map = {i: label for i, label in enumerate(self.label_list, 1)}
|
||||||
|
model_config = {
|
||||||
|
"bert_model": self.bert_model,
|
||||||
|
"do_lower": False,
|
||||||
|
"max_seq_length": max_seq_len,
|
||||||
|
"num_labels": len(self.label_list) + 1,
|
||||||
|
"label_map": label_map,
|
||||||
|
}
|
||||||
|
json.dump(
|
||||||
|
model_config,
|
||||||
|
open(
|
||||||
|
os.path.join(result_dir, f"{mode}-model_config.json"),
|
||||||
|
"w",
|
||||||
|
encoding="utf-8",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if mode == "type":
|
||||||
|
joblib.dump(
|
||||||
|
self.entity_types, os.path.join(result_dir, "type_embedding.pkl")
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_best_model(self, result_dir: str, mode: str):
|
||||||
|
output_model_file = os.path.join(result_dir, "en_tmp_{}".format(WEIGHTS_NAME))
|
||||||
|
config_name = os.path.join(result_dir, "tmp-model_config.json")
|
||||||
|
shutil.copy(output_model_file, output_model_file.replace("tmp", mode))
|
||||||
|
shutil.copy(config_name, config_name.replace("tmp", mode))
|
||||||
|
|
||||||
|
def load_model(self, mode: str = "all"):
|
||||||
|
if not self.model_dir:
|
||||||
|
return
|
||||||
|
model_dir = self.model_dir
|
||||||
|
logger.info(f"********** Loading saved {mode} model **********")
|
||||||
|
output_model_file = os.path.join(
|
||||||
|
model_dir, "en_{}_{}".format(mode, WEIGHTS_NAME)
|
||||||
|
)
|
||||||
|
self.model = BertForTokenClassification_.from_pretrained(
|
||||||
|
self.bert_model, num_labels=len(self.label_list), output_hidden_states=True
|
||||||
|
)
|
||||||
|
self.model.set_config(
|
||||||
|
self.args.use_classify,
|
||||||
|
self.args.distance_mode,
|
||||||
|
self.args.similar_k,
|
||||||
|
self.args.shared_bert,
|
||||||
|
mode,
|
||||||
|
)
|
||||||
|
self.model.to(self.args.device)
|
||||||
|
self.model.load_state_dict(torch.load(output_model_file, map_location="cuda"))
|
||||||
|
self.layer_set()
|
||||||
|
|
||||||
|
def decode_span(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
types,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
viterbi_decoder=None,
|
||||||
|
):
|
||||||
|
if self.is_debug:
|
||||||
|
joblib.dump([logits, target, self.label_list], "debug/span.pkl")
|
||||||
|
device = target.device
|
||||||
|
K = max([len(ii) for ii in types])
|
||||||
|
if viterbi_decoder:
|
||||||
|
N = target.shape[0]
|
||||||
|
B = 16
|
||||||
|
result = []
|
||||||
|
for i in range((N - 1) // B + 1):
|
||||||
|
tmp_logits = torch.tensor(logits[i * B : (i + 1) * B]).to(target.device)
|
||||||
|
if len(tmp_logits.shape) == 2:
|
||||||
|
tmp_logits = tmp_logits.unsqueeze(0)
|
||||||
|
tmp_target = target[i * B : (i + 1) * B]
|
||||||
|
log_probs = nn.functional.log_softmax(
|
||||||
|
tmp_logits.detach(), dim=-1
|
||||||
|
) # batch_size x max_seq_len x n_labels
|
||||||
|
pred_labels = viterbi_decoder.forward(
|
||||||
|
log_probs, mask[i * B : (i + 1) * B], tmp_target
|
||||||
|
)
|
||||||
|
|
||||||
|
for ii, jj in zip(pred_labels, tmp_target.detach().cpu().numpy()):
|
||||||
|
left, right, tmp = 0, 0, []
|
||||||
|
while right < len(jj) and jj[right] == self.ignore_token_label_id:
|
||||||
|
tmp.append(-1)
|
||||||
|
right += 1
|
||||||
|
while left < len(ii):
|
||||||
|
tmp.append(ii[left])
|
||||||
|
left += 1
|
||||||
|
right += 1
|
||||||
|
while (
|
||||||
|
right < len(jj) and jj[right] == self.ignore_token_label_id
|
||||||
|
):
|
||||||
|
tmp.append(-1)
|
||||||
|
right += 1
|
||||||
|
result.append(tmp)
|
||||||
|
target = target.detach().cpu().numpy()
|
||||||
|
B, T = target.shape
|
||||||
|
if not viterbi_decoder:
|
||||||
|
logits = logits.detach().cpu().numpy()
|
||||||
|
result = np.argmax(logits, -1)
|
||||||
|
|
||||||
|
if self.label_list == ["O", "B", "I"]:
|
||||||
|
res = []
|
||||||
|
for ii in range(B):
|
||||||
|
tmp, idx = [], 0
|
||||||
|
max_pad = T - 1
|
||||||
|
while (
|
||||||
|
max_pad > 0 and target[ii][max_pad - 1] == self.pad_token_label_id
|
||||||
|
):
|
||||||
|
max_pad -= 1
|
||||||
|
while idx < max_pad:
|
||||||
|
if target[ii][idx] == self.ignore_token_label_id or (
|
||||||
|
result[ii][idx] != 1
|
||||||
|
):
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
e = idx
|
||||||
|
while e < max_pad - 1 and (
|
||||||
|
target[ii][e + 1] == self.ignore_token_label_id
|
||||||
|
or result[ii][e + 1] in [self.ignore_token_label_id, 2]
|
||||||
|
):
|
||||||
|
e += 1
|
||||||
|
tmp.append((idx, e))
|
||||||
|
idx = e + 1
|
||||||
|
res.append(tmp)
|
||||||
|
elif self.label_list == ["O", "B", "I", "E", "S"]:
|
||||||
|
res = []
|
||||||
|
for ii in range(B):
|
||||||
|
tmp, idx = [], 0
|
||||||
|
max_pad = T - 1
|
||||||
|
while (
|
||||||
|
max_pad > 0 and target[ii][max_pad - 1] == self.pad_token_label_id
|
||||||
|
):
|
||||||
|
max_pad -= 1
|
||||||
|
while idx < max_pad:
|
||||||
|
if target[ii][idx] == self.ignore_token_label_id or (
|
||||||
|
result[ii][idx] not in [1, 4]
|
||||||
|
):
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
e = idx
|
||||||
|
while (
|
||||||
|
e < max_pad - 1
|
||||||
|
and result[ii][e] not in [3, 4]
|
||||||
|
and (
|
||||||
|
target[ii][e + 1] == self.ignore_token_label_id
|
||||||
|
or result[ii][e + 1] in [self.ignore_token_label_id, 2, 3]
|
||||||
|
)
|
||||||
|
):
|
||||||
|
e += 1
|
||||||
|
if e < max_pad and result[ii][e] in [3, 4]:
|
||||||
|
while (
|
||||||
|
e < max_pad - 1
|
||||||
|
and target[ii][e + 1] == self.ignore_token_label_id
|
||||||
|
):
|
||||||
|
e += 1
|
||||||
|
tmp.append((idx, e))
|
||||||
|
idx = e + 1
|
||||||
|
res.append(tmp)
|
||||||
|
M = max([len(ii) for ii in res])
|
||||||
|
e_mask = np.zeros((B, M, T), np.int8)
|
||||||
|
e_type_mask = np.zeros((B, M, K), np.int8)
|
||||||
|
e_type_ids = np.zeros((B, M, K), np.int)
|
||||||
|
for ii in range(B):
|
||||||
|
for idx, (s, e) in enumerate(res[ii]):
|
||||||
|
e_mask[ii][idx][s : e + 1] = 1
|
||||||
|
types_set = types[ii]
|
||||||
|
if len(res[ii]):
|
||||||
|
e_type_ids[ii, : len(res[ii]), : len(types_set)] = [types_set] * len(
|
||||||
|
res[ii]
|
||||||
|
)
|
||||||
|
e_type_mask[ii, : len(res[ii]), : len(types_set)] = np.ones(
|
||||||
|
(len(res[ii]), len(types_set))
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
torch.tensor(e_mask).to(device),
|
||||||
|
torch.tensor(e_type_ids, dtype=torch.long).to(device),
|
||||||
|
torch.tensor(e_type_mask).to(device),
|
||||||
|
res,
|
||||||
|
types,
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode_entity(self, e_logits, result, types, entities):
|
||||||
|
if self.is_debug:
|
||||||
|
joblib.dump([e_logits, result, types, entities], "debug/e.pkl")
|
||||||
|
target, preds = entities, []
|
||||||
|
B = len(e_logits)
|
||||||
|
logits = e_logits
|
||||||
|
|
||||||
|
for ii in range(B):
|
||||||
|
tmp = []
|
||||||
|
tmp_res = result[ii]
|
||||||
|
tmp_types = types[ii]
|
||||||
|
for jj in range(len(tmp_res)):
|
||||||
|
lg = logits[ii][jj, : len(tmp_types)]
|
||||||
|
tmp.append((*tmp_res[jj], tmp_types[np.argmax(lg)], lg[np.argmax(lg)]))
|
||||||
|
preds.append(tmp)
|
||||||
|
return target, preds
|
||||||
|
|
||||||
|
def cacl_f1(self, targets: list, predes: list):
|
||||||
|
tp, fp, fn = 0, 0, 0
|
||||||
|
for ii, jj in zip(targets, predes):
|
||||||
|
ii, jj = set(ii), set(jj)
|
||||||
|
same = ii - (ii - jj)
|
||||||
|
tp += len(same)
|
||||||
|
fn += len(ii - jj)
|
||||||
|
fp += len(jj - ii)
|
||||||
|
p = tp / (fp + tp + 1e-10)
|
||||||
|
r = tp / (fn + tp + 1e-10)
|
||||||
|
return p, r, 2 * p * r / (p + r + 1e-10)
|
||||||
|
|
||||||
|
def eval_typing(self, e_logits, e_type_mask):
|
||||||
|
e_logits = e_logits
|
||||||
|
e_type_mask = e_type_mask.detach().cpu().numpy()
|
||||||
|
if self.is_debug:
|
||||||
|
joblib.dump([e_logits, e_type_mask], "debug/typing.pkl")
|
||||||
|
|
||||||
|
N = len(e_logits)
|
||||||
|
B_S = 16
|
||||||
|
res = []
|
||||||
|
for i in range((N - 1) // B_S + 1):
|
||||||
|
tmp_e_logits = np.argmax(e_logits[i * B_S : (i + 1) * B_S], -1)
|
||||||
|
B, M = tmp_e_logits.shape
|
||||||
|
tmp_e_type_mask = e_type_mask[i * B_S : (i + 1) * B_S][:, :M, 0]
|
||||||
|
res.extend(tmp_e_logits[tmp_e_type_mask == 1])
|
||||||
|
ground = [0] * len(res)
|
||||||
|
return enumerate(res), enumerate(ground)
|
|
@ -0,0 +1,661 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
from learner import Learner
|
||||||
|
from modeling import ViterbiDecoder
|
||||||
|
from preprocessor import Corpus, EntityTypes
|
||||||
|
from utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
|
def get_label_list(args):
|
||||||
|
# prepare dataset
|
||||||
|
if args.tagging_scheme == "BIOES":
|
||||||
|
label_list = ["O", "B", "I", "E", "S"]
|
||||||
|
elif args.tagging_scheme == "BIO":
|
||||||
|
label_list = ["O", "B", "I"]
|
||||||
|
else:
|
||||||
|
label_list = ["O", "I"]
|
||||||
|
return label_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_path(agrs, train_mode: str):
|
||||||
|
assert args.dataset in [
|
||||||
|
"FewNERD",
|
||||||
|
"Domain",
|
||||||
|
"Domain2",
|
||||||
|
], f"Dataset: {args.dataset} Not Support."
|
||||||
|
if args.dataset == "FewNERD":
|
||||||
|
return os.path.join(
|
||||||
|
args.data_path,
|
||||||
|
args.mode,
|
||||||
|
"{}_{}_{}.jsonl".format(train_mode, args.N, args.K),
|
||||||
|
)
|
||||||
|
elif args.dataset == "Domain":
|
||||||
|
if train_mode == "dev":
|
||||||
|
train_mode = "valid"
|
||||||
|
text = "_shot_5" if args.K == 5 else ""
|
||||||
|
replace_text = "-" if args.K == 5 else "_"
|
||||||
|
return os.path.join(
|
||||||
|
"ACL2020data",
|
||||||
|
"xval_ner{}".format(text),
|
||||||
|
"ner_{}_{}{}.json".format(train_mode, args.N, text).replace(
|
||||||
|
"_", replace_text
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif args.dataset == "Domain2":
|
||||||
|
if train_mode == "train":
|
||||||
|
return os.path.join("domain2", "{}_10_5.json".format(train_mode))
|
||||||
|
return os.path.join(
|
||||||
|
"domain2", "{}_{}_{}.json".format(train_mode, args.mode, args.K)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_type_embedding(learner, args):
|
||||||
|
logger.info("********** Replace trained type embedding **********")
|
||||||
|
entity_types = joblib.load(os.path.join(args.result_dir, "type_embedding.pkl"))
|
||||||
|
N, H = entity_types.types_embedding.weight.data.shape
|
||||||
|
for ii in range(N):
|
||||||
|
learner.models.embeddings.word_embeddings.weight.data[
|
||||||
|
ii + 1
|
||||||
|
] = entity_types.types_embedding.weight.data[ii]
|
||||||
|
|
||||||
|
|
||||||
|
def train_meta(args):
|
||||||
|
logger.info("********** Scheme: Meta Learning **********")
|
||||||
|
label_list = get_label_list(args)
|
||||||
|
|
||||||
|
valid_data_path = get_data_path(args, "dev")
|
||||||
|
valid_corpus = Corpus(
|
||||||
|
logger,
|
||||||
|
valid_data_path,
|
||||||
|
args.bert_model,
|
||||||
|
args.max_seq_len,
|
||||||
|
label_list,
|
||||||
|
args.entity_types,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=False,
|
||||||
|
viterbi=args.viterbi,
|
||||||
|
tagging=args.tagging_scheme,
|
||||||
|
device=args.device,
|
||||||
|
concat_types=args.concat_types,
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
test_corpus = valid_corpus
|
||||||
|
train_corpus = valid_corpus
|
||||||
|
else:
|
||||||
|
train_data_path = get_data_path(args, "train")
|
||||||
|
train_corpus = Corpus(
|
||||||
|
logger,
|
||||||
|
train_data_path,
|
||||||
|
args.bert_model,
|
||||||
|
args.max_seq_len,
|
||||||
|
label_list,
|
||||||
|
args.entity_types,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=True,
|
||||||
|
tagging=args.tagging_scheme,
|
||||||
|
device=args.device,
|
||||||
|
concat_types=args.concat_types,
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.ignore_eval_test:
|
||||||
|
test_data_path = get_data_path(args, "test")
|
||||||
|
test_corpus = Corpus(
|
||||||
|
logger,
|
||||||
|
test_data_path,
|
||||||
|
args.bert_model,
|
||||||
|
args.max_seq_len,
|
||||||
|
label_list,
|
||||||
|
args.entity_types,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=False,
|
||||||
|
viterbi=args.viterbi,
|
||||||
|
tagging=args.tagging_scheme,
|
||||||
|
device=args.device,
|
||||||
|
concat_types=args.concat_types,
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
learner = Learner(
|
||||||
|
args.bert_model,
|
||||||
|
label_list,
|
||||||
|
args.freeze_layer,
|
||||||
|
logger,
|
||||||
|
args.lr_meta,
|
||||||
|
args.lr_inner,
|
||||||
|
args.warmup_prop_meta,
|
||||||
|
args.warmup_prop_inner,
|
||||||
|
args.max_meta_steps,
|
||||||
|
py_alias=args.py_alias,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "embedding" in args.concat_types:
|
||||||
|
replace_type_embedding(learner, args)
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
F1_valid_best = {ii: -1.0 for ii in ["all", "type", "span"]}
|
||||||
|
F1_test = -1.0
|
||||||
|
best_step, protect_step = -1.0, 100 if args.train_mode != "type" else 50
|
||||||
|
|
||||||
|
for step in range(args.max_meta_steps):
|
||||||
|
progress = 1.0 * step / args.max_meta_steps
|
||||||
|
|
||||||
|
batch_query, batch_support = train_corpus.get_batch_meta(
|
||||||
|
batch_size=args.inner_size
|
||||||
|
) # (batch_size=32)
|
||||||
|
if args.use_supervise:
|
||||||
|
span_loss, type_loss = learner.forward_supervise(
|
||||||
|
batch_query,
|
||||||
|
batch_support,
|
||||||
|
progress=progress,
|
||||||
|
inner_steps=args.inner_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
span_loss, type_loss = learner.forward_meta(
|
||||||
|
batch_query,
|
||||||
|
batch_support,
|
||||||
|
progress=progress,
|
||||||
|
inner_steps=args.inner_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
if step % 20 == 0:
|
||||||
|
logger.info(
|
||||||
|
"Step: {}/{}, span loss = {:.6f}, type loss = {:.6f}, time = {:.2f}s.".format(
|
||||||
|
step, args.max_meta_steps, span_loss, type_loss, time.time() - t
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if step % args.eval_every_meta_steps == 0 and step > protect_step:
|
||||||
|
logger.info("********** Scheme: evaluate - [valid] **********")
|
||||||
|
result_valid, predictions_valid = test(args, learner, valid_corpus, "valid")
|
||||||
|
|
||||||
|
F1_valid = result_valid["f1"]
|
||||||
|
is_best = False
|
||||||
|
if F1_valid > F1_valid_best["all"]:
|
||||||
|
logger.info("===> Best Valid F1: {}".format(F1_valid))
|
||||||
|
logger.info(" Saving model...")
|
||||||
|
learner.save_model(args.result_dir, "en", args.max_seq_len, "all")
|
||||||
|
F1_valid_best["all"] = F1_valid
|
||||||
|
best_step = step
|
||||||
|
is_best = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
result_valid["span_f1"] > F1_valid_best["span"]
|
||||||
|
and args.train_mode != "type"
|
||||||
|
):
|
||||||
|
F1_valid_best["span"] = result_valid["span_f1"]
|
||||||
|
learner.save_model(args.result_dir, "en", args.max_seq_len, "span")
|
||||||
|
logger.info("Best Span Store {}".format(step))
|
||||||
|
is_best = True
|
||||||
|
if (
|
||||||
|
result_valid["type_f1"] > F1_valid_best["type"]
|
||||||
|
and args.train_mode != "span"
|
||||||
|
):
|
||||||
|
F1_valid_best["type"] = result_valid["type_f1"]
|
||||||
|
learner.save_model(args.result_dir, "en", args.max_seq_len, "type")
|
||||||
|
logger.info("Best Type Store {}".format(step))
|
||||||
|
is_best = True
|
||||||
|
|
||||||
|
if is_best and not args.ignore_eval_test:
|
||||||
|
logger.info("********** Scheme: evaluate - [test] **********")
|
||||||
|
result_test, predictions_test = test(args, learner, test_corpus, "test")
|
||||||
|
|
||||||
|
F1_test = result_test["f1"]
|
||||||
|
logger.info(
|
||||||
|
"Best Valid F1: {}, Step: {}".format(F1_valid_best, best_step)
|
||||||
|
)
|
||||||
|
logger.info("Test F1: {}".format(F1_test))
|
||||||
|
|
||||||
|
|
||||||
|
def test(args, learner, corpus, types: str):
|
||||||
|
if corpus.viterbi != "none":
|
||||||
|
id2label = corpus.id2label
|
||||||
|
transition_matrix = corpus.transition_matrix
|
||||||
|
if args.viterbi == "soft":
|
||||||
|
label_list = get_label_list(args)
|
||||||
|
train_data_path = get_data_path(args, "train")
|
||||||
|
train_corpus = Corpus(
|
||||||
|
logger,
|
||||||
|
train_data_path,
|
||||||
|
args.bert_model,
|
||||||
|
args.max_seq_len,
|
||||||
|
label_list,
|
||||||
|
args.entity_types,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=True,
|
||||||
|
tagging=args.tagging_scheme,
|
||||||
|
viterbi="soft",
|
||||||
|
device=args.device,
|
||||||
|
concat_types=args.concat_types,
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
|
id2label = train_corpus.id2label
|
||||||
|
transition_matrix = train_corpus.transition_matrix
|
||||||
|
|
||||||
|
viterbi_decoder = ViterbiDecoder(id2label, transition_matrix)
|
||||||
|
else:
|
||||||
|
viterbi_decoder = None
|
||||||
|
result_test, predictions = learner.evaluate_meta_(
|
||||||
|
corpus,
|
||||||
|
logger,
|
||||||
|
lr=args.lr_finetune,
|
||||||
|
steps=args.max_ft_steps,
|
||||||
|
mode=args.mode,
|
||||||
|
set_type=types,
|
||||||
|
type_steps=args.max_type_ft_steps,
|
||||||
|
viterbi_decoder=viterbi_decoder,
|
||||||
|
)
|
||||||
|
return result_test, predictions
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args):
|
||||||
|
logger.info("********** Scheme: Meta Test **********")
|
||||||
|
label_list = get_label_list(args)
|
||||||
|
|
||||||
|
valid_data_path = get_data_path(args, "dev")
|
||||||
|
valid_corpus = Corpus(
|
||||||
|
logger,
|
||||||
|
valid_data_path,
|
||||||
|
args.bert_model,
|
||||||
|
args.max_seq_len,
|
||||||
|
label_list,
|
||||||
|
args.entity_types,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=False,
|
||||||
|
tagging=args.tagging_scheme,
|
||||||
|
viterbi=args.viterbi,
|
||||||
|
concat_types=args.concat_types,
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data_path = get_data_path(args, "test")
|
||||||
|
test_corpus = Corpus(
|
||||||
|
logger,
|
||||||
|
test_data_path,
|
||||||
|
args.bert_model,
|
||||||
|
args.max_seq_len,
|
||||||
|
label_list,
|
||||||
|
args.entity_types,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=False,
|
||||||
|
tagging=args.tagging_scheme,
|
||||||
|
viterbi=args.viterbi,
|
||||||
|
concat_types=args.concat_types,
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
learner = Learner(
|
||||||
|
args.bert_model,
|
||||||
|
label_list,
|
||||||
|
args.freeze_layer,
|
||||||
|
logger,
|
||||||
|
args.lr_meta,
|
||||||
|
args.lr_inner,
|
||||||
|
args.warmup_prop_meta,
|
||||||
|
args.warmup_prop_inner,
|
||||||
|
args.max_meta_steps,
|
||||||
|
model_dir=args.model_dir,
|
||||||
|
py_alias=args.py_alias,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("********** Scheme: evaluate - [valid] **********")
|
||||||
|
test(args, learner, valid_corpus, "valid")
|
||||||
|
|
||||||
|
logger.info("********** Scheme: evaluate - [test] **********")
|
||||||
|
test(args, learner, test_corpus, "test")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bpe(args):
|
||||||
|
def convert_base(train_mode: str):
|
||||||
|
data_path = get_data_path(args, train_mode)
|
||||||
|
corpus = Corpus(
|
||||||
|
logger,
|
||||||
|
data_path,
|
||||||
|
args.bert_model,
|
||||||
|
args.max_seq_len,
|
||||||
|
label_list,
|
||||||
|
args.entity_types,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=False,
|
||||||
|
tagging=args.tagging_scheme,
|
||||||
|
viterbi=args.viterbi,
|
||||||
|
concat_types=args.concat_types,
|
||||||
|
dataset=args.dataset,
|
||||||
|
device=args.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for seed in [171, 354, 550, 667, 985]:
|
||||||
|
path = os.path.join(
|
||||||
|
args.model_dir,
|
||||||
|
f"all_{train_mode if train_mode == 'test' else 'valid'}_preds.pkl",
|
||||||
|
).replace("171", str(seed))
|
||||||
|
data = joblib.load(path)
|
||||||
|
if len(data) == 3:
|
||||||
|
spans = data[-1]
|
||||||
|
else:
|
||||||
|
spans = [[jj[:-2] for jj in ii] for ii in data[-1]]
|
||||||
|
target = [[jj[:-1] for jj in ii] for ii in data[0]]
|
||||||
|
|
||||||
|
res = corpus._decoder_bpe_index(spans)
|
||||||
|
target = corpus._decoder_bpe_index(target)
|
||||||
|
with open(
|
||||||
|
f"preds/{args.mode}-{args.N}way{args.K}shot-seed{seed}-{train_mode}.jsonl",
|
||||||
|
"w",
|
||||||
|
) as f:
|
||||||
|
json.dump(res, f)
|
||||||
|
if seed != 171:
|
||||||
|
continue
|
||||||
|
with open(
|
||||||
|
f"preds/{args.mode}-{args.N}way{args.K}shot-seed{seed}-{train_mode}_golden.jsonl",
|
||||||
|
"w",
|
||||||
|
) as f:
|
||||||
|
json.dump(target, f)
|
||||||
|
|
||||||
|
logger.info("********** Scheme: Convert BPE **********")
|
||||||
|
os.makedirs("preds", exist_ok=True)
|
||||||
|
label_list = get_label_list(args)
|
||||||
|
convert_base("dev")
|
||||||
|
convert_base("test")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def my_bool(s):
|
||||||
|
return s != "False"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset", type=str, default="FewNERD", help="FewNERD or Domain"
|
||||||
|
)
|
||||||
|
parser.add_argument("--N", type=int, default=5)
|
||||||
|
parser.add_argument("--K", type=int, default=1)
|
||||||
|
parser.add_argument("--mode", type=str, default="intra")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_only",
|
||||||
|
action="store_true",
|
||||||
|
help="if true, will load the trained model and run test only",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--convert_bpe",
|
||||||
|
action="store_true",
|
||||||
|
help="if true, will convert the bpe encode result to word level.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--tagging_scheme", type=str, default="BIO", help="BIO or IO")
|
||||||
|
# dataset settings
|
||||||
|
parser.add_argument("--data_path", type=str, default="episode-data")
|
||||||
|
parser.add_argument(
|
||||||
|
"--result_dir", type=str, help="where to save the result.", default="test"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_dir", type=str, help="dir name of a trained model", default=""
|
||||||
|
)
|
||||||
|
|
||||||
|
# meta-test setting
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_finetune",
|
||||||
|
type=float,
|
||||||
|
help="finetune learning rate, used in [test_meta]. and [k_shot setting]",
|
||||||
|
default=3e-5,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_ft_steps", type=int, help="maximal steps token for fine-tune.", default=3
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_type_ft_steps",
|
||||||
|
type=int,
|
||||||
|
help="maximal steps token for entity type fine-tune.",
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# meta-train setting
|
||||||
|
parser.add_argument(
|
||||||
|
"--inner_steps",
|
||||||
|
type=int,
|
||||||
|
help="every ** inner update for one meta-update",
|
||||||
|
default=2,
|
||||||
|
) # ===>
|
||||||
|
parser.add_argument(
|
||||||
|
"--inner_size",
|
||||||
|
type=int,
|
||||||
|
help="[number of tasks] for one meta-update",
|
||||||
|
default=32,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_inner", type=float, help="inner loop learning rate", default=3e-5
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_meta", type=float, help="meta learning rate", default=3e-5
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_meta_steps",
|
||||||
|
type=int,
|
||||||
|
help="maximal steps token for meta training.",
|
||||||
|
default=5001,
|
||||||
|
)
|
||||||
|
parser.add_argument("--eval_every_meta_steps", type=int, default=500)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_prop_inner",
|
||||||
|
type=int,
|
||||||
|
help="warm up proportion for inner update",
|
||||||
|
default=0.1,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_prop_meta",
|
||||||
|
type=int,
|
||||||
|
help="warm up proportion for meta update",
|
||||||
|
default=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# permanent params
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze_layer", type=int, help="the layer of mBERT to be frozen", default=0
|
||||||
|
)
|
||||||
|
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bert_model",
|
||||||
|
type=str,
|
||||||
|
default="bert-base-uncased",
|
||||||
|
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
|
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||||
|
"bert-base-multilingual-cased, bert-base-chinese.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--viterbi", type=str, default="hard", help="hard or soft or None"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--concat_types", type=str, default="past", help="past or before or None"
|
||||||
|
)
|
||||||
|
# expt setting
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, help="random seed to reproduce the result.", default=667
|
||||||
|
)
|
||||||
|
parser.add_argument("--gpu_device", type=int, help="GPU device num", default=0)
|
||||||
|
parser.add_argument("--py_alias", type=str, help="python alias", default="python")
|
||||||
|
parser.add_argument(
|
||||||
|
"--types_path",
|
||||||
|
type=str,
|
||||||
|
help="the path of entities types",
|
||||||
|
default="data/entity_types.json",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--negative_types_number",
|
||||||
|
type=int,
|
||||||
|
help="the number of negative types in each batch",
|
||||||
|
default=4,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--negative_mode", type=str, help="the mode of negative types", default="batch"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--types_mode", type=str, help="the embedding mode of type span", default="cls"
|
||||||
|
)
|
||||||
|
parser.add_argument("--name", type=str, help="the name of experiment", default="")
|
||||||
|
parser.add_argument("--debug", help="debug mode", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--init_type_embedding_from_bert",
|
||||||
|
action="store_true",
|
||||||
|
help="initialization type embedding from BERT",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_classify",
|
||||||
|
action="store_true",
|
||||||
|
help="use classifier after entity embedding",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance_mode", type=str, help="embedding distance mode", default="cos"
|
||||||
|
)
|
||||||
|
parser.add_argument("--similar_k", type=float, help="cosine similar k", default=10)
|
||||||
|
parser.add_argument("--shared_bert", default=True, type=my_bool, help="shared BERT")
|
||||||
|
parser.add_argument("--train_mode", default="add", type=str, help="add, span, type")
|
||||||
|
parser.add_argument("--eval_mode", default="add", type=str, help="add, two-stage")
|
||||||
|
parser.add_argument(
|
||||||
|
"--type_threshold", default=2.5, type=float, help="typing decoder threshold"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lambda_max_loss", default=2.0, type=float, help="span max loss lambda"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--inner_lambda_max_loss", default=2.0, type=float, help="span max loss lambda"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--inner_similar_k", type=float, help="cosine similar k", default=10
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ignore_eval_test", help="if/not eval in test", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nouse_inner_ft",
|
||||||
|
action="store_true",
|
||||||
|
help="if true, will convert the bpe encode result to word level.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_supervise",
|
||||||
|
action="store_true",
|
||||||
|
help="if true, will convert the bpe encode result to word level.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.negative_types_number = args.N - 1
|
||||||
|
if "Domain" in args.dataset:
|
||||||
|
args.types_path = "data/entity_types_domain.json"
|
||||||
|
|
||||||
|
# setup random seed
|
||||||
|
set_seed(args.seed, args.gpu_device)
|
||||||
|
|
||||||
|
# set up GPU device
|
||||||
|
device = torch.device("cuda")
|
||||||
|
torch.cuda.set_device(args.gpu_device)
|
||||||
|
|
||||||
|
# setup logger settings
|
||||||
|
if args.test_only:
|
||||||
|
top_dir = "models-{}-{}-{}".format(args.N, args.K, args.mode)
|
||||||
|
args.model_dir = "{}-innerSteps_{}-innerSize_{}-lrInner_{}-lrMeta_{}-maxSteps_{}-seed_{}{}".format(
|
||||||
|
args.bert_model,
|
||||||
|
args.inner_steps,
|
||||||
|
args.inner_size,
|
||||||
|
args.lr_inner,
|
||||||
|
args.lr_meta,
|
||||||
|
args.max_meta_steps,
|
||||||
|
args.seed,
|
||||||
|
"-name_{}".format(args.name) if args.name else "",
|
||||||
|
)
|
||||||
|
args.model_dir = os.path.join(top_dir, args.model_dir)
|
||||||
|
if not os.path.exists(args.model_dir):
|
||||||
|
if args.convert_bpe:
|
||||||
|
os.makedirs(args.model_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError("Model directory does not exist!")
|
||||||
|
fh = logging.FileHandler(
|
||||||
|
"{}/log-test-ftLr_{}-ftSteps_{}.txt".format(
|
||||||
|
args.model_dir, args.lr_finetune, args.max_ft_steps
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
top_dir = "models-{}-{}-{}".format(args.N, args.K, args.mode)
|
||||||
|
args.result_dir = "{}-innerSteps_{}-innerSize_{}-lrInner_{}-lrMeta_{}-maxSteps_{}-seed_{}{}".format(
|
||||||
|
args.bert_model,
|
||||||
|
args.inner_steps,
|
||||||
|
args.inner_size,
|
||||||
|
args.lr_inner,
|
||||||
|
args.lr_meta,
|
||||||
|
args.max_meta_steps,
|
||||||
|
args.seed,
|
||||||
|
"-name_{}".format(args.name) if args.name else "",
|
||||||
|
)
|
||||||
|
os.makedirs(top_dir, exist_ok=True)
|
||||||
|
if not os.path.exists("{}/{}".format(top_dir, args.result_dir)):
|
||||||
|
os.mkdir("{}/{}".format(top_dir, args.result_dir))
|
||||||
|
elif args.result_dir != "test":
|
||||||
|
pass
|
||||||
|
|
||||||
|
args.result_dir = "{}/{}".format(top_dir, args.result_dir)
|
||||||
|
fh = logging.FileHandler("{}/log-training.txt".format(args.result_dir))
|
||||||
|
|
||||||
|
# dump args
|
||||||
|
with Path("{}/args-train.json".format(args.result_dir)).open(
|
||||||
|
"w", encoding="utf-8"
|
||||||
|
) as fw:
|
||||||
|
json.dump(vars(args), fw, indent=4, sort_keys=True)
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
os.makedirs("debug", exist_ok=True)
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s %(levelname)s: - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
fh.setLevel(logging.INFO)
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
ch = logging.StreamHandler()
|
||||||
|
ch.setLevel(logging.INFO)
|
||||||
|
ch.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger.addHandler(ch)
|
||||||
|
logger.addHandler(fh)
|
||||||
|
|
||||||
|
args.device = device
|
||||||
|
logger.info(f"Using Device {device}")
|
||||||
|
args.entity_types = EntityTypes(
|
||||||
|
args.types_path, args.negative_types_number, args.negative_mode
|
||||||
|
)
|
||||||
|
args.entity_types.build_types_embedding(
|
||||||
|
args.bert_model,
|
||||||
|
True,
|
||||||
|
args.device,
|
||||||
|
args.types_mode,
|
||||||
|
args.init_type_embedding_from_bert,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.convert_bpe:
|
||||||
|
convert_bpe(args)
|
||||||
|
elif args.test_only:
|
||||||
|
if args.model_dir == "":
|
||||||
|
raise ValueError("NULL model directory!")
|
||||||
|
evaluate(args)
|
||||||
|
else:
|
||||||
|
if args.model_dir != "":
|
||||||
|
raise ValueError("Model directory should be NULL!")
|
||||||
|
train_meta(args)
|
|
@ -0,0 +1,270 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from transformers import BertForTokenClassification
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class BertForTokenClassification_(BertForTokenClassification):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(BertForTokenClassification_, self).__init__(*args, **kwargs)
|
||||||
|
self.input_size = 768
|
||||||
|
self.span_loss = nn.functional.cross_entropy
|
||||||
|
self.type_loss = nn.functional.cross_entropy
|
||||||
|
self.dropout = nn.Dropout(p=0.1)
|
||||||
|
self.log_softmax = nn.functional.log_softmax
|
||||||
|
|
||||||
|
def set_config(
|
||||||
|
self,
|
||||||
|
use_classify: bool = False,
|
||||||
|
distance_mode: str = "cos",
|
||||||
|
similar_k: float = 30,
|
||||||
|
shared_bert: bool = True,
|
||||||
|
train_mode: str = "add",
|
||||||
|
):
|
||||||
|
self.use_classify = use_classify
|
||||||
|
self.distance_mode = distance_mode
|
||||||
|
self.similar_k = similar_k
|
||||||
|
self.shared_bert = shared_bert
|
||||||
|
self.train_mode = train_mode
|
||||||
|
if train_mode == "type":
|
||||||
|
self.classifier = None
|
||||||
|
|
||||||
|
if train_mode != "span":
|
||||||
|
self.ln = nn.LayerNorm(768, 1e-5, True)
|
||||||
|
if use_classify:
|
||||||
|
# self.type_classify = nn.Linear(self.input_size, self.input_size)
|
||||||
|
self.type_classify = nn.Sequential(
|
||||||
|
nn.Linear(self.input_size, self.input_size * 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(self.input_size * 2, self.input_size),
|
||||||
|
)
|
||||||
|
if self.distance_mode != "cos":
|
||||||
|
self.dis_cls = nn.Sequential(
|
||||||
|
nn.Linear(self.input_size * 3, self.input_size),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(self.input_size, 2),
|
||||||
|
)
|
||||||
|
config = {
|
||||||
|
"use_classify": use_classify,
|
||||||
|
"distance_mode": distance_mode,
|
||||||
|
"similar_k": similar_k,
|
||||||
|
"shared_bert": shared_bert,
|
||||||
|
"train_mode": train_mode,
|
||||||
|
}
|
||||||
|
logger.info(f"Model Setting: {config}")
|
||||||
|
if not shared_bert:
|
||||||
|
self.bert2 = deepcopy(self.bert)
|
||||||
|
|
||||||
|
def forward_wuqh(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
labels=None,
|
||||||
|
e_mask=None,
|
||||||
|
e_type_ids=None,
|
||||||
|
e_type_mask=None,
|
||||||
|
entity_types=None,
|
||||||
|
entity_mode: str = "mean",
|
||||||
|
is_update_type_embedding: bool = False,
|
||||||
|
lambda_max_loss: float = 0.0,
|
||||||
|
sim_k: float = 0,
|
||||||
|
):
|
||||||
|
max_len = (attention_mask != 0).max(0)[0].nonzero(as_tuple=False)[-1].item() + 1
|
||||||
|
input_ids = input_ids[:, :max_len]
|
||||||
|
attention_mask = attention_mask[:, :max_len].type(torch.int8)
|
||||||
|
token_type_ids = token_type_ids[:, :max_len]
|
||||||
|
labels = labels[:, :max_len]
|
||||||
|
output = self.bert(
|
||||||
|
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
||||||
|
)
|
||||||
|
sequence_output = self.dropout(output[0])
|
||||||
|
|
||||||
|
if self.train_mode != "type":
|
||||||
|
logits = self.classifier(
|
||||||
|
sequence_output
|
||||||
|
) # batch_size x seq_len x num_labels
|
||||||
|
else:
|
||||||
|
logits = None
|
||||||
|
if not self.shared_bert:
|
||||||
|
output2 = self.bert2(
|
||||||
|
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
|
||||||
|
)
|
||||||
|
sequence_output2 = self.dropout(output2[0])
|
||||||
|
|
||||||
|
if e_type_ids is not None and self.train_mode != "span":
|
||||||
|
if e_type_mask.sum() != 0:
|
||||||
|
M = (e_type_mask[:, :, 0] != 0).max(0)[0].nonzero(as_tuple=False)[
|
||||||
|
-1
|
||||||
|
].item() + 1
|
||||||
|
else:
|
||||||
|
M = 1
|
||||||
|
e_mask = e_mask[:, :M, :max_len].type(torch.int8)
|
||||||
|
e_type_ids = e_type_ids[:, :M, :]
|
||||||
|
e_type_mask = e_type_mask[:, :M, :].type(torch.int8)
|
||||||
|
B, M, K = e_type_ids.shape
|
||||||
|
e_out = self.get_enity_hidden(
|
||||||
|
sequence_output if self.shared_bert else sequence_output2,
|
||||||
|
e_mask,
|
||||||
|
entity_mode,
|
||||||
|
)
|
||||||
|
if self.use_classify:
|
||||||
|
e_out = self.type_classify(e_out)
|
||||||
|
e_out = self.ln(e_out) # batch_size x max_entity_num x hidden_size
|
||||||
|
if is_update_type_embedding:
|
||||||
|
entity_types.update_type_embedding(e_out, e_type_ids, e_type_mask)
|
||||||
|
e_out = e_out.unsqueeze(2).expand(B, M, K, -1)
|
||||||
|
types = self.get_types_embedding(
|
||||||
|
e_type_ids, entity_types
|
||||||
|
) # batch_size x max_entity_num x K x hidden_size
|
||||||
|
|
||||||
|
if self.distance_mode == "cat":
|
||||||
|
e_types = torch.cat([e_out, types, (e_out - types).abs()], -1)
|
||||||
|
e_types = self.dis_cls(e_types)
|
||||||
|
e_types = e_types[:, :, :0]
|
||||||
|
elif self.distance_mode == "l2":
|
||||||
|
e_types = -(torch.pow(e_out - types, 2)).sum(-1)
|
||||||
|
elif self.distance_mode == "cos":
|
||||||
|
sim_k = sim_k if sim_k else self.similar_k
|
||||||
|
e_types = sim_k * (e_out * types).sum(-1) / 768
|
||||||
|
|
||||||
|
e_logits = e_types
|
||||||
|
if M:
|
||||||
|
em = e_type_mask.clone()
|
||||||
|
em[em.sum(-1) == 0] = 1
|
||||||
|
e = e_types * em
|
||||||
|
e_type_label = torch.zeros((B, M)).to(e_types.device)
|
||||||
|
type_loss = self.calc_loss(
|
||||||
|
self.type_loss, e, e_type_label, e_type_mask[:, :, 0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
type_loss = torch.tensor(0).to(sequence_output.device)
|
||||||
|
else:
|
||||||
|
e_logits, type_loss = None, None
|
||||||
|
|
||||||
|
if labels is not None and self.train_mode != "type":
|
||||||
|
# Only keep active parts of the loss
|
||||||
|
loss_fct = CrossEntropyLoss(reduction="none")
|
||||||
|
B, M, T = logits.shape
|
||||||
|
if attention_mask is not None:
|
||||||
|
active_loss = attention_mask.view(-1) == 1
|
||||||
|
active_logits = logits.reshape(-1, self.num_labels)[active_loss]
|
||||||
|
active_labels = labels.reshape(-1)[active_loss]
|
||||||
|
base_loss = loss_fct(active_logits, active_labels)
|
||||||
|
loss = torch.mean(base_loss)
|
||||||
|
|
||||||
|
# max-loss
|
||||||
|
if lambda_max_loss > 0:
|
||||||
|
active_loss = active_loss.view(B, M)
|
||||||
|
active_max = []
|
||||||
|
start_id = 0
|
||||||
|
for i in range(B):
|
||||||
|
sent_len = torch.sum(active_loss[i])
|
||||||
|
end_id = start_id + sent_len
|
||||||
|
active_max.append(torch.max(base_loss[start_id:end_id]))
|
||||||
|
start_id = end_id
|
||||||
|
|
||||||
|
loss += lambda_max_loss * torch.mean(torch.stack(active_max))
|
||||||
|
else:
|
||||||
|
raise ValueError("Miss attention mask!")
|
||||||
|
else:
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
return logits, e_logits, loss, type_loss
|
||||||
|
|
||||||
|
def get_enity_hidden(
|
||||||
|
self, hidden: torch.Tensor, e_mask: torch.Tensor, entity_mode: str
|
||||||
|
):
|
||||||
|
B, M, T = e_mask.shape
|
||||||
|
e_out = hidden.unsqueeze(1).expand(B, M, T, -1) * e_mask.unsqueeze(
|
||||||
|
-1
|
||||||
|
) # batch_size x max_entity_num x seq_len x hidden_size
|
||||||
|
if entity_mode == "mean":
|
||||||
|
return e_out.sum(2) / (
|
||||||
|
e_mask.sum(-1).unsqueeze(-1) + 1e-30
|
||||||
|
) # batch_size x max_entity_num x hidden_size
|
||||||
|
|
||||||
|
def get_types_embedding(self, e_type_ids: torch.Tensor, entity_types):
|
||||||
|
return entity_types.get_types_embedding(e_type_ids)
|
||||||
|
|
||||||
|
def calc_loss(self, loss_fn, preds, target, mask=None):
|
||||||
|
target = target.reshape(-1)
|
||||||
|
preds += 1e-10
|
||||||
|
preds = preds.reshape(-1, preds.shape[-1])
|
||||||
|
ce_loss = loss_fn(preds, target.long(), reduction="none")
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.reshape(-1)
|
||||||
|
ce_loss = ce_loss * mask
|
||||||
|
return ce_loss.sum() / (mask.sum() + 1e-10)
|
||||||
|
return ce_loss.sum() / (target.sum() + 1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
class ViterbiDecoder(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id2label,
|
||||||
|
transition_matrix,
|
||||||
|
ignore_token_label_id=torch.nn.CrossEntropyLoss().ignore_index,
|
||||||
|
):
|
||||||
|
self.id2label = id2label
|
||||||
|
self.n_labels = len(id2label)
|
||||||
|
self.transitions = transition_matrix
|
||||||
|
self.ignore_token_label_id = ignore_token_label_id
|
||||||
|
|
||||||
|
def forward(self, logprobs, attention_mask, label_ids):
|
||||||
|
# probs: batch_size x max_seq_len x n_labels
|
||||||
|
batch_size, max_seq_len, n_labels = logprobs.size()
|
||||||
|
attention_mask = attention_mask[:, :max_seq_len]
|
||||||
|
label_ids = label_ids[:, :max_seq_len]
|
||||||
|
|
||||||
|
active_tokens = (attention_mask == 1) & (
|
||||||
|
label_ids != self.ignore_token_label_id
|
||||||
|
)
|
||||||
|
if n_labels != self.n_labels:
|
||||||
|
raise ValueError("Labels do not match!")
|
||||||
|
|
||||||
|
label_seqs = []
|
||||||
|
|
||||||
|
for idx in range(batch_size):
|
||||||
|
logprob_i = logprobs[idx, :, :][
|
||||||
|
active_tokens[idx]
|
||||||
|
] # seq_len(active) x n_labels
|
||||||
|
|
||||||
|
back_pointers = []
|
||||||
|
|
||||||
|
forward_var = logprob_i[0] # n_labels
|
||||||
|
|
||||||
|
for j in range(1, len(logprob_i)): # for tag_feat in feat:
|
||||||
|
next_label_var = forward_var + self.transitions # n_labels x n_labels
|
||||||
|
viterbivars_t, bptrs_t = torch.max(next_label_var, dim=1) # n_labels
|
||||||
|
|
||||||
|
logp_j = logprob_i[j] # n_labels
|
||||||
|
forward_var = viterbivars_t + logp_j # n_labels
|
||||||
|
bptrs_t = bptrs_t.cpu().numpy().tolist()
|
||||||
|
back_pointers.append(bptrs_t)
|
||||||
|
|
||||||
|
path_score, best_label_id = torch.max(forward_var, dim=-1)
|
||||||
|
best_label_id = best_label_id.item()
|
||||||
|
best_path = [best_label_id]
|
||||||
|
|
||||||
|
for bptrs_t in reversed(back_pointers):
|
||||||
|
best_label_id = bptrs_t[best_label_id]
|
||||||
|
best_path.append(best_label_id)
|
||||||
|
|
||||||
|
if len(best_path) != len(logprob_i):
|
||||||
|
raise ValueError("Number of labels doesn't match!")
|
||||||
|
|
||||||
|
best_path.reverse()
|
||||||
|
label_seqs.append(best_path)
|
||||||
|
|
||||||
|
return label_seqs
|
|
@ -0,0 +1,830 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import bisect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers import BertModel, BertTokenizer
|
||||||
|
from utils import load_file
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityTypes(object):
|
||||||
|
def __init__(self, types_path: str, negative_types_number: int, negative_mode: str):
|
||||||
|
self.types = {}
|
||||||
|
self.types_map = {}
|
||||||
|
self.O_id = 0
|
||||||
|
self.types_embedding = None
|
||||||
|
self.negative_mode = negative_mode
|
||||||
|
self.load_entity_types(types_path)
|
||||||
|
|
||||||
|
def load_entity_types(self, types_path: str):
|
||||||
|
self.types = load_file(types_path, "json")
|
||||||
|
types_list = sorted([jj for ii in self.types.values() for jj in ii])
|
||||||
|
self.types_map = {jj: ii for ii, jj in enumerate(types_list)}
|
||||||
|
self.O_id = self.types_map["O"]
|
||||||
|
self.types_list = types_list
|
||||||
|
logger.info("Load %d entity types from %s.", len(types_list), types_path)
|
||||||
|
|
||||||
|
def build_types_embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
do_lower_case: bool,
|
||||||
|
device,
|
||||||
|
types_mode: str = "cls",
|
||||||
|
init_type_embedding_from_bert: bool = False,
|
||||||
|
):
|
||||||
|
types_list = sorted([jj for ii in self.types.values() for jj in ii])
|
||||||
|
if init_type_embedding_from_bert:
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(
|
||||||
|
model, do_lower_case=do_lower_case
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = [
|
||||||
|
[tokenizer.cls_token_id]
|
||||||
|
+ tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ii))
|
||||||
|
+ [tokenizer.sep_token_id]
|
||||||
|
for ii in types_list
|
||||||
|
]
|
||||||
|
token_max = max([len(ii) for ii in tokens])
|
||||||
|
mask = [[1] * len(ii) + [0] * (token_max - len(ii)) for ii in tokens]
|
||||||
|
ids = [
|
||||||
|
ii + [tokenizer.pad_token_id] * (token_max - len(ii)) for ii in tokens
|
||||||
|
]
|
||||||
|
mask = torch.tensor(np.array(mask), dtype=torch.long).to(device)
|
||||||
|
ids = torch.tensor(np.array(ids), dtype=torch.long).to(
|
||||||
|
device
|
||||||
|
) # len(type_list) x token_max
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained(model).to(device)
|
||||||
|
outs = model(ids, mask)
|
||||||
|
else:
|
||||||
|
outs = [0, torch.rand((len(types_list), 768)).to(device)]
|
||||||
|
self.types_embedding = nn.Embedding(*outs[1].shape).to(device)
|
||||||
|
if types_mode.lower() == "cls":
|
||||||
|
self.types_embedding.weight = nn.Parameter(outs[1])
|
||||||
|
self.types_embedding.requires_grad = False
|
||||||
|
logger.info("Built the types embedding.")
|
||||||
|
|
||||||
|
def generate_negative_types(
|
||||||
|
self, labels: list, types: list, negative_types_number: int
|
||||||
|
):
|
||||||
|
N = len(labels)
|
||||||
|
data = np.zeros((N, 1 + negative_types_number), np.int)
|
||||||
|
if self.negative_mode == "batch":
|
||||||
|
batch_labels = set(types)
|
||||||
|
if negative_types_number > len(batch_labels):
|
||||||
|
other = list(
|
||||||
|
set(range(len(self.types_map))) - batch_labels - set([self.O_id])
|
||||||
|
)
|
||||||
|
other_size = negative_types_number - len(batch_labels)
|
||||||
|
else:
|
||||||
|
other, other_size = [], 0
|
||||||
|
b_size = min(len(batch_labels) - 1, negative_types_number)
|
||||||
|
o_set = [self.O_id] if negative_types_number > len(batch_labels) - 1 else []
|
||||||
|
for idx, l in enumerate(labels):
|
||||||
|
data[idx][0] = l
|
||||||
|
data[idx][1:] = np.concatenate(
|
||||||
|
[
|
||||||
|
np.random.choice(list(batch_labels - set([l])), b_size, False),
|
||||||
|
o_set,
|
||||||
|
np.random.choice(other, other_size, False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_types_embedding(self, labels: torch.Tensor):
|
||||||
|
return self.types_embedding(labels)
|
||||||
|
|
||||||
|
def update_type_embedding(self, e_out, e_type_ids, e_type_mask):
|
||||||
|
labels = e_type_ids[:, :, 0][e_type_mask[:, :, 0] == 1]
|
||||||
|
hiddens = e_out[e_type_mask[:, :, 0] == 1]
|
||||||
|
label_set = set(labels.detach().cpu().numpy())
|
||||||
|
for ii in label_set:
|
||||||
|
self.types_embedding.weight.data[ii] = hiddens[labels == ii].mean(0)
|
||||||
|
|
||||||
|
|
||||||
|
class InputExample(object):
|
||||||
|
def __init__(
|
||||||
|
self, guid: str, words: list, labels: list, types: list, entities: list
|
||||||
|
):
|
||||||
|
self.guid = guid
|
||||||
|
self.words = words
|
||||||
|
self.labels = labels
|
||||||
|
self.types = types
|
||||||
|
self.entities = entities
|
||||||
|
|
||||||
|
|
||||||
|
class InputFeatures(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
segment_ids,
|
||||||
|
label_ids,
|
||||||
|
e_mask,
|
||||||
|
e_type_ids,
|
||||||
|
e_type_mask,
|
||||||
|
types,
|
||||||
|
entities,
|
||||||
|
):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.input_mask = input_mask
|
||||||
|
self.segment_ids = segment_ids
|
||||||
|
self.label_ids = label_ids
|
||||||
|
self.e_mask = e_mask
|
||||||
|
self.e_type_ids = e_type_ids
|
||||||
|
self.e_type_mask = e_type_mask
|
||||||
|
self.types = types
|
||||||
|
self.entities = entities
|
||||||
|
|
||||||
|
|
||||||
|
class Corpus(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger,
|
||||||
|
data_fn,
|
||||||
|
bert_model,
|
||||||
|
max_seq_length,
|
||||||
|
label_list,
|
||||||
|
entity_types: EntityTypes,
|
||||||
|
do_lower_case=True,
|
||||||
|
shuffle=True,
|
||||||
|
tagging="BIO",
|
||||||
|
viterbi="none",
|
||||||
|
device="cuda",
|
||||||
|
concat_types: str = "None",
|
||||||
|
dataset: str = "FewNERD",
|
||||||
|
negative_types_number: int = -1,
|
||||||
|
):
|
||||||
|
self.logger = logger
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(
|
||||||
|
bert_model, do_lower_case=do_lower_case
|
||||||
|
)
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
self.entity_types = entity_types
|
||||||
|
self.label_list = label_list
|
||||||
|
self.label_map = {label: i for i, label in enumerate(label_list)}
|
||||||
|
self.id2label = {i: label for i, label in enumerate(label_list)}
|
||||||
|
self.n_labels = len(label_list)
|
||||||
|
self.tagging_scheme = tagging
|
||||||
|
self.max_len_dict = {"entity": 0, "type": 0, "sentence": 0}
|
||||||
|
self.max_entities_length = 50
|
||||||
|
self.viterbi = viterbi
|
||||||
|
self.dataset = dataset
|
||||||
|
self.negative_types_number = negative_types_number
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Construct the transition matrix via [{}] scheme...".format(viterbi)
|
||||||
|
)
|
||||||
|
# M[ij]: p(j->i)
|
||||||
|
if viterbi == "none":
|
||||||
|
self.transition_matrix = None
|
||||||
|
update_transition_matrix = False
|
||||||
|
elif viterbi == "hard":
|
||||||
|
self.transition_matrix = torch.zeros(
|
||||||
|
[self.n_labels, self.n_labels], device=device
|
||||||
|
) # pij: p(j -> i)
|
||||||
|
if self.n_labels == 3:
|
||||||
|
self.transition_matrix[2][0] = -10000 # p(O -> I) = 0
|
||||||
|
elif self.n_labels == 5:
|
||||||
|
for (i, j) in [
|
||||||
|
(2, 0),
|
||||||
|
(3, 0),
|
||||||
|
(0, 1),
|
||||||
|
(1, 1),
|
||||||
|
(4, 1),
|
||||||
|
(0, 2),
|
||||||
|
(1, 2),
|
||||||
|
(4, 2),
|
||||||
|
(2, 3),
|
||||||
|
(3, 3),
|
||||||
|
(2, 4),
|
||||||
|
(3, 4),
|
||||||
|
]:
|
||||||
|
self.transition_matrix[i][j] = -10000
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
update_transition_matrix = False
|
||||||
|
elif viterbi == "soft":
|
||||||
|
self.transition_matrix = (
|
||||||
|
torch.zeros([self.n_labels, self.n_labels], device=device) + 1e-8
|
||||||
|
)
|
||||||
|
update_transition_matrix = True
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
self.tasks = self.read_tasks_from_file(
|
||||||
|
data_fn, update_transition_matrix, concat_types, dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_total = len(self.tasks)
|
||||||
|
self.batch_start_idx = 0
|
||||||
|
self.batch_idxs = (
|
||||||
|
np.random.permutation(self.n_total)
|
||||||
|
if shuffle
|
||||||
|
else np.array([i for i in range(self.n_total)])
|
||||||
|
) # for batch sampling in training
|
||||||
|
|
||||||
|
def read_tasks_from_file(
|
||||||
|
self,
|
||||||
|
data_fn,
|
||||||
|
update_transition_matrix=False,
|
||||||
|
concat_types: str = "None",
|
||||||
|
dataset: str = "FewNERD",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
return: List[task]
|
||||||
|
task['support'] = List[InputExample]
|
||||||
|
task['query'] = List[InputExample]
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.logger.info("Reading tasks from {}...".format(data_fn))
|
||||||
|
self.logger.info(
|
||||||
|
" update_transition_matrix = {}".format(update_transition_matrix)
|
||||||
|
)
|
||||||
|
self.logger.info(" concat_types = {}".format(concat_types))
|
||||||
|
|
||||||
|
with open(data_fn, "r", encoding="utf-8") as json_file:
|
||||||
|
json_list = list(json_file)
|
||||||
|
output_tasks = []
|
||||||
|
all_labels = [] if update_transition_matrix else None
|
||||||
|
if dataset == "Domain":
|
||||||
|
json_list = self._convert_Domain2FewNERD(json_list)
|
||||||
|
|
||||||
|
for task_id, json_str in enumerate(json_list):
|
||||||
|
if task_id % 1000 == 0:
|
||||||
|
self.logger.info("Reading tasks %d of %d", task_id, len(json_list))
|
||||||
|
task = json.loads(json_str) if dataset != "Domain" else json_str
|
||||||
|
|
||||||
|
support = task["support"]
|
||||||
|
types = task["types"]
|
||||||
|
if self.negative_types_number == -1:
|
||||||
|
self.negative_types_number = len(types) - 1
|
||||||
|
tmp_support, entities, tmp_support_tokens, tmp_query_tokens = [], [], [], []
|
||||||
|
self.max_len_dict["type"] = max(self.max_len_dict["type"], len(types))
|
||||||
|
|
||||||
|
if concat_types != "None":
|
||||||
|
types = set()
|
||||||
|
for l_list in support["label"]:
|
||||||
|
types.update(l_list)
|
||||||
|
types.remove("O")
|
||||||
|
tokenized_types = self.__tokenize_types__(types, concat_types)
|
||||||
|
else:
|
||||||
|
tokenized_types = None
|
||||||
|
|
||||||
|
for i, (words, labels) in enumerate(zip(support["word"], support["label"])):
|
||||||
|
entities = self._convert_label_to_entities_(labels)
|
||||||
|
self.max_len_dict["entity"] = max(
|
||||||
|
len(entities), self.max_len_dict["entity"]
|
||||||
|
)
|
||||||
|
if self.tagging_scheme == "BIOES":
|
||||||
|
labels = self._convert_label_to_BIOES_(labels)
|
||||||
|
elif self.tagging_scheme == "BIO":
|
||||||
|
labels = self._convert_label_to_BIO_(labels)
|
||||||
|
elif self.tagging_scheme == "IO":
|
||||||
|
labels = self._convert_label_to_IO_(labels)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid tagging scheme!")
|
||||||
|
guid = "task[%s]-%s" % (task_id, i)
|
||||||
|
feature, token_sum = self._convert_example_to_feature_(
|
||||||
|
InputExample(
|
||||||
|
guid=guid,
|
||||||
|
words=words,
|
||||||
|
labels=labels,
|
||||||
|
types=types,
|
||||||
|
entities=entities,
|
||||||
|
),
|
||||||
|
tokenized_types=tokenized_types,
|
||||||
|
concat_types=concat_types,
|
||||||
|
)
|
||||||
|
tmp_support.append(feature)
|
||||||
|
tmp_support_tokens.append(token_sum)
|
||||||
|
if update_transition_matrix:
|
||||||
|
all_labels.append(labels)
|
||||||
|
|
||||||
|
query = task["query"]
|
||||||
|
tmp_query = []
|
||||||
|
for i, (words, labels) in enumerate(zip(query["word"], query["label"])):
|
||||||
|
entities = self._convert_label_to_entities_(labels)
|
||||||
|
self.max_len_dict["entity"] = max(
|
||||||
|
len(entities), self.max_len_dict["entity"]
|
||||||
|
)
|
||||||
|
if self.tagging_scheme == "BIOES":
|
||||||
|
labels = self._convert_label_to_BIOES_(labels)
|
||||||
|
elif self.tagging_scheme == "BIO":
|
||||||
|
labels = self._convert_label_to_BIO_(labels)
|
||||||
|
elif self.tagging_scheme == "IO":
|
||||||
|
labels = self._convert_label_to_IO_(labels)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid tagging scheme!")
|
||||||
|
guid = "task[%s]-%s" % (task_id, i)
|
||||||
|
feature, token_sum = self._convert_example_to_feature_(
|
||||||
|
InputExample(
|
||||||
|
guid=guid,
|
||||||
|
words=words,
|
||||||
|
labels=labels,
|
||||||
|
types=types,
|
||||||
|
entities=entities,
|
||||||
|
),
|
||||||
|
tokenized_types=tokenized_types,
|
||||||
|
concat_types=concat_types,
|
||||||
|
)
|
||||||
|
tmp_query.append(feature)
|
||||||
|
tmp_query_tokens.append(token_sum)
|
||||||
|
if update_transition_matrix:
|
||||||
|
all_labels.append(labels)
|
||||||
|
|
||||||
|
output_tasks.append(
|
||||||
|
{
|
||||||
|
"support": tmp_support,
|
||||||
|
"query": tmp_query,
|
||||||
|
"support_token": tmp_support_tokens,
|
||||||
|
"query_token": tmp_query_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
"%s Max Entities Lengths: %d, Max batch Types Number: %d, Max sentence Length: %d",
|
||||||
|
data_fn,
|
||||||
|
self.max_len_dict["entity"],
|
||||||
|
self.max_len_dict["type"],
|
||||||
|
self.max_len_dict["sentence"],
|
||||||
|
)
|
||||||
|
if update_transition_matrix:
|
||||||
|
self._count_transition_matrix_(all_labels)
|
||||||
|
return output_tasks
|
||||||
|
|
||||||
|
def _convert_Domain2FewNERD(self, data: list):
|
||||||
|
def decode_batch(batch: dict):
|
||||||
|
word = batch["seq_ins"]
|
||||||
|
label = [
|
||||||
|
[jj.replace("B-", "").replace("I-", "") for jj in ii]
|
||||||
|
for ii in batch["seq_outs"]
|
||||||
|
]
|
||||||
|
return {"word": word, "label": label}
|
||||||
|
|
||||||
|
data = json.loads(data[0])
|
||||||
|
res = []
|
||||||
|
for domain in data.keys():
|
||||||
|
d = data[domain]
|
||||||
|
labels = self.entity_types.types[domain]
|
||||||
|
res.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"support": decode_batch(ii["support"]),
|
||||||
|
"query": decode_batch(ii["batch"]),
|
||||||
|
"types": labels,
|
||||||
|
}
|
||||||
|
for ii in d
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def __tokenize_types__(self, types, concat_types: str = "past"):
|
||||||
|
tokens = []
|
||||||
|
for t in types:
|
||||||
|
if "embedding" in concat_types:
|
||||||
|
t_tokens = [f"[unused{self.entity_types.types_map[t]}]"]
|
||||||
|
else:
|
||||||
|
t_tokens = self.tokenizer.tokenize(t)
|
||||||
|
if len(t_tokens) == 0:
|
||||||
|
continue
|
||||||
|
tokens.extend(t_tokens)
|
||||||
|
tokens.append(",") # separate different types with a comma ','.
|
||||||
|
tokens.pop() # pop the last comma
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _count_transition_matrix_(self, labels):
|
||||||
|
self.logger.info("Computing transition matrix...")
|
||||||
|
for sent_labels in labels:
|
||||||
|
for i in range(len(sent_labels) - 1):
|
||||||
|
start = self.label_map[sent_labels[i]]
|
||||||
|
end = self.label_map[sent_labels[i + 1]]
|
||||||
|
self.transition_matrix[end][start] += 1
|
||||||
|
self.transition_matrix /= torch.sum(self.transition_matrix, dim=0)
|
||||||
|
self.transition_matrix = torch.log(self.transition_matrix)
|
||||||
|
self.logger.info("Done.")
|
||||||
|
|
||||||
|
def _convert_label_to_entities_(self, label_list: list):
|
||||||
|
N = len(label_list)
|
||||||
|
S = [
|
||||||
|
ii
|
||||||
|
for ii in range(N)
|
||||||
|
if label_list[ii] != "O"
|
||||||
|
and (not ii or label_list[ii] != label_list[ii - 1])
|
||||||
|
]
|
||||||
|
E = [
|
||||||
|
ii
|
||||||
|
for ii in range(N)
|
||||||
|
if label_list[ii] != "O"
|
||||||
|
and (ii == N - 1 or label_list[ii] != label_list[ii + 1])
|
||||||
|
]
|
||||||
|
return [(s, e, label_list[s]) for s, e in zip(S, E)]
|
||||||
|
|
||||||
|
def _convert_label_to_BIOES_(self, label_list):
|
||||||
|
res = []
|
||||||
|
label_list = ["O"] + label_list + ["O"]
|
||||||
|
for i in range(1, len(label_list) - 1):
|
||||||
|
if label_list[i] == "O":
|
||||||
|
res.append("O")
|
||||||
|
continue
|
||||||
|
# for S
|
||||||
|
if (
|
||||||
|
label_list[i] != label_list[i - 1]
|
||||||
|
and label_list[i] != label_list[i + 1]
|
||||||
|
):
|
||||||
|
res.append("S")
|
||||||
|
elif (
|
||||||
|
label_list[i] != label_list[i - 1]
|
||||||
|
and label_list[i] == label_list[i + 1]
|
||||||
|
):
|
||||||
|
res.append("B")
|
||||||
|
elif (
|
||||||
|
label_list[i] == label_list[i - 1]
|
||||||
|
and label_list[i] != label_list[i + 1]
|
||||||
|
):
|
||||||
|
res.append("E")
|
||||||
|
elif (
|
||||||
|
label_list[i] == label_list[i - 1]
|
||||||
|
and label_list[i] == label_list[i + 1]
|
||||||
|
):
|
||||||
|
res.append("I")
|
||||||
|
else:
|
||||||
|
raise ValueError("Some bugs exist in your code!")
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _convert_label_to_BIO_(self, label_list):
|
||||||
|
precursor = ""
|
||||||
|
label_output = []
|
||||||
|
for label in label_list:
|
||||||
|
if label == "O":
|
||||||
|
label_output.append("O")
|
||||||
|
elif label != precursor:
|
||||||
|
label_output.append("B")
|
||||||
|
else:
|
||||||
|
label_output.append("I")
|
||||||
|
precursor = label
|
||||||
|
|
||||||
|
return label_output
|
||||||
|
|
||||||
|
def _convert_label_to_IO_(self, label_list):
|
||||||
|
label_output = []
|
||||||
|
for label in label_list:
|
||||||
|
if label == "O":
|
||||||
|
label_output.append("O")
|
||||||
|
else:
|
||||||
|
label_output.append("I")
|
||||||
|
|
||||||
|
return label_output
|
||||||
|
|
||||||
|
def _convert_example_to_feature_(
|
||||||
|
self,
|
||||||
|
example,
|
||||||
|
cls_token_at_end=False,
|
||||||
|
cls_token="[CLS]",
|
||||||
|
cls_token_segment_id=0,
|
||||||
|
sep_token="[SEP]",
|
||||||
|
sep_token_extra=False,
|
||||||
|
pad_on_left=False,
|
||||||
|
pad_token=0,
|
||||||
|
pad_token_segment_id=0,
|
||||||
|
pad_token_label_id=-1,
|
||||||
|
sequence_a_segment_id=0,
|
||||||
|
mask_padding_with_zero=True,
|
||||||
|
ignore_token_label_id=torch.nn.CrossEntropyLoss().ignore_index,
|
||||||
|
sequence_b_segment_id=1,
|
||||||
|
tokenized_types=None,
|
||||||
|
concat_types: str = "None",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
`cls_token_at_end` define the location of the CLS token:
|
||||||
|
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||||
|
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||||
|
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||||
|
"""
|
||||||
|
tokens, label_ids, token_sum = [], [], [1]
|
||||||
|
if tokenized_types is None:
|
||||||
|
tokenized_types = []
|
||||||
|
if "before" in concat_types:
|
||||||
|
token_sum[-1] += 1 + len(tokenized_types)
|
||||||
|
|
||||||
|
for words, labels in zip(example.words, example.labels):
|
||||||
|
word_tokens = self.tokenizer.tokenize(words)
|
||||||
|
token_sum.append(token_sum[-1] + len(word_tokens))
|
||||||
|
if len(word_tokens) == 0:
|
||||||
|
continue
|
||||||
|
tokens.extend(word_tokens)
|
||||||
|
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
||||||
|
label_ids.extend(
|
||||||
|
[self.label_map[labels]]
|
||||||
|
+ [ignore_token_label_id] * (len(word_tokens) - 1)
|
||||||
|
)
|
||||||
|
self.max_len_dict["sentence"] = max(self.max_len_dict["sentence"], len(tokens))
|
||||||
|
e_ids = [(token_sum[s], token_sum[e + 1] - 1) for s, e, _ in example.entities]
|
||||||
|
e_mask = np.zeros((self.max_entities_length, self.max_seq_length), np.int8)
|
||||||
|
e_type_mask = np.zeros(
|
||||||
|
(self.max_entities_length, 1 + self.negative_types_number), np.int8
|
||||||
|
)
|
||||||
|
e_type_mask[: len(e_ids), :] = np.ones(
|
||||||
|
(len(e_ids), 1 + self.negative_types_number), np.int8
|
||||||
|
)
|
||||||
|
for idx, (s, e) in enumerate(e_ids):
|
||||||
|
e_mask[idx][s : e + 1] = 1
|
||||||
|
e_type_ids = [self.entity_types.types_map[t] for _, _, t in example.entities]
|
||||||
|
entities = [(s, e, t) for (s, e), t in zip(e_ids, e_type_ids)]
|
||||||
|
batch_types = [self.entity_types.types_map[ii] for ii in example.types]
|
||||||
|
# e_type_ids[i, 0] is the possitive label, while e_type_ids[i, 1:] are negative labels
|
||||||
|
e_type_ids = self.entity_types.generate_negative_types(
|
||||||
|
e_type_ids, batch_types, self.negative_types_number
|
||||||
|
)
|
||||||
|
if len(e_type_ids) < self.max_entities_length:
|
||||||
|
e_type_ids = np.concatenate(
|
||||||
|
[
|
||||||
|
e_type_ids,
|
||||||
|
[[0] * (1 + self.negative_types_number)]
|
||||||
|
* (self.max_entities_length - len(e_type_ids)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
||||||
|
special_tokens_count = 3 if sep_token_extra else 2
|
||||||
|
if len(tokens) > self.max_seq_length - special_tokens_count - len(
|
||||||
|
tokenized_types
|
||||||
|
):
|
||||||
|
tokens = tokens[
|
||||||
|
: (self.max_seq_length - special_tokens_count - len(tokenized_types))
|
||||||
|
]
|
||||||
|
label_ids = label_ids[
|
||||||
|
: (self.max_seq_length - special_tokens_count - len(tokenized_types))
|
||||||
|
]
|
||||||
|
|
||||||
|
types = [self.entity_types.types_map[t] for t in example.types]
|
||||||
|
|
||||||
|
if "before" in concat_types:
|
||||||
|
# OPTION 1: Concatenated tokenized types at START
|
||||||
|
len_sentence = len(tokens)
|
||||||
|
tokens = [cls_token] + tokenized_types + [sep_token] + tokens
|
||||||
|
label_ids = [ignore_token_label_id] * (len(tokenized_types) + 2) + label_ids
|
||||||
|
segment_ids = (
|
||||||
|
[cls_token_segment_id]
|
||||||
|
+ [sequence_a_segment_id] * (len(tokenized_types) + 1)
|
||||||
|
+ [sequence_b_segment_id] * len_sentence
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# OPTION 2: Concatenated tokenized types at END
|
||||||
|
tokens += [sep_token]
|
||||||
|
label_ids += [ignore_token_label_id]
|
||||||
|
if sep_token_extra:
|
||||||
|
raise ValueError("Unexpected path!")
|
||||||
|
# roberta uses an extra separator b/w pairs of sentences
|
||||||
|
tokens += [sep_token]
|
||||||
|
label_ids += [ignore_token_label_id]
|
||||||
|
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||||||
|
|
||||||
|
if cls_token_at_end:
|
||||||
|
raise ValueError("Unexpected path!")
|
||||||
|
tokens += [cls_token]
|
||||||
|
label_ids += [ignore_token_label_id]
|
||||||
|
segment_ids += [cls_token_segment_id]
|
||||||
|
else:
|
||||||
|
tokens = [cls_token] + tokens
|
||||||
|
label_ids = [ignore_token_label_id] + label_ids
|
||||||
|
segment_ids = [cls_token_segment_id] + segment_ids
|
||||||
|
|
||||||
|
if "past" in concat_types:
|
||||||
|
tokens += tokenized_types
|
||||||
|
label_ids += [ignore_token_label_id] * len(tokenized_types)
|
||||||
|
segment_ids += [sequence_b_segment_id] * len(tokenized_types)
|
||||||
|
|
||||||
|
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
|
||||||
|
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
|
||||||
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||||
|
|
||||||
|
# Zero-pad up to the sequence length.
|
||||||
|
padding_length = self.max_seq_length - len(input_ids)
|
||||||
|
if pad_on_left:
|
||||||
|
input_ids = ([pad_token] * padding_length) + input_ids
|
||||||
|
input_mask = (
|
||||||
|
[0 if mask_padding_with_zero else 1] * padding_length
|
||||||
|
) + input_mask
|
||||||
|
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||||
|
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
||||||
|
else:
|
||||||
|
input_ids += [pad_token] * padding_length
|
||||||
|
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
||||||
|
segment_ids += [pad_token_segment_id] * padding_length
|
||||||
|
label_ids += [pad_token_label_id] * padding_length
|
||||||
|
|
||||||
|
assert len(input_ids) == self.max_seq_length
|
||||||
|
assert len(input_mask) == self.max_seq_length
|
||||||
|
assert len(segment_ids) == self.max_seq_length
|
||||||
|
assert len(label_ids) == self.max_seq_length
|
||||||
|
|
||||||
|
return (
|
||||||
|
InputFeatures(
|
||||||
|
input_ids=input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
label_ids=label_ids,
|
||||||
|
e_mask=e_mask, # max_entities_length x 128
|
||||||
|
e_type_ids=e_type_ids, # max_entities_length x 5 (n_types)
|
||||||
|
e_type_mask=np.array(e_type_mask), # max_entities_length x 5 (n_types)
|
||||||
|
types=np.array(types),
|
||||||
|
entities=entities,
|
||||||
|
),
|
||||||
|
token_sum,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_batch_info(self, shuffle=False):
|
||||||
|
self.batch_start_idx = 0
|
||||||
|
self.batch_idxs = (
|
||||||
|
np.random.permutation(self.n_total)
|
||||||
|
if shuffle
|
||||||
|
else np.array([i for i in range(self.n_total)])
|
||||||
|
) # for batch sampling in training
|
||||||
|
|
||||||
|
def get_batch_meta(self, batch_size, device="cuda", shuffle=True):
|
||||||
|
if self.batch_start_idx + batch_size > self.n_total:
|
||||||
|
self.reset_batch_info(shuffle=shuffle)
|
||||||
|
|
||||||
|
query_batch = []
|
||||||
|
support_batch = []
|
||||||
|
start_id = self.batch_start_idx
|
||||||
|
|
||||||
|
for i in range(start_id, start_id + batch_size):
|
||||||
|
idx = self.batch_idxs[i]
|
||||||
|
task_curr = self.tasks[idx]
|
||||||
|
|
||||||
|
query_item = {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[f.input_ids for f in task_curr["query"]], dtype=torch.long
|
||||||
|
).to(
|
||||||
|
device
|
||||||
|
), # 1 x max_seq_len
|
||||||
|
"input_mask": torch.tensor(
|
||||||
|
[f.input_mask for f in task_curr["query"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"segment_ids": torch.tensor(
|
||||||
|
[f.segment_ids for f in task_curr["query"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"label_ids": torch.tensor(
|
||||||
|
[f.label_ids for f in task_curr["query"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_mask": torch.tensor(
|
||||||
|
[f.e_mask for f in task_curr["query"]], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"e_type_ids": torch.tensor(
|
||||||
|
[f.e_type_ids for f in task_curr["query"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_type_mask": torch.tensor(
|
||||||
|
[f.e_type_mask for f in task_curr["query"]], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"types": [f.types for f in task_curr["query"]],
|
||||||
|
"entities": [f.entities for f in task_curr["query"]],
|
||||||
|
"idx": idx,
|
||||||
|
}
|
||||||
|
query_batch.append(query_item)
|
||||||
|
|
||||||
|
support_item = {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[f.input_ids for f in task_curr["support"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
# 1 x max_seq_len
|
||||||
|
"input_mask": torch.tensor(
|
||||||
|
[f.input_mask for f in task_curr["support"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"segment_ids": torch.tensor(
|
||||||
|
[f.segment_ids for f in task_curr["support"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"label_ids": torch.tensor(
|
||||||
|
[f.label_ids for f in task_curr["support"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_mask": torch.tensor(
|
||||||
|
[f.e_mask for f in task_curr["support"]], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"e_type_ids": torch.tensor(
|
||||||
|
[f.e_type_ids for f in task_curr["support"]], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_type_mask": torch.tensor(
|
||||||
|
[f.e_type_mask for f in task_curr["support"]], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"types": [f.types for f in task_curr["support"]],
|
||||||
|
"entities": [f.entities for f in task_curr["support"]],
|
||||||
|
"idx": idx,
|
||||||
|
}
|
||||||
|
support_batch.append(support_item)
|
||||||
|
|
||||||
|
self.batch_start_idx += batch_size
|
||||||
|
|
||||||
|
return query_batch, support_batch
|
||||||
|
|
||||||
|
def get_batch_NOmeta(self, batch_size, device="cuda", shuffle=True):
|
||||||
|
if self.batch_start_idx + batch_size >= self.n_total:
|
||||||
|
self.reset_batch_info(shuffle=shuffle)
|
||||||
|
if self.mask_rate >= 0:
|
||||||
|
self.query_features = self.build_query_features_with_mask(
|
||||||
|
mask_rate=self.mask_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
idxs = self.batch_idxs[self.batch_start_idx : self.batch_start_idx + batch_size]
|
||||||
|
batch_features = [self.query_features[idx] for idx in idxs]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[f.input_ids for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"input_mask": torch.tensor(
|
||||||
|
[f.input_mask for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"segment_ids": torch.tensor(
|
||||||
|
[f.segment_ids for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"label_ids": torch.tensor(
|
||||||
|
[f.label_id for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_mask": torch.tensor(
|
||||||
|
[f.e_mask for f in batch_features], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"e_type_ids": torch.tensor(
|
||||||
|
[f.e_type_ids for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_type_mask": torch.tensor(
|
||||||
|
[f.e_type_mask for f in batch_features], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"types": [f.types for f in batch_features],
|
||||||
|
"entities": [f.entities for f in batch_features],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.batch_start_idx += batch_size
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def get_batches(self, batch_size, device="cuda", shuffle=False):
|
||||||
|
batches = []
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
idxs = np.random.permutation(self.n_total)
|
||||||
|
features = [self.query_features[i] for i in idxs]
|
||||||
|
else:
|
||||||
|
features = self.query_features
|
||||||
|
|
||||||
|
for i in range(0, self.n_total, batch_size):
|
||||||
|
batch_features = features[i : min(self.n_total, i + batch_size)]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[f.input_ids for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"input_mask": torch.tensor(
|
||||||
|
[f.input_mask for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"segment_ids": torch.tensor(
|
||||||
|
[f.segment_ids for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"label_ids": torch.tensor(
|
||||||
|
[f.label_id for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_mask": torch.tensor(
|
||||||
|
[f.e_mask for f in batch_features], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"e_type_ids": torch.tensor(
|
||||||
|
[f.e_type_ids for f in batch_features], dtype=torch.long
|
||||||
|
).to(device),
|
||||||
|
"e_type_mask": torch.tensor(
|
||||||
|
[f.e_type_mask for f in batch_features], dtype=torch.int
|
||||||
|
).to(device),
|
||||||
|
"types": [f.types for f in batch_features],
|
||||||
|
"entities": [f.entities for f in batch_features],
|
||||||
|
}
|
||||||
|
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
def _decoder_bpe_index(self, sentences_spans: list):
|
||||||
|
res = []
|
||||||
|
tokens = [jj for ii in self.tasks for jj in ii["query_token"]]
|
||||||
|
assert len(tokens) == len(
|
||||||
|
sentences_spans
|
||||||
|
), f"tokens size: {len(tokens)}, sentences size: {len(sentences_spans)}"
|
||||||
|
for sentence_idx, spans in enumerate(sentences_spans):
|
||||||
|
token = tokens[sentence_idx]
|
||||||
|
tmp = []
|
||||||
|
for b, e in spans:
|
||||||
|
nb = bisect.bisect_left(token, b)
|
||||||
|
ne = bisect.bisect_left(token, e)
|
||||||
|
tmp.append((nb, ne))
|
||||||
|
res.append(tmp)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pass
|
|
@ -0,0 +1,3 @@
|
||||||
|
numpy
|
||||||
|
transformers==4.10.0
|
||||||
|
torch==1.9.0
|
|
@ -0,0 +1,71 @@
|
||||||
|
rem Copyright (c) Microsoft Corporation.
|
||||||
|
rem Licensed under the MIT license.
|
||||||
|
|
||||||
|
set SEEDS=171 354 550 667 985
|
||||||
|
set N=5
|
||||||
|
set K=1
|
||||||
|
set mode=inter
|
||||||
|
|
||||||
|
for %%e in (%SEEDS%) do (
|
||||||
|
python3 main.py ^
|
||||||
|
--gpu_device=1 ^
|
||||||
|
--seed=%%e ^
|
||||||
|
--mode=%mode% ^
|
||||||
|
--N=%N% ^
|
||||||
|
--K=%K% ^
|
||||||
|
--similar_k=10 ^
|
||||||
|
--eval_every_meta_steps=100 ^
|
||||||
|
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
|
||||||
|
--train_mode=span ^
|
||||||
|
--inner_steps=2 ^
|
||||||
|
--inner_size=32 ^
|
||||||
|
--max_ft_steps=3 ^
|
||||||
|
--lambda_max_loss=2 ^
|
||||||
|
--inner_lambda_max_loss=5 ^
|
||||||
|
--tagging_scheme=BIOES ^
|
||||||
|
--viterbi=hard ^
|
||||||
|
--concat_types=None ^
|
||||||
|
--ignore_eval_test
|
||||||
|
|
||||||
|
python3 main.py ^
|
||||||
|
--seed=%%e ^
|
||||||
|
--gpu_device=1 ^
|
||||||
|
--lr_inner=1e-4 ^
|
||||||
|
--lr_meta=1e-4 ^
|
||||||
|
--mode=%mode% ^
|
||||||
|
--N=%N% ^
|
||||||
|
--K=%K% ^
|
||||||
|
--similar_k=10 ^
|
||||||
|
--inner_similar_k=10 ^
|
||||||
|
--eval_every_meta_steps=100 ^
|
||||||
|
--name=10-k_100_type_2_32_3_10_10 ^
|
||||||
|
--train_mode=type ^
|
||||||
|
--inner_steps=2 ^
|
||||||
|
--inner_size=32 ^
|
||||||
|
--max_ft_steps=3 ^
|
||||||
|
--concat_types=None ^
|
||||||
|
--lambda_max_loss=2.0
|
||||||
|
|
||||||
|
cp models-%N%-%K%-%mode%\bert-base-uncased-innerSteps_2-innerSize_32-lrInner_0.0001-lrMeta_0.0001-maxSteps_5001-seed_%%e-name_10-k_100_type_2_32_3_10_10\en_type_pytorch_model.bin models-%N%-%K%-%mode%\bert-base-uncased-innerSteps_2-innerSize_32-lrInner_3e-05-lrMeta_3e-05-maxSteps_5001-seed_%%e-name_10-k_100_2_32_3_max_loss_2_5_BIOES
|
||||||
|
|
||||||
|
python3 main.py ^
|
||||||
|
--gpu_device=1 ^
|
||||||
|
--seed=%%e ^
|
||||||
|
--N=%N% ^
|
||||||
|
--K=%K% ^
|
||||||
|
--mode=%mode% ^
|
||||||
|
--similar_k=10 ^
|
||||||
|
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
|
||||||
|
--concat_types=None ^
|
||||||
|
--test_only ^
|
||||||
|
--eval_mode=two-stage ^
|
||||||
|
--inner_steps=2 ^
|
||||||
|
--inner_size=32 ^
|
||||||
|
--max_ft_steps=3 ^
|
||||||
|
--max_type_ft_steps=3 ^
|
||||||
|
--lambda_max_loss=2.0 ^
|
||||||
|
--inner_lambda_max_loss=5.0 ^
|
||||||
|
--inner_similar_k=10 ^
|
||||||
|
--viterbi=hard ^
|
||||||
|
--tagging_scheme=BIOES
|
||||||
|
)
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
SEEDS=(171 354 550 667 985)
|
||||||
|
N=5
|
||||||
|
K=1
|
||||||
|
mode=inter
|
||||||
|
|
||||||
|
for seed in ${SEEDS[@]}; do
|
||||||
|
python3 main.py \
|
||||||
|
--gpu_device=1 \
|
||||||
|
--seed=${seed} \
|
||||||
|
--mode=${mode} \
|
||||||
|
--N=${N} \
|
||||||
|
--K=${K} \
|
||||||
|
--similar_k=10 \
|
||||||
|
--eval_every_meta_steps=100 \
|
||||||
|
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
|
||||||
|
--train_mode=span \
|
||||||
|
--inner_steps=2 \
|
||||||
|
--inner_size=32 \
|
||||||
|
--max_ft_steps=3 \
|
||||||
|
--lambda_max_loss=2 \
|
||||||
|
--inner_lambda_max_loss=5 \
|
||||||
|
--tagging_scheme=BIOES \
|
||||||
|
--viterbi=hard \
|
||||||
|
--concat_types=None \
|
||||||
|
--ignore_eval_test
|
||||||
|
|
||||||
|
python3 main.py \
|
||||||
|
--seed=${seed} \
|
||||||
|
--gpu_device=1 \
|
||||||
|
--lr_inner=1e-4 \
|
||||||
|
--lr_meta=1e-4 \
|
||||||
|
--mode=${mode} \
|
||||||
|
--N=${N} \
|
||||||
|
--K=${K} \
|
||||||
|
--similar_k=10 \
|
||||||
|
--inner_similar_k=10 \
|
||||||
|
--eval_every_meta_steps=100 \
|
||||||
|
--name=10-k_100_type_2_32_3_10_10 \
|
||||||
|
--train_mode=type \
|
||||||
|
--inner_steps=2 \
|
||||||
|
--inner_size=32 \
|
||||||
|
--max_ft_steps=3 \
|
||||||
|
--concat_types=None \
|
||||||
|
--lambda_max_loss=2.0
|
||||||
|
|
||||||
|
cp models-${N}-${K}-${mode}/bert-base-uncased-innerSteps_2-innerSize_32-lrInner_0.0001-lrMeta_0.0001-maxSteps_5001-seed_${seed}-name_10-k_100_type_2_32_3_10_10/en_type_pytorch_model.bin models-${N}-${K}-${mode}/bert-base-uncased-innerSteps_2-innerSize_32-lrInner_3e-05-lrMeta_3e-05-maxSteps_5001-seed_${seed}-name_10-k_100_2_32_3_max_loss_2_5_BIOES
|
||||||
|
|
||||||
|
python3 main.py \
|
||||||
|
--gpu_device=1 \
|
||||||
|
--seed=${seed} \
|
||||||
|
--N=${N} \
|
||||||
|
--K=${K} \
|
||||||
|
--mode=${mode} \
|
||||||
|
--similar_k=10 \
|
||||||
|
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
|
||||||
|
--concat_types=None \
|
||||||
|
--test_only \
|
||||||
|
--eval_mode=two-stage \
|
||||||
|
--inner_steps=2 \
|
||||||
|
--inner_size=32 \
|
||||||
|
--max_ft_steps=3 \
|
||||||
|
--max_type_ft_steps=3 \
|
||||||
|
--lambda_max_loss=2.0 \
|
||||||
|
--inner_lambda_max_loss=5.0 \
|
||||||
|
--inner_similar_k=10 \
|
||||||
|
--viterbi=hard \
|
||||||
|
--tagging_scheme=BIOES
|
||||||
|
done
|
|
@ -0,0 +1,30 @@
|
||||||
|
rem Copyright (c) Microsoft Corporation.
|
||||||
|
rem Licensed under the MIT license.
|
||||||
|
|
||||||
|
set SEEDS=171
|
||||||
|
set N=5
|
||||||
|
set K=1
|
||||||
|
set mode=inter
|
||||||
|
|
||||||
|
for %%e in (%SEEDS%) do (
|
||||||
|
python3 main.py ^
|
||||||
|
--gpu_device=1 ^
|
||||||
|
--seed=%%e ^
|
||||||
|
--N=%N% ^
|
||||||
|
--K=%K% ^
|
||||||
|
--mode=%mode% ^
|
||||||
|
--similar_k=10 ^
|
||||||
|
--name=10-k_100_2_32_3_max_loss_2_5_BIOES ^
|
||||||
|
--concat_types=None ^
|
||||||
|
--test_only ^
|
||||||
|
--eval_mode=two-stage ^
|
||||||
|
--inner_steps=2 ^
|
||||||
|
--inner_size=32 ^
|
||||||
|
--max_ft_steps=3 ^
|
||||||
|
--max_type_ft_steps=3 ^
|
||||||
|
--lambda_max_loss=2.0 ^
|
||||||
|
--inner_lambda_max_loss=5.0 ^
|
||||||
|
--inner_similar_k=10 ^
|
||||||
|
--viterbi=hard ^
|
||||||
|
--tagging_scheme=BIOES
|
||||||
|
)
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
SEEDS=(171)
|
||||||
|
N=5
|
||||||
|
K=1
|
||||||
|
mode=inter
|
||||||
|
|
||||||
|
for seed in ${SEEDS[@]}; do
|
||||||
|
python3 main.py \
|
||||||
|
--gpu_device=1 \
|
||||||
|
--seed=${seed} \
|
||||||
|
--N=${N} \
|
||||||
|
--K=${K} \
|
||||||
|
--mode=${mode} \
|
||||||
|
--similar_k=10 \
|
||||||
|
--name=10-k_100_2_32_3_max_loss_2_5_BIOES \
|
||||||
|
--concat_types=None \
|
||||||
|
--test_only \
|
||||||
|
--eval_mode=two-stage \
|
||||||
|
--inner_steps=2 \
|
||||||
|
--inner_size=32 \
|
||||||
|
--max_ft_steps=3 \
|
||||||
|
--max_type_ft_steps=3 \
|
||||||
|
--lambda_max_loss=2.0 \
|
||||||
|
--inner_lambda_max_loss=5.0 \
|
||||||
|
--inner_similar_k=10 \
|
||||||
|
--viterbi=hard \
|
||||||
|
--tagging_scheme=BIOES
|
||||||
|
done
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(path: str, mode: str = "list-strip"):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return [] if not mode else ""
|
||||||
|
with open(path, "r", encoding="utf-8", newline="\n") as f:
|
||||||
|
if mode == "list-strip":
|
||||||
|
data = [ii.strip() for ii in f.readlines()]
|
||||||
|
elif mode == "str":
|
||||||
|
data = f.read()
|
||||||
|
elif mode == "list":
|
||||||
|
data = list(f.readlines())
|
||||||
|
elif mode == "json":
|
||||||
|
data = json.loads(f.read())
|
||||||
|
elif mode == "json-list":
|
||||||
|
data = [json.loads(ii) for ii in f.readlines()]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed, gpu_device):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if gpu_device > -1:
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
Загрузка…
Ссылка в новой задаче