add DecomposedMetaSL code
This commit is contained in:
Родитель
e8cca59e75
Коммит
e062c845ae
|
@ -1 +1,76 @@
|
||||||
Code will be public once the paper is accepted.
|
# Decomposed Meta-Learning for Few-Shot Sequence Labeling
|
||||||
|
|
||||||
|
This repository contains the open-sourced official implementation of the paper:
|
||||||
|
[Decomposed Meta-Learning for Few-Shot Sequence Labeling](https://ieeexplore.ieee.org/abstract/document/10458261) (TASLP).
|
||||||
|
|
||||||
|
|
||||||
|
_Tingting Ma, Qianhui Wu, Huiqiang Jiang, Jieru Lin, Börje F. Karlsson, Tiejun Zhao, and Chin-Yew Lin_
|
||||||
|
|
||||||
|
If you find this repo helpful, please cite the following paper
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@ARTICLE{ma-etal-2024-decomposedmetasl,
|
||||||
|
author={Ma, Tingting and Wu, Qianhui and Jiang, Huiqiang and Lin, Jieru and Karlsson, Börje F. and Zhao, Tiejun and Lin, Chin-Yew},
|
||||||
|
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
|
||||||
|
title={Decomposed Meta-Learning for Few-Shot Sequence Labeling},
|
||||||
|
year={2024},
|
||||||
|
volume={},
|
||||||
|
number={},
|
||||||
|
pages={1-14},
|
||||||
|
doi={10.1109/TASLP.2024.3372879}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For any questions/comments, please feel free to open GitHub issues.
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
- python 3.9.17
|
||||||
|
- pytorch 1.9.1+cu111
|
||||||
|
- [HuggingFace Transformers 4.10.0](https://github.com/huggingface/transformers)
|
||||||
|
|
||||||
|
### Input data format
|
||||||
|
|
||||||
|
train/dev/test_N_K_id.jsonl:
|
||||||
|
Each line contains the following fields:
|
||||||
|
|
||||||
|
1. `target_classes`: A list of types (e.g., "event-other," "person-scholar").
|
||||||
|
2. `query_idx`: A list of indexes corresponding to query sentences for the i-th instance in train/dev/test.txt.
|
||||||
|
3. `support_idx`: A list of indexes corresponding to support sentences for the i-th instance in train/dev/test.txt.
|
||||||
|
|
||||||
|
### Train and Evaluate
|
||||||
|
|
||||||
|
For _Seperate_ model,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/train_ment.sh
|
||||||
|
bash scripts/train_type.sh
|
||||||
|
bash scripts/eval_sep.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
For _Joint_ model,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/train_joint.sh
|
||||||
|
bash scripts/eval_joint.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
For _POS tagging_ task, run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash scripts/train_pos.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||||
|
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||||
|
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||||
|
|
||||||
|
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||||
|
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||||
|
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||||
|
|
||||||
|
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||||
|
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||||
|
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,10 @@
|
||||||
|
{"target_classes": ["event-other", "person-scholar", "organization-showorganization", "building-other", "organization-religion"], "query_idx": [2944, 11433, 13834, 9686, 14970], "support_idx": [15362, 12956, 9225, 6364, 17558]}
|
||||||
|
{"target_classes": ["other-currency", "organization-religion", "building-other", "organization-showorganization", "person-other"], "query_idx": [16913, 16269, 1550, 5031, 6583], "support_idx": [365, 14279, 7493, 5495]}
|
||||||
|
{"target_classes": ["product-train", "art-painting", "person-scholar", "organization-showorganization", "building-other"], "query_idx": [12906, 17326, 2405, 12872, 9081], "support_idx": [9686, 1202, 1422, 3779]}
|
||||||
|
{"target_classes": ["location-park", "building-other", "product-train", "other-currency", "organization-religion"], "query_idx": [18096, 7281, 4038, 11566, 14149], "support_idx": [16798, 17009, 11792, 1476, 67]}
|
||||||
|
{"target_classes": ["product-game", "location-park", "organization-showorganization", "building-other", "other-currency"], "query_idx": [1333, 16185, 15369, 13226, 18582], "support_idx": [4179, 12320, 4168, 5272, 557]}
|
||||||
|
{"target_classes": ["product-train", "art-painting", "event-other", "building-library", "building-other"], "query_idx": [14660, 2874, 17317, 16888, 4034], "support_idx": [149, 5083, 13658, 12898, 1242]}
|
||||||
|
{"target_classes": ["other-chemicalthing", "person-scholar", "other-currency", "building-other", "building-library"], "query_idx": [16000, 12005, 8526, 2955, 3635], "support_idx": [14882, 16916, 2826, 3466, 4354]}
|
||||||
|
{"target_classes": ["building-other", "location-park", "organization-religion", "building-library", "product-train"], "query_idx": [9041, 17985, 11146, 9223, 3569], "support_idx": [15817, 3343, 15093, 3147, 18407]}
|
||||||
|
{"target_classes": ["building-other", "organization-showorganization", "other-currency", "building-library", "product-game"], "query_idx": [14602, 14513, 3689, 16122, 11496], "support_idx": [213, 8006, 11213, 3411, 9197]}
|
||||||
|
{"target_classes": ["organization-religion", "other-chemicalthing", "location-park", "person-scholar", "art-painting"], "query_idx": [3652, 12342, 14141, 7028, 1242], "support_idx": [16594, 2171, 6945, 6765, 12307]}
|
|
@ -0,0 +1,10 @@
|
||||||
|
{"target_classes": ["art-writtenart", "other-livingthing", "person-athlete", "building-theater", "product-car"], "query_idx": [5475, 2456, 7708, 13599, 1070], "support_idx": [12355, 10289, 7443, 3571, 4689]}
|
||||||
|
{"target_classes": ["event-sportsevent", "art-writtenart", "other-livingthing", "location-bodiesofwater", "product-car"], "query_idx": [13726, 10244, 2481, 9620, 5067], "support_idx": [8120, 6841, 6299, 8113, 12566]}
|
||||||
|
{"target_classes": ["other-medical", "other-livingthing", "organization-government/governmentagency", "person-athlete", "product-weapon"], "query_idx": [6730, 7660, 9544, 10325, 752], "support_idx": [3131, 7347, 12463, 7987, 13812]}
|
||||||
|
{"target_classes": ["location-other", "art-writtenart", "other-educationaldegree", "building-theater", "event-election"], "query_idx": [8314, 6694, 7124, 9619, 11697], "support_idx": [11050, 13350, 4019, 6025, 11376]}
|
||||||
|
{"target_classes": ["person-athlete", "organization-politicalparty", "location-bodiesofwater", "event-election", "event-sportsevent"], "query_idx": [7482, 11414, 1001, 13487], "support_idx": [5012, 4715, 13625, 9276, 6870]}
|
||||||
|
{"target_classes": ["art-music", "location-bodiesofwater", "organization-politicalparty", "product-car", "event-sportsevent"], "query_idx": [9002, 7862, 7715, 5282, 12689], "support_idx": [4504, 7822, 11325, 8480, 1235]}
|
||||||
|
{"target_classes": ["other-medical", "location-other", "person-actor", "product-weapon", "location-bodiesofwater"], "query_idx": [2532, 11648, 13626, 1109, 10751], "support_idx": [7290, 413, 4334, 10504, 12808]}
|
||||||
|
{"target_classes": ["person-actor", "product-weapon", "organization-politicalparty", "person-athlete", "organization-government/governmentagency"], "query_idx": [10577, 4080, 13918, 799, 2619], "support_idx": [8103, 12071, 9532, 1554, 2131]}
|
||||||
|
{"target_classes": ["location-other", "organization-politicalparty", "person-actor", "organization-government/governmentagency", "event-election"], "query_idx": [11580, 8955, 8633, 12802, 4687], "support_idx": [7790, 11474, 6822, 11488, 4439]}
|
||||||
|
{"target_classes": ["other-medical", "other-educationaldegree", "person-athlete", "product-car", "event-sportsevent"], "query_idx": [8012, 11377, 10920, 800, 12107], "support_idx": [5617, 13104, 4165, 1768, 10965]}
|
|
@ -0,0 +1,10 @@
|
||||||
|
{"target_classes": ["person-director", "person-artist/author", "product-ship", "building-sportsfacility", "other-astronomything"], "query_idx": [128416, 57936, 120966, 6520, 71667], "support_idx": [59271, 51450, 57931, 126361, 83894]}
|
||||||
|
{"target_classes": ["other-law", "event-protest", "building-airport", "location-road/railway/highway/transit", "organization-sportsteam"], "query_idx": [83479, 31867, 91254, 100102, 108109], "support_idx": [38357, 105279, 29226, 113465, 45224]}
|
||||||
|
{"target_classes": ["other-god", "product-airplane", "product-ship", "event-protest", "other-award"], "query_idx": [76530, 20953, 26875, 21041, 75295], "support_idx": [87042, 121344, 66832, 122153, 40746]}
|
||||||
|
{"target_classes": ["location-GPE", "product-software", "art-film", "other-biologything", "building-airport"], "query_idx": [99546, 31778, 60736, 115175, 45604], "support_idx": [13509, 16341, 82343, 63652, 45731]}
|
||||||
|
{"target_classes": ["product-airplane", "other-language", "person-director", "event-protest", "location-GPE"], "query_idx": [122795, 20509, 76322, 98322, 87870], "support_idx": [116202, 15218, 33921, 102917, 77796]}
|
||||||
|
{"target_classes": ["organization-company", "product-software", "other-law", "organization-media/newspaper", "product-other"], "query_idx": [13360, 24527, 122844, 126093, 105019], "support_idx": [79122, 95407, 61315, 14147, 64059]}
|
||||||
|
{"target_classes": ["event-protest", "product-airplane", "person-soldier", "person-artist/author", "product-ship"], "query_idx": [3298, 82634, 111096, 63988, 39442], "support_idx": [96798, 69080, 69749, 60232, 38841]}
|
||||||
|
{"target_classes": ["other-god", "product-software", "art-broadcastprogram", "other-disease", "organization-sportsteam"], "query_idx": [64957, 90058, 64150, 55510, 82947], "support_idx": [110777, 63235, 63848, 41609, 90769]}
|
||||||
|
{"target_classes": ["building-hotel", "building-restaurant", "person-director", "person-artist/author", "other-law"], "query_idx": [36475, 97996, 81899, 107727, 120948], "support_idx": [57146, 123180, 45636, 69299, 10967]}
|
||||||
|
{"target_classes": ["person-artist/author", "event-attack/battle/war/militaryconflict", "organization-media/newspaper", "location-island", "organization-other"], "query_idx": [78920, 90936, 21417, 105820, 56455], "support_idx": [48831, 76272, 104368, 43682]}
|
|
@ -0,0 +1,40 @@
|
||||||
|
On O
|
||||||
|
June O
|
||||||
|
15 O
|
||||||
|
, O
|
||||||
|
1991 O
|
||||||
|
, O
|
||||||
|
Hoch building-other
|
||||||
|
Auditorium building-other
|
||||||
|
was O
|
||||||
|
struck O
|
||||||
|
by O
|
||||||
|
lightning O
|
||||||
|
. O
|
||||||
|
|
||||||
|
Completed O
|
||||||
|
in O
|
||||||
|
July O
|
||||||
|
2008 O
|
||||||
|
by O
|
||||||
|
Mighty building-other
|
||||||
|
River building-other
|
||||||
|
Power building-other
|
||||||
|
at O
|
||||||
|
a O
|
||||||
|
cost O
|
||||||
|
of O
|
||||||
|
NZ O
|
||||||
|
$ O
|
||||||
|
300 O
|
||||||
|
million O
|
||||||
|
, O
|
||||||
|
the O
|
||||||
|
plant O
|
||||||
|
's O
|
||||||
|
capacity O
|
||||||
|
proved O
|
||||||
|
greater O
|
||||||
|
than O
|
||||||
|
expected O
|
||||||
|
. O
|
|
@ -0,0 +1,32 @@
|
||||||
|
The O
|
||||||
|
stadium O
|
||||||
|
previously O
|
||||||
|
hosted O
|
||||||
|
the O
|
||||||
|
1982 O
|
||||||
|
Commonwealth event-sportsevent
|
||||||
|
Games event-sportsevent
|
||||||
|
and O
|
||||||
|
2001 O
|
||||||
|
Goodwill event-sportsevent
|
||||||
|
Games event-sportsevent
|
||||||
|
. O
|
||||||
|
|
||||||
|
The O
|
||||||
|
San organization-government/governmentagency
|
||||||
|
Diego organization-government/governmentagency
|
||||||
|
Harbor organization-government/governmentagency
|
||||||
|
Police organization-government/governmentagency
|
||||||
|
Department organization-government/governmentagency
|
||||||
|
is O
|
||||||
|
the O
|
||||||
|
law O
|
||||||
|
enforcement O
|
||||||
|
authority O
|
||||||
|
for O
|
||||||
|
the O
|
||||||
|
Port location-bodiesofwater
|
||||||
|
of location-bodiesofwater
|
||||||
|
San location-bodiesofwater
|
||||||
|
Diego location-bodiesofwater
|
||||||
|
. O
|
|
@ -0,0 +1,26 @@
|
||||||
|
When O
|
||||||
|
reconstruction O
|
||||||
|
of O
|
||||||
|
the O
|
||||||
|
building O
|
||||||
|
was O
|
||||||
|
complete O
|
||||||
|
, O
|
||||||
|
the O
|
||||||
|
rear O
|
||||||
|
half O
|
||||||
|
of O
|
||||||
|
the O
|
||||||
|
building O
|
||||||
|
was O
|
||||||
|
named O
|
||||||
|
Budig O
|
||||||
|
Hall O
|
||||||
|
, O
|
||||||
|
for O
|
||||||
|
then O
|
||||||
|
KU organization-education
|
||||||
|
Chancellor O
|
||||||
|
Gene O
|
||||||
|
Budig O
|
||||||
|
. O
|
|
@ -0,0 +1,251 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from typing import Dict, List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
B_PREF="B-"
|
||||||
|
I_PREF = "I-"
|
||||||
|
S_PREF = "S-"
|
||||||
|
E_PREF = "E-"
|
||||||
|
O = "O"
|
||||||
|
|
||||||
|
class LinearCRF(nn.Module):
|
||||||
|
def __init__(self, tag_size: int, schema: str, add_constraint: bool = False,
|
||||||
|
label2idx: Dict = None, device: torch.device = None):
|
||||||
|
super(LinearCRF, self).__init__()
|
||||||
|
self.label_size = tag_size + 3
|
||||||
|
self.label2idx = label2idx
|
||||||
|
self.tag_list = list(self.label2idx.keys())
|
||||||
|
self.start_idx = tag_size
|
||||||
|
self.end_idx = tag_size + 1
|
||||||
|
self.pad_idx = tag_size + 2
|
||||||
|
self.schema = schema
|
||||||
|
self.add_constraint = add_constraint
|
||||||
|
self.init_params(device=device)
|
||||||
|
return
|
||||||
|
|
||||||
|
def reset(self, label2idx, device):
|
||||||
|
if len(label2idx) == len(self.label2idx):
|
||||||
|
return
|
||||||
|
tag_size = len(label2idx)
|
||||||
|
self.label_size = tag_size + 3
|
||||||
|
self.label2idx = label2idx
|
||||||
|
self.tag_list = list(self.label2idx.keys())
|
||||||
|
self.start_idx = tag_size
|
||||||
|
self.end_idx = tag_size + 1
|
||||||
|
self.pad_idx = tag_size + 2
|
||||||
|
self.add_constraint = True
|
||||||
|
self.init_params(device)
|
||||||
|
return
|
||||||
|
|
||||||
|
def init_params(self, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
init_transition = torch.zeros(self.label_size, self.label_size, device=device)
|
||||||
|
init_transition[:, self.start_idx] = -10000.0
|
||||||
|
init_transition[self.end_idx, :] = -10000.0
|
||||||
|
init_transition[:, self.pad_idx] = -10000.0
|
||||||
|
init_transition[self.pad_idx, :] = -10000.0
|
||||||
|
if self.add_constraint:
|
||||||
|
if self.schema == "BIO":
|
||||||
|
self.add_constraint_for_bio(init_transition)
|
||||||
|
elif self.schema == "BIOES":
|
||||||
|
self.add_constraint_for_iobes(init_transition)
|
||||||
|
else:
|
||||||
|
print("[ERROR] wrong schema name!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
self.transition = nn.Parameter(init_transition, requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
def add_constraint_for_bio(self, transition: torch.Tensor):
|
||||||
|
for prev_label in self.tag_list:
|
||||||
|
for next_label in self.tag_list:
|
||||||
|
if prev_label == O and next_label.startswith(I_PREF):
|
||||||
|
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
|
||||||
|
if (prev_label.startswith(B_PREF) or prev_label.startswith(I_PREF)) and next_label.startswith(I_PREF):
|
||||||
|
if prev_label[2:] != next_label[2:]:
|
||||||
|
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
|
||||||
|
for label in self.tag_list:
|
||||||
|
if label.startswith(I_PREF):
|
||||||
|
transition[self.start_idx, self.label2idx[label]] = -10000.0
|
||||||
|
return
|
||||||
|
|
||||||
|
def add_constraint_for_iobes(self, transition: torch.Tensor):
|
||||||
|
for prev_label in self.tag_list:
|
||||||
|
for next_label in self.tag_list:
|
||||||
|
if prev_label == O and (next_label.startswith(I_PREF) or next_label.startswith(E_PREF)):
|
||||||
|
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
|
||||||
|
if prev_label.startswith(B_PREF) or prev_label.startswith(I_PREF):
|
||||||
|
if next_label.startswith(O) or next_label.startswith(B_PREF) or next_label.startswith(S_PREF):
|
||||||
|
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
|
||||||
|
elif prev_label[2:] != next_label[2:]:
|
||||||
|
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
|
||||||
|
if prev_label.startswith(S_PREF) or prev_label.startswith(E_PREF):
|
||||||
|
if next_label.startswith(I_PREF) or next_label.startswith(E_PREF):
|
||||||
|
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
|
||||||
|
for label in self.tag_list:
|
||||||
|
if label.startswith(I_PREF) or label.startswith(E_PREF):
|
||||||
|
transition[self.start_idx, self.label2idx[label]] = -10000.0
|
||||||
|
if label.startswith(I_PREF) or label.startswith(B_PREF):
|
||||||
|
transition[self.label2idx[label], self.end_idx] = -10000.0
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(self, lstm_scores, word_seq_lens, tags, mask, decode_flag=False):
|
||||||
|
all_scores = self.calculate_all_scores(lstm_scores=lstm_scores)
|
||||||
|
unlabed_score = self.forward_unlabeled(all_scores, word_seq_lens)
|
||||||
|
labeled_score = self.forward_labeled(all_scores, word_seq_lens, tags, mask)
|
||||||
|
per_sent_loss = (unlabed_score - labeled_score).sum() / mask.size(0)
|
||||||
|
if decode_flag:
|
||||||
|
_, decodeIdx = self.viterbi_decode(all_scores, word_seq_lens)
|
||||||
|
return per_sent_loss, decodeIdx
|
||||||
|
return per_sent_loss
|
||||||
|
|
||||||
|
def get_marginal_score(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
marginal = self.forward_backward(lstm_scores=lstm_scores, word_seq_lens=word_seq_lens)
|
||||||
|
return marginal
|
||||||
|
|
||||||
|
def forward_unlabeled(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = all_scores.size(0)
|
||||||
|
seq_len = all_scores.size(1)
|
||||||
|
dev_num = all_scores.get_device()
|
||||||
|
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
|
||||||
|
alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
|
||||||
|
|
||||||
|
alpha[:, 0, :] = all_scores[:, 0, self.start_idx, :]
|
||||||
|
|
||||||
|
for word_idx in range(1, seq_len):
|
||||||
|
before_log_sum_exp = alpha[:, word_idx-1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + all_scores[:, word_idx, :, :]
|
||||||
|
alpha[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
|
||||||
|
|
||||||
|
last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size)-1).view(batch_size, self.label_size)
|
||||||
|
last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
|
||||||
|
last_alpha = torch.logsumexp(last_alpha.view(batch_size, self.label_size, 1), dim=1).view(batch_size) #log Z(x)
|
||||||
|
return last_alpha
|
||||||
|
|
||||||
|
def backward(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = lstm_scores.size(0)
|
||||||
|
seq_len = lstm_scores.size(1)
|
||||||
|
dev_num = lstm_scores.get_device()
|
||||||
|
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
|
||||||
|
beta = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
|
||||||
|
|
||||||
|
rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
|
||||||
|
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
|
||||||
|
|
||||||
|
perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range(word_seq_lens[batch_idx] - 1, 0, -1)
|
||||||
|
perm_idx = perm_idx.long()
|
||||||
|
for i, length in enumerate(word_seq_lens):
|
||||||
|
rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]]
|
||||||
|
|
||||||
|
beta[:, 0, :] = rev_score[:, 0, self.end_idx, :]
|
||||||
|
for word_idx in range(1, seq_len):
|
||||||
|
before_log_sum_exp = beta[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + rev_score[:, word_idx, :, :]
|
||||||
|
beta[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
|
||||||
|
|
||||||
|
last_beta = torch.gather(beta, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size) - 1).view(batch_size, self.label_size)
|
||||||
|
last_beta += self.transition.transpose(0, 1)[:, self.start_idx].view(1, self.label_size).expand(batch_size, self.label_size)
|
||||||
|
last_beta = torch.logsumexp(last_beta, dim=1)
|
||||||
|
|
||||||
|
for i, length in enumerate(word_seq_lens):
|
||||||
|
beta[i, :length] = beta[i, :length][perm_idx[i, :length]]
|
||||||
|
return torch.sum(last_beta)
|
||||||
|
|
||||||
|
def forward_backward(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = lstm_scores.size(0)
|
||||||
|
seq_len = lstm_scores.size(1)
|
||||||
|
dev_num = lstm_scores.get_device()
|
||||||
|
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
|
||||||
|
alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
|
||||||
|
beta = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
|
||||||
|
scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
|
||||||
|
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
|
||||||
|
rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
|
||||||
|
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
|
||||||
|
|
||||||
|
perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range(word_seq_lens[batch_idx] - 1, 0, -1)
|
||||||
|
perm_idx = perm_idx.long()
|
||||||
|
for i, length in enumerate(word_seq_lens):
|
||||||
|
rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]]
|
||||||
|
alpha[:, 0, :] = scores[:, 0, self.start_idx, :]
|
||||||
|
beta[:, 0, :] = rev_score[:, 0, self.end_idx, :]
|
||||||
|
for word_idx in range(1, seq_len):
|
||||||
|
before_log_sum_exp = alpha[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + scores[ :, word_idx, :, :]
|
||||||
|
alpha[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
|
||||||
|
|
||||||
|
before_log_sum_exp = beta[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + rev_score[:, word_idx, :, :]
|
||||||
|
beta[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
|
||||||
|
|
||||||
|
last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size) - 1).view(batch_size, self.label_size)
|
||||||
|
last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
|
||||||
|
last_alpha = torch.logsumexp(last_alpha.view(batch_size, self.label_size), dim=-1).view(batch_size, 1, 1).expand(batch_size, seq_len, self.label_size)
|
||||||
|
for i, length in enumerate(word_seq_lens):
|
||||||
|
beta[i, :length] = beta[i, :length][perm_idx[i, :length]]
|
||||||
|
return alpha + beta - last_alpha - lstm_scores
|
||||||
|
|
||||||
|
def forward_labeled(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor, tags: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
|
||||||
|
batchSize = all_scores.shape[0]
|
||||||
|
sentLength = all_scores.shape[1]
|
||||||
|
currentTagScores = torch.gather(all_scores, 3, tags.view(batchSize, sentLength, 1, 1).expand(batchSize, sentLength, self.label_size, 1)).view(batchSize, -1, self.label_size)
|
||||||
|
tagTransScoresMiddle = None
|
||||||
|
if sentLength != 1:
|
||||||
|
tagTransScoresMiddle = torch.gather(currentTagScores[:, 1:, :], 2, tags[:, :sentLength - 1].view(batchSize, sentLength - 1, 1)).view(batchSize, -1)
|
||||||
|
tagTransScoresBegin = currentTagScores[:, 0, self.start_idx]
|
||||||
|
endTagIds = torch.gather(tags, 1, word_seq_lens.view(batchSize, 1) - 1)
|
||||||
|
tagTransScoresEnd = torch.gather(self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size), 1, endTagIds).view(batchSize)
|
||||||
|
score = tagTransScoresBegin + tagTransScoresEnd
|
||||||
|
masks = masks.type(torch.float32)
|
||||||
|
|
||||||
|
if sentLength != 1:
|
||||||
|
score += torch.sum(tagTransScoresMiddle.mul(masks[:, 1:]), dim=1)
|
||||||
|
return score
|
||||||
|
|
||||||
|
def calculate_all_scores(self, lstm_scores: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = lstm_scores.size(0)
|
||||||
|
seq_len = lstm_scores.size(1)
|
||||||
|
scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
|
||||||
|
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def decode(self, features, wordSeqLengths, new_label2idx=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if new_label2idx is not None:
|
||||||
|
self.reset(new_label2idx, features.device)
|
||||||
|
all_scores = self.calculate_all_scores(features)
|
||||||
|
bestScores, decodeIdx = self.viterbi_decode(all_scores, wordSeqLengths)
|
||||||
|
return bestScores, decodeIdx
|
||||||
|
|
||||||
|
def viterbi_decode(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
batchSize = all_scores.shape[0]
|
||||||
|
sentLength = all_scores.shape[1]
|
||||||
|
dev_num = all_scores.get_device()
|
||||||
|
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
|
||||||
|
scoresRecord = torch.zeros([batchSize, sentLength, self.label_size], device=curr_dev)
|
||||||
|
idxRecord = torch.zeros([batchSize, sentLength, self.label_size], dtype=torch.int64, device=curr_dev)
|
||||||
|
startIds = torch.full((batchSize, self.label_size), self.start_idx, dtype=torch.int64, device=curr_dev)
|
||||||
|
decodeIdx = torch.LongTensor(batchSize, sentLength).to(curr_dev)
|
||||||
|
|
||||||
|
scores = all_scores
|
||||||
|
scoresRecord[:, 0, :] = scores[:, 0, self.start_idx, :]
|
||||||
|
idxRecord[:, 0, :] = startIds
|
||||||
|
for wordIdx in range(1, sentLength):
|
||||||
|
scoresIdx = scoresRecord[:, wordIdx - 1, :].view(batchSize, self.label_size, 1).expand(batchSize, self.label_size,
|
||||||
|
self.label_size) + scores[:, wordIdx, :, :]
|
||||||
|
idxRecord[:, wordIdx, :] = torch.argmax(scoresIdx, 1)
|
||||||
|
scoresRecord[:, wordIdx, :] = torch.gather(scoresIdx, 1, idxRecord[:, wordIdx, :].view(batchSize, 1, self.label_size)).view(batchSize, self.label_size)
|
||||||
|
lastScores = torch.gather(scoresRecord, 1, word_seq_lens.view(batchSize, 1, 1).expand(batchSize, 1, self.label_size) - 1).view(batchSize, self.label_size) ##select position
|
||||||
|
lastScores += self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size)
|
||||||
|
decodeIdx[:, 0] = torch.argmax(lastScores, 1)
|
||||||
|
bestScores = torch.gather(lastScores, 1, decodeIdx[:, 0].view(batchSize, 1))
|
||||||
|
|
||||||
|
for distance2Last in range(sentLength - 1):
|
||||||
|
curIdx = torch.clamp(word_seq_lens - distance2Last - 1, min=1).view(batchSize, 1, 1).expand(batchSize, 1, self.label_size)
|
||||||
|
lastNIdxRecord = torch.gather(idxRecord, 1, curIdx).view(batchSize, self.label_size)
|
||||||
|
decodeIdx[:, distance2Last + 1] = torch.gather(lastNIdxRecord, 1, decodeIdx[:, distance2Last].view(batchSize, 1)).view(batchSize)
|
||||||
|
perm_pos = torch.arange(1, sentLength + 1).to(curr_dev)
|
||||||
|
perm_pos = perm_pos.unsqueeze(0).expand(batchSize, sentLength)
|
||||||
|
perm_pos = word_seq_lens.unsqueeze(1).expand(batchSize, sentLength) - perm_pos
|
||||||
|
perm_pos = perm_pos.masked_fill(perm_pos < 0, 0)
|
||||||
|
decodeIdx = torch.gather(decodeIdx, 1, perm_pos)
|
||||||
|
return bestScores, decodeIdx
|
|
@ -0,0 +1,323 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .crf import LinearCRF
|
||||||
|
from .span_fewshotmodel import FewShotSpanModel
|
||||||
|
from util.span_sample import convert_bio2spans
|
||||||
|
from .loss_model import MaxLoss
|
||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class SelectedJointModel(FewShotSpanModel):
|
||||||
|
def __init__(self, span_encoder, num_tag, ment_label2idx, schema, use_crf, max_loss, use_oproto, dot=False, normalize="none",
|
||||||
|
temperature=None, use_focal=False, type_lam=1):
|
||||||
|
super(SelectedJointModel, self).__init__(span_encoder)
|
||||||
|
self.dot = dot
|
||||||
|
self.normalize = normalize
|
||||||
|
self.temperature = temperature
|
||||||
|
self.use_oproto = use_oproto
|
||||||
|
self.proto = None
|
||||||
|
self.num_tag = num_tag
|
||||||
|
self.use_crf = use_crf
|
||||||
|
self.ment_label2idx = ment_label2idx
|
||||||
|
self.ment_idx2label = {idx: label for label, idx in self.ment_label2idx.items()}
|
||||||
|
self.schema = schema
|
||||||
|
self.cls = nn.Linear(span_encoder.word_dim, self.num_tag)
|
||||||
|
self.type_lam = type_lam
|
||||||
|
self.proto = None
|
||||||
|
if self.use_crf:
|
||||||
|
self.crf_layer = LinearCRF(self.num_tag, schema=schema, add_constraint=True, label2idx=ment_label2idx)
|
||||||
|
if use_focal:
|
||||||
|
raise ValueError("not support focal loss")
|
||||||
|
self.ment_loss_fct = MaxLoss(gamma=max_loss)
|
||||||
|
self.base_loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
|
||||||
|
print("use cross entropy loss")
|
||||||
|
print("use dot : {} use normalizatioin: {} use temperature: {}".format(self.dot, self.normalize,
|
||||||
|
self.temperature if self.temperature else "none"))
|
||||||
|
self.cached_o_proto = torch.zeros(span_encoder.span_dim, requires_grad=False)
|
||||||
|
self.init_weights()
|
||||||
|
return
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.cls.weight.data.normal_(mean=0.0, std=0.02)
|
||||||
|
if self.cls.bias is not None:
|
||||||
|
self.cls.bias.data.zero_()
|
||||||
|
if self.use_crf:
|
||||||
|
self.crf_layer.init_params()
|
||||||
|
return
|
||||||
|
|
||||||
|
def type_loss_fct(self, logits, targets, inst_weights=None):
|
||||||
|
if inst_weights is None:
|
||||||
|
loss = self.base_loss_fct(logits, targets)
|
||||||
|
loss = loss.mean()
|
||||||
|
else:
|
||||||
|
targets = torch.clamp(targets, min=0)
|
||||||
|
one_hot_targets = torch.zeros(logits.size(), device=logits.device).scatter_(1, targets.unsqueeze(1), 1)
|
||||||
|
soft_labels = inst_weights.unsqueeze(1) * one_hot_targets + (1 - one_hot_targets) * (1 - inst_weights).unsqueeze(1) / (logits.size(1) - 1)
|
||||||
|
logp = F.log_softmax(logits, dim=-1)
|
||||||
|
loss = - (logp * soft_labels).sum(1)
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def __dist__(self, x, y, dim):
|
||||||
|
if self.normalize == 'l2':
|
||||||
|
x = F.normalize(x, p=2, dim=-1)
|
||||||
|
y = F.normalize(y, p=2, dim=-1)
|
||||||
|
if self.dot:
|
||||||
|
sim = (x * y).sum(dim)
|
||||||
|
else:
|
||||||
|
sim = -(torch.pow(x - y, 2)).sum(dim)
|
||||||
|
if self.temperature:
|
||||||
|
sim = sim / self.temperature
|
||||||
|
return sim
|
||||||
|
|
||||||
|
def __batch_dist__(self, S_emb, Q_emb, Q_mask):
|
||||||
|
Q_emb = Q_emb[Q_mask.eq(1), :].view(-1, Q_emb.size(-1))
|
||||||
|
dist = self.__dist__(S_emb.unsqueeze(0), Q_emb.unsqueeze(1), 2)
|
||||||
|
return dist
|
||||||
|
|
||||||
|
def __get_proto__(self, S_emb, S_tag, S_mask):
|
||||||
|
proto = []
|
||||||
|
embedding = S_emb[S_mask.eq(1), :].view(-1, S_emb.size(-1))
|
||||||
|
S_tag = S_tag[S_mask.eq(1)]
|
||||||
|
if self.use_oproto:
|
||||||
|
st_idx = 0
|
||||||
|
else:
|
||||||
|
st_idx = 1
|
||||||
|
proto = [self.cached_o_proto]
|
||||||
|
for label in range(st_idx, torch.max(S_tag) + 1):
|
||||||
|
proto.append(torch.mean(embedding[S_tag.eq(label), :], 0))
|
||||||
|
proto = torch.stack(proto, dim=0)
|
||||||
|
return proto
|
||||||
|
|
||||||
|
def __get_proto_dist__(self, Q_emb, Q_mask):
|
||||||
|
dist = self.__batch_dist__(self.proto, Q_emb, Q_mask)
|
||||||
|
if not self.use_oproto:
|
||||||
|
dist[:, 0] = -1000000
|
||||||
|
return dist
|
||||||
|
|
||||||
|
def forward_type_step(self, query, encoder_mode=None, query_bottom_hiddens=None):
|
||||||
|
if query['span_mask'].sum().item() == 0: # there is no query mentions
|
||||||
|
print("no query mentions")
|
||||||
|
empty_tensor = torch.tensor([], device=query['word'].device)
|
||||||
|
zero_tensor = torch.tensor([0], device=query['word'].device)
|
||||||
|
return empty_tensor, empty_tensor, empty_tensor, zero_tensor
|
||||||
|
query_span_emb = self.word_encoder(query['word'], query['word_mask'], word_to_piece_inds=query['word_to_piece_ind'],
|
||||||
|
word_to_piece_ends=query['word_to_piece_end'], span_indices=query['span_indices'],
|
||||||
|
mode=encoder_mode, bottom_hiddens=query_bottom_hiddens)
|
||||||
|
logits = self.__get_proto_dist__(query_span_emb, query['span_mask'])
|
||||||
|
golds = query["span_tag"][query["span_mask"].eq(1)].view(-1)
|
||||||
|
if query["span_weights"] is not None:
|
||||||
|
query_span_weights = query["span_weights"][query["span_mask"].eq(1)].view(-1)
|
||||||
|
else:
|
||||||
|
query_span_weights = None
|
||||||
|
if self.use_oproto:
|
||||||
|
loss = self.type_loss_fct(logits, golds, inst_weights=query_span_weights)
|
||||||
|
else:
|
||||||
|
loss = self.type_loss_fct(logits[:, 1:], golds - 1, inst_weights=query_span_weights)
|
||||||
|
_, preds = torch.max(logits, dim=-1)
|
||||||
|
return logits, preds, golds, loss
|
||||||
|
|
||||||
|
def init_proto(self, support_data, encoder_mode=None, support_bottom_hiddens=None):
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
support_span_emb = self.word_encoder(support_data['word'], support_data['word_mask'], word_to_piece_inds=support_data['word_to_piece_ind'],
|
||||||
|
word_to_piece_ends=support_data['word_to_piece_end'], span_indices=support_data['span_indices'],
|
||||||
|
mode=encoder_mode, bottom_hiddens=support_bottom_hiddens)
|
||||||
|
self.cached_o_proto = self.cached_o_proto.to(support_span_emb.device)
|
||||||
|
proto = self.__get_proto__(support_span_emb, support_data['span_tag'], support_data['span_mask'])
|
||||||
|
self.proto = nn.Parameter(proto.data, requires_grad=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward_ment_step(self, batch, crf_mode=True, encoder_mode=None):
|
||||||
|
res = self.word_encoder(batch['word'], batch['word_mask'],
|
||||||
|
batch['word_to_piece_ind'],
|
||||||
|
batch['word_to_piece_end'], mode=encoder_mode,
|
||||||
|
)
|
||||||
|
word_emb = res[0]
|
||||||
|
bottom_hiddens = res[1]
|
||||||
|
|
||||||
|
logits = self.cls(word_emb)
|
||||||
|
gold = batch['ment_labels']
|
||||||
|
tot_loss = self.ment_loss_fct(logits, gold)
|
||||||
|
if self.use_crf and crf_mode:
|
||||||
|
crf_sp_logits = torch.zeros((logits.size(0), logits.size(1), 3), device=logits.device)
|
||||||
|
crf_sp_logits = torch.cat([logits, crf_sp_logits], dim=2)
|
||||||
|
_, pred = self.crf_layer.decode(crf_sp_logits, batch['seq_len'])
|
||||||
|
else:
|
||||||
|
pred = torch.argmax(logits, dim=-1)
|
||||||
|
pred = pred.masked_fill(gold.eq(-1), -1)
|
||||||
|
return logits, pred, gold, tot_loss, bottom_hiddens
|
||||||
|
|
||||||
|
def joint_inner_update(self, support_data, inner_steps, lr_inner):
|
||||||
|
self.init_proto(support_data, encoder_mode="type")
|
||||||
|
parameters_to_optimize = list(self.named_parameters())
|
||||||
|
decay_params = []
|
||||||
|
nodecay_params = []
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
for n, p in parameters_to_optimize:
|
||||||
|
if p.requires_grad:
|
||||||
|
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
|
||||||
|
decay_params.append(p)
|
||||||
|
else:
|
||||||
|
nodecay_params.append(p)
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': decay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 1e-3},
|
||||||
|
{'params': nodecay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 0},
|
||||||
|
]
|
||||||
|
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
|
||||||
|
self.train()
|
||||||
|
for _ in range(inner_steps):
|
||||||
|
inner_opt.zero_grad()
|
||||||
|
_, _, _, ment_loss, support_bottom_hiddens = self.forward_ment_step(support_data, crf_mode=False, encoder_mode="ment")
|
||||||
|
_, _, _, type_loss = self.forward_type_step(support_data, encoder_mode="type", query_bottom_hiddens=support_bottom_hiddens)
|
||||||
|
|
||||||
|
loss = ment_loss + type_loss * self.type_lam
|
||||||
|
loss.backward()
|
||||||
|
inner_opt.step()
|
||||||
|
return
|
||||||
|
|
||||||
|
def decode_ments(self, episode_preds):
|
||||||
|
episode_preds = episode_preds.detach()
|
||||||
|
span_indices = torch.zeros((episode_preds.size(0), 100, 2), dtype=torch.long)
|
||||||
|
span_masks = torch.zeros(span_indices.size()[:2], dtype=torch.long)
|
||||||
|
span_labels = torch.full_like(span_masks, fill_value=1, dtype=torch.long)
|
||||||
|
max_span_num = 0
|
||||||
|
for i, pred in enumerate(episode_preds):
|
||||||
|
seqs = []
|
||||||
|
for idx in pred:
|
||||||
|
if idx.item() == -1:
|
||||||
|
break
|
||||||
|
seqs.append(self.ment_idx2label[idx.item()])
|
||||||
|
ents = convert_bio2spans(seqs, self.schema)
|
||||||
|
max_span_num = max(max_span_num, len(ents))
|
||||||
|
for j, x in enumerate(ents):
|
||||||
|
span_indices[i, j, 0] = x[1]
|
||||||
|
span_indices[i, j, 1] = x[2]
|
||||||
|
span_masks[i, :len(ents)] = 1
|
||||||
|
return span_indices, span_masks, span_labels, max_span_num
|
||||||
|
|
||||||
|
|
||||||
|
# forward proto maml
|
||||||
|
def forward_joint_meta(self, batch, inner_steps, lr_inner, mode):
|
||||||
|
no_grads = ["proto"]
|
||||||
|
if batch['query']['span_mask'].sum().item() == 0: # there is no query mentions
|
||||||
|
print("no query mentions")
|
||||||
|
no_grads.append("type_adapters")
|
||||||
|
names, params = self.get_named_params(no_grads=no_grads)
|
||||||
|
weights = deepcopy(params)
|
||||||
|
meta_grad = []
|
||||||
|
episode_losses = []
|
||||||
|
episode_ment_losses = []
|
||||||
|
episode_type_losses = []
|
||||||
|
query_ment_logits = []
|
||||||
|
query_ment_preds = []
|
||||||
|
query_ment_golds = []
|
||||||
|
query_type_logits = []
|
||||||
|
query_type_preds = []
|
||||||
|
query_type_golds = []
|
||||||
|
two_stage_query_ments = []
|
||||||
|
two_stage_query_masks = []
|
||||||
|
two_stage_max_snum = 0
|
||||||
|
current_support_num = 0
|
||||||
|
current_query_num = 0
|
||||||
|
support, query = batch["support"], batch["query"]
|
||||||
|
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'ment_labels', 'span_indices', 'span_mask', 'span_tag', 'span_weights']
|
||||||
|
|
||||||
|
for i, sent_support_num in enumerate(support['sentence_num']):
|
||||||
|
sent_query_num = query['sentence_num'][i]
|
||||||
|
one_support = {
|
||||||
|
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys if k in support
|
||||||
|
}
|
||||||
|
one_query = {
|
||||||
|
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys if k in query
|
||||||
|
}
|
||||||
|
self.zero_grad()
|
||||||
|
self.joint_inner_update(one_support, inner_steps, lr_inner) # inner update parameters on support data
|
||||||
|
if mode == "train":
|
||||||
|
qy_ment_logits, qy_ment_pred, qy_ment_gold, qy_ment_loss, qy_bottom_hiddens = self.forward_ment_step(one_query, crf_mode=False, encoder_mode="ment") # evaluate on query data
|
||||||
|
qy_type_logits, qy_type_pred, qy_type_gold, qy_type_loss = self.forward_type_step(one_query, encoder_mode="type", query_bottom_hiddens=qy_bottom_hiddens) # evaluate on query data
|
||||||
|
qy_loss = qy_ment_loss + self.type_lam * qy_type_loss
|
||||||
|
if one_query['span_mask'].sum().item() == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
grad = torch.autograd.grad(qy_loss, params) # meta-update
|
||||||
|
meta_grad.append(grad)
|
||||||
|
elif mode == "test-onestage":
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
qy_ment_logits, qy_ment_pred, qy_ment_gold, qy_ment_loss, qy_bottom_hiddens = self.forward_ment_step(one_query, encoder_mode="ment") # evaluate on query data
|
||||||
|
qy_type_logits, qy_type_pred, qy_type_gold, qy_type_loss = self.forward_type_step(one_query, encoder_mode="type", query_bottom_hiddens=qy_bottom_hiddens) # evaluate on query data
|
||||||
|
qy_loss = qy_ment_loss + self.type_lam * qy_type_loss
|
||||||
|
else:
|
||||||
|
assert mode == "test-twostage"
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
qy_ment_logits, qy_ment_pred, qy_ment_gold, qy_ment_loss, qy_bottom_hiddens = self.forward_ment_step(one_query, encoder_mode="ment") # evaluate on query data
|
||||||
|
span_indices, span_mask, span_tag, max_span_num = self.decode_ments(qy_ment_pred)
|
||||||
|
one_query['span_indices'] = span_indices[:, :max_span_num, :].to(qy_ment_logits.device)
|
||||||
|
one_query['span_mask'] = span_mask[:, :max_span_num].to(qy_ment_logits.device)
|
||||||
|
one_query['span_tag'] = span_tag[:, :max_span_num].to(qy_ment_logits.device)
|
||||||
|
one_query['span_weights'] = None
|
||||||
|
two_stage_max_snum = max(two_stage_max_snum, max_span_num)
|
||||||
|
two_stage_query_masks.append(span_mask)
|
||||||
|
two_stage_query_ments.append(span_indices)
|
||||||
|
qy_type_logits, qy_type_pred, qy_type_gold, qy_type_loss = self.forward_type_step(one_query, encoder_mode="type", query_bottom_hiddens=qy_bottom_hiddens) # evaluate on query data
|
||||||
|
qy_loss = qy_ment_loss + self.type_lam * qy_type_loss
|
||||||
|
|
||||||
|
episode_losses.append(qy_loss.item())
|
||||||
|
episode_ment_losses.append(qy_ment_loss.item())
|
||||||
|
episode_type_losses.append(qy_type_loss.item())
|
||||||
|
query_ment_preds.append(qy_ment_pred)
|
||||||
|
query_ment_golds.append(qy_ment_gold)
|
||||||
|
query_ment_logits.append(qy_ment_logits)
|
||||||
|
query_type_preds.append(qy_type_pred)
|
||||||
|
query_type_golds.append(qy_type_gold)
|
||||||
|
query_type_logits.append(qy_type_logits)
|
||||||
|
self.load_weights(names, weights)
|
||||||
|
|
||||||
|
current_query_num += sent_query_num
|
||||||
|
current_support_num += sent_support_num
|
||||||
|
|
||||||
|
self.zero_grad()
|
||||||
|
|
||||||
|
if mode == "test-twostage":
|
||||||
|
two_stage_query_masks = torch.cat(two_stage_query_masks, dim=0)[:, :two_stage_max_snum]
|
||||||
|
two_stage_query_ments = torch.cat(two_stage_query_ments, dim=0)[:, :two_stage_max_snum, :]
|
||||||
|
return {'loss': np.mean(episode_losses), 'ment_loss': np.mean(episode_ment_losses), 'type_loss': np.mean(episode_type_losses), 'names': names, 'grads': meta_grad,
|
||||||
|
'ment_preds': query_ment_preds, 'ment_golds': query_ment_golds, 'ment_logits': query_ment_logits,
|
||||||
|
'type_preds': query_type_preds, 'type_golds': query_type_golds, 'type_logits': query_type_logits,
|
||||||
|
'pred_spans': two_stage_query_ments, 'pred_masks': two_stage_query_masks
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {'loss': np.mean(episode_losses), 'ment_loss': np.mean(episode_ment_losses), 'type_loss': np.mean(episode_type_losses), 'names': names, 'grads': meta_grad,
|
||||||
|
'ment_preds': query_ment_preds, 'ment_golds': query_ment_golds, 'ment_logits': query_ment_logits,
|
||||||
|
'type_preds': query_type_preds, 'type_golds': query_type_golds, 'type_logits': query_type_logits,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_named_params(self, no_grads=[]):
|
||||||
|
names = []
|
||||||
|
params = []
|
||||||
|
for n, p in self.named_parameters():
|
||||||
|
if any([pn in n for pn in no_grads]):
|
||||||
|
continue
|
||||||
|
if p.requires_grad:
|
||||||
|
names.append(n)
|
||||||
|
params.append(p)
|
||||||
|
return names, params
|
||||||
|
|
||||||
|
def load_weights(self, names, params):
|
||||||
|
model_params = self.state_dict()
|
||||||
|
for n, p in zip(names, params):
|
||||||
|
model_params[n].data.copy_(p.data)
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_gradients(self, names, grads):
|
||||||
|
model_params = self.state_dict(keep_vars=True)
|
||||||
|
for n, g in zip(names, grads):
|
||||||
|
if model_params[n].grad is None:
|
||||||
|
continue
|
||||||
|
model_params[n].grad.data.add_(g.data) # accumulate
|
||||||
|
return
|
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class MaxLoss(nn.Module):
|
||||||
|
def __init__(self, gamma=1.0):
|
||||||
|
super(MaxLoss, self).__init__()
|
||||||
|
self.gamma = gamma
|
||||||
|
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(self, logits, targets):
|
||||||
|
assert logits.dim() == 3 #batch_size, seq_len, label_num
|
||||||
|
batch_size, seq_len, class_num = logits.size()
|
||||||
|
token_loss = self.loss_fct(logits.view(-1, class_num), targets.view(-1)).view(batch_size, seq_len)
|
||||||
|
act_pos = targets.ne(-1)
|
||||||
|
loss = token_loss[act_pos].mean()
|
||||||
|
if self.gamma > 0:
|
||||||
|
max_loss = torch.max(token_loss, dim=1)[0]
|
||||||
|
loss += self.gamma * max_loss.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
class FocalLoss(nn.Module):
|
||||||
|
def __init__(self, gamma=1.0, reduction=False):
|
||||||
|
super(FocalLoss, self).__init__()
|
||||||
|
self.gamma = gamma
|
||||||
|
self.reduction = reduction
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(self, logits, targets):
|
||||||
|
assert logits.dim() == 2
|
||||||
|
assert torch.min(targets).item() > -1
|
||||||
|
|
||||||
|
logp = F.log_softmax(logits, dim=1)
|
||||||
|
target_logp = logp.gather(1, targets.view(-1, 1)).view(-1)
|
||||||
|
target_p = torch.exp(target_logp)
|
||||||
|
weight = (1 - target_p) ** self.gamma
|
||||||
|
loss = - weight * target_logp
|
||||||
|
if self.reduction:
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
|
@ -0,0 +1,171 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .span_fewshotmodel import FewShotSpanModel
|
||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
from .crf import LinearCRF
|
||||||
|
from .loss_model import MaxLoss
|
||||||
|
|
||||||
|
class MentSeqtagger(FewShotSpanModel):
|
||||||
|
def __init__(self, span_encoder, num_tag, label2idx, schema, use_crf, max_loss):
|
||||||
|
super(MentSeqtagger, self).__init__(span_encoder)
|
||||||
|
self.num_tag = num_tag
|
||||||
|
self.use_crf = use_crf
|
||||||
|
self.label2idx = label2idx
|
||||||
|
self.idx2label = {idx: label for label, idx in self.label2idx.items()}
|
||||||
|
self.schema = schema
|
||||||
|
self.cls = nn.Linear(span_encoder.word_dim, self.num_tag)
|
||||||
|
if self.use_crf:
|
||||||
|
self.crf_layer = LinearCRF(self.num_tag, schema=schema, add_constraint=True, label2idx=label2idx)
|
||||||
|
self.ment_loss_fct = MaxLoss(gamma=max_loss)
|
||||||
|
self.init_weights()
|
||||||
|
return
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.cls.weight.data.normal_(mean=0.0, std=0.02)
|
||||||
|
if self.cls.bias is not None:
|
||||||
|
self.cls.bias.data.zero_()
|
||||||
|
if self.use_crf:
|
||||||
|
self.crf_layer.init_params()
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward_step(self, batch, crf_mode=True):
|
||||||
|
word_emb = self.word_encoder(batch['word'], batch['word_mask'],
|
||||||
|
batch['word_to_piece_ind'],
|
||||||
|
batch['word_to_piece_end'])
|
||||||
|
logits = self.cls(word_emb)
|
||||||
|
gold = batch['ment_labels']
|
||||||
|
tot_loss = self.ment_loss_fct(logits, gold)
|
||||||
|
if self.use_crf and crf_mode:
|
||||||
|
crf_sp_logits = torch.zeros((logits.size(0), logits.size(1), 3), device=logits.device)
|
||||||
|
crf_sp_logits = torch.cat([logits, crf_sp_logits], dim=2)
|
||||||
|
_, pred = self.crf_layer.decode(crf_sp_logits, batch['seq_len'])
|
||||||
|
else:
|
||||||
|
pred = torch.argmax(logits, dim=-1)
|
||||||
|
pred = pred.masked_fill(gold.eq(-1), -1)
|
||||||
|
return logits, pred, gold, tot_loss
|
||||||
|
|
||||||
|
|
||||||
|
def inner_update(self, train_data, inner_steps, lr_inner):
|
||||||
|
parameters_to_optimize = list(self.named_parameters())
|
||||||
|
decay_params = []
|
||||||
|
nodecay_params = []
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
for n, p in parameters_to_optimize:
|
||||||
|
if p.requires_grad:
|
||||||
|
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
|
||||||
|
decay_params.append(p)
|
||||||
|
else:
|
||||||
|
nodecay_params.append(p)
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': decay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 1e-3},
|
||||||
|
{'params': nodecay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 0},
|
||||||
|
]
|
||||||
|
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
|
||||||
|
self.train()
|
||||||
|
for _ in range(inner_steps):
|
||||||
|
inner_opt.zero_grad()
|
||||||
|
_, _, _, loss = self.forward_step(train_data, crf_mode=False)
|
||||||
|
loss.backward()
|
||||||
|
inner_opt.step()
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward_sup(self, batch, mode):
|
||||||
|
support, query = batch["support"], batch["query"]
|
||||||
|
query_logits = []
|
||||||
|
query_preds = []
|
||||||
|
query_golds = []
|
||||||
|
all_loss = 0
|
||||||
|
task_num = 0
|
||||||
|
current_support_num = 0
|
||||||
|
current_query_num = 0
|
||||||
|
if mode == "train":
|
||||||
|
crf_mode = False
|
||||||
|
else:
|
||||||
|
assert mode == "test"
|
||||||
|
crf_mode = True
|
||||||
|
self.eval()
|
||||||
|
print('eval mode')
|
||||||
|
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'ment_labels']
|
||||||
|
for i, sent_support_num in enumerate(support['sentence_num']):
|
||||||
|
sent_query_num = query['sentence_num'][i]
|
||||||
|
one_support = {
|
||||||
|
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys
|
||||||
|
}
|
||||||
|
one_query = {
|
||||||
|
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys
|
||||||
|
}
|
||||||
|
_, _, _, sp_loss = self.forward_step(one_support, False)
|
||||||
|
all_loss += sp_loss
|
||||||
|
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query, crf_mode)
|
||||||
|
query_preds.append(qy_pred)
|
||||||
|
query_golds.append(qy_gold)
|
||||||
|
query_logits.append(qy_logits)
|
||||||
|
all_loss += qy_loss
|
||||||
|
task_num += 2
|
||||||
|
current_query_num += sent_query_num
|
||||||
|
current_support_num += sent_support_num
|
||||||
|
return {'loss': all_loss / task_num, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
|
||||||
|
|
||||||
|
def forward_meta(self, batch, inner_steps, lr_inner, mode):
|
||||||
|
names, params = self.get_named_params()
|
||||||
|
weights = deepcopy(params)
|
||||||
|
meta_grad = []
|
||||||
|
episode_losses = []
|
||||||
|
query_preds = []
|
||||||
|
query_golds = []
|
||||||
|
query_logits = []
|
||||||
|
support, query = batch["support"], batch["query"]
|
||||||
|
current_support_num = 0
|
||||||
|
current_query_num = 0
|
||||||
|
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'ment_labels']
|
||||||
|
for i, sent_support_num in enumerate(support['sentence_num']):
|
||||||
|
sent_query_num = query['sentence_num'][i]
|
||||||
|
one_support = {
|
||||||
|
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys
|
||||||
|
}
|
||||||
|
one_query = {
|
||||||
|
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys
|
||||||
|
}
|
||||||
|
self.zero_grad()
|
||||||
|
self.inner_update(one_support, inner_steps, lr_inner)
|
||||||
|
if mode == "train":
|
||||||
|
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query, crf_mode=False)
|
||||||
|
grad = torch.autograd.grad(qy_loss, params)
|
||||||
|
meta_grad.append(grad)
|
||||||
|
else:
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
|
||||||
|
episode_losses.append(qy_loss.item())
|
||||||
|
query_preds.append(qy_pred)
|
||||||
|
query_golds.append(qy_gold)
|
||||||
|
query_logits.append(qy_logits)
|
||||||
|
self.load_weights(names, weights)
|
||||||
|
|
||||||
|
current_query_num += sent_query_num
|
||||||
|
current_support_num += sent_support_num
|
||||||
|
self.zero_grad()
|
||||||
|
return {'loss': np.mean(episode_losses), 'names': names, 'grads': meta_grad, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
|
||||||
|
|
||||||
|
def get_named_params(self):
|
||||||
|
names = [n for n, p in self.named_parameters() if p.requires_grad]
|
||||||
|
params = [p for n, p in self.named_parameters() if p.requires_grad]
|
||||||
|
return names, params
|
||||||
|
|
||||||
|
def load_weights(self, names, params):
|
||||||
|
model_params = self.state_dict()
|
||||||
|
for n, p in zip(names, params):
|
||||||
|
model_params[n].data.copy_(p.data)
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_gradients(self, names, grads):
|
||||||
|
model_params = self.state_dict(keep_vars=True)
|
||||||
|
for n, g in zip(names, grads):
|
||||||
|
if model_params[n].grad is None:
|
||||||
|
continue
|
||||||
|
model_params[n].grad.data.add_(g.data) # accumulate
|
||||||
|
return
|
|
@ -0,0 +1,64 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class FewShotTokenModel(nn.Module):
|
||||||
|
def __init__(self, my_word_encoder):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.word_encoder = nn.DataParallel(my_word_encoder)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def span_accuracy(self, pred, gold):
|
||||||
|
return np.mean(pred == gold)
|
||||||
|
|
||||||
|
def metrics_by_pos(self, preds, golds):
|
||||||
|
hit_cnt = 0
|
||||||
|
tot_cnt = 0
|
||||||
|
for i in range(len(preds)):
|
||||||
|
tot_cnt += len(preds[i])
|
||||||
|
for j in range(len(preds[i])):
|
||||||
|
if preds[i][j] == golds[i][j]:
|
||||||
|
hit_cnt += 1
|
||||||
|
return hit_cnt, tot_cnt
|
||||||
|
|
||||||
|
def seq_eval(self, ep_preds, query):
|
||||||
|
query_seq_lens = query["seq_len"].detach().cpu().tolist()
|
||||||
|
subsent_label2tag_ids = []
|
||||||
|
subsent_pred_ids = []
|
||||||
|
for k, batch_preds in enumerate(ep_preds):
|
||||||
|
batch_preds = batch_preds.detach().cpu().tolist()
|
||||||
|
for pred in batch_preds:
|
||||||
|
subsent_pred_ids.append(pred)
|
||||||
|
subsent_label2tag_ids.append(k)
|
||||||
|
|
||||||
|
sent_gold_labels = []
|
||||||
|
sent_pred_labels = []
|
||||||
|
subsent_id = 0
|
||||||
|
query['word_labels'] = query['word_labels'].cpu().tolist()
|
||||||
|
for snum in query['subsentence_num']:
|
||||||
|
whole_sent_gids = []
|
||||||
|
whole_sent_pids = []
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
whole_sent_gids += query['word_labels'][k][:query_seq_lens[k]]
|
||||||
|
whole_sent_pids += subsent_pred_ids[k][:query_seq_lens[k]]
|
||||||
|
label2tag = query['label2tag'][subsent_label2tag_ids[subsent_id]]
|
||||||
|
sent_gold_labels.append([label2tag[lid] for lid in whole_sent_gids])
|
||||||
|
sent_pred_labels.append([label2tag[lid] for lid in whole_sent_pids])
|
||||||
|
subsent_id += snum
|
||||||
|
hit_cnt, tot_cnt = self.metrics_by_pos(sent_pred_labels, sent_gold_labels)
|
||||||
|
logs = {
|
||||||
|
"index": query["index"],
|
||||||
|
"seq_len": query_seq_lens,
|
||||||
|
"pred": sent_pred_labels,
|
||||||
|
"gold": sent_gold_labels,
|
||||||
|
"sentence_num": query["sentence_num"],
|
||||||
|
"subsentence_num": query['subsentence_num']
|
||||||
|
}
|
||||||
|
metric_logs = {
|
||||||
|
"gold_cnt": tot_cnt,
|
||||||
|
"hit_cnt": hit_cnt
|
||||||
|
}
|
||||||
|
return metric_logs, logs
|
|
@ -0,0 +1,173 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .pos_fewshotmodel import FewShotTokenModel
|
||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
from .loss_model import MaxLoss
|
||||||
|
|
||||||
|
class SeqProtoCls(FewShotTokenModel):
|
||||||
|
def __init__(self, span_encoder, max_loss, dot=False, normalize="none", temperature=None, use_focal=False):
|
||||||
|
super(SeqProtoCls, self).__init__(span_encoder)
|
||||||
|
self.dot = dot
|
||||||
|
self.normalize = normalize
|
||||||
|
self.temperature = temperature
|
||||||
|
self.use_focal = use_focal
|
||||||
|
self.loss_fct = MaxLoss(gamma=max_loss)
|
||||||
|
self.proto = None
|
||||||
|
print("use dot : {} use normalizatioin: {} use temperature: {}".format(self.dot, self.normalize,
|
||||||
|
self.temperature if self.temperature else "none"))
|
||||||
|
return
|
||||||
|
|
||||||
|
def __dist__(self, x, y, dim):
|
||||||
|
if self.normalize == 'l2':
|
||||||
|
x = F.normalize(x, p=2, dim=-1)
|
||||||
|
y = F.normalize(y, p=2, dim=-1)
|
||||||
|
if self.dot:
|
||||||
|
sim = (x * y).sum(dim)
|
||||||
|
else:
|
||||||
|
sim = -(torch.pow(x - y, 2)).sum(dim)
|
||||||
|
if self.temperature:
|
||||||
|
sim = sim / self.temperature
|
||||||
|
return sim
|
||||||
|
|
||||||
|
def __batch_dist__(self, S_emb, Q_emb, Q_mask):
|
||||||
|
if Q_mask is None:
|
||||||
|
Q_emb = Q_emb.view(-1, Q_emb.size(-1))
|
||||||
|
else:
|
||||||
|
Q_emb = Q_emb[Q_mask.eq(1), :].view(-1, Q_emb.size(-1))
|
||||||
|
dist = self.__dist__(S_emb.unsqueeze(0), Q_emb.unsqueeze(1), 2)
|
||||||
|
return dist
|
||||||
|
|
||||||
|
def __get_proto__(self, S_emb, S_tag, S_mask, max_tag):
|
||||||
|
proto = []
|
||||||
|
embedding = S_emb[S_mask.eq(1), :].view(-1, S_emb.size(-1))
|
||||||
|
S_tag = S_tag[S_mask.eq(1)]
|
||||||
|
for label in range(max_tag + 1):
|
||||||
|
if S_tag.eq(label).sum().item() == 0:
|
||||||
|
proto.append(torch.zeros(embedding.size(-1), device=embedding.device))
|
||||||
|
else:
|
||||||
|
proto.append(torch.mean(embedding[S_tag.eq(label), :], 0))
|
||||||
|
proto = torch.stack(proto, dim=0)
|
||||||
|
return proto
|
||||||
|
|
||||||
|
def __get_proto_dist__(self, Q_emb, Q_mask):
|
||||||
|
dist = self.__batch_dist__(self.proto, Q_emb, Q_mask)
|
||||||
|
return dist
|
||||||
|
|
||||||
|
def forward_step(self, query):
|
||||||
|
query_word_emb = self.word_encoder(query['word'], query['word_mask'], word_to_piece_inds=query['word_to_piece_ind'],
|
||||||
|
word_to_piece_ends=query['word_to_piece_end'])
|
||||||
|
logits = self.__get_proto_dist__(query_word_emb, None)
|
||||||
|
logits = logits.view(query_word_emb.size(0), query_word_emb.size(1), -1)
|
||||||
|
gold = query['word_labels']
|
||||||
|
tot_loss = self.loss_fct(logits, gold)
|
||||||
|
pred = torch.argmax(logits, dim=-1)
|
||||||
|
pred = pred.masked_fill(query['word_labels'] < 0, -1)
|
||||||
|
return logits, pred, gold, tot_loss
|
||||||
|
|
||||||
|
def init_proto(self, support):
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
support_word_emb = self.word_encoder(support['word'], support['word_mask'], word_to_piece_inds=support['word_to_piece_ind'],
|
||||||
|
word_to_piece_ends=support['word_to_piece_end'])
|
||||||
|
|
||||||
|
proto = self.__get_proto__(support_word_emb, support['word_labels'], support['word_labels'] > -1, len(support['label2idx']) - 1)
|
||||||
|
self.proto = nn.Parameter(proto.data, requires_grad=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def inner_update(self, support_data, inner_steps, lr_inner):
|
||||||
|
self.init_proto(support_data)
|
||||||
|
parameters_to_optimize = list(self.named_parameters())
|
||||||
|
decay_params = []
|
||||||
|
nodecay_params = []
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
for n, p in parameters_to_optimize:
|
||||||
|
if p.requires_grad:
|
||||||
|
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
|
||||||
|
decay_params.append(p)
|
||||||
|
else:
|
||||||
|
nodecay_params.append(p)
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': decay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 1e-3},
|
||||||
|
{'params': nodecay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 0},
|
||||||
|
]
|
||||||
|
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
|
||||||
|
self.train()
|
||||||
|
for _ in range(inner_steps):
|
||||||
|
inner_opt.zero_grad()
|
||||||
|
_, _, _, loss = self.forward_step(support_data)
|
||||||
|
loss.backward()
|
||||||
|
inner_opt.step()
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward_meta(self, batch, inner_steps, lr_inner, mode):
|
||||||
|
names, params = self.get_named_params(no_grads=["proto"])
|
||||||
|
weights = deepcopy(params)
|
||||||
|
|
||||||
|
meta_grad = []
|
||||||
|
episode_losses = []
|
||||||
|
query_logits = []
|
||||||
|
query_preds = []
|
||||||
|
query_golds = []
|
||||||
|
current_support_num = 0
|
||||||
|
current_query_num = 0
|
||||||
|
support, query = batch["support"], batch["query"]
|
||||||
|
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'word_labels']
|
||||||
|
|
||||||
|
for i, sent_support_num in enumerate(support['sentence_num']):
|
||||||
|
sent_query_num = query['sentence_num'][i]
|
||||||
|
label2tag = query['label2tag'][i]
|
||||||
|
one_support = {
|
||||||
|
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys if k in support
|
||||||
|
}
|
||||||
|
one_query = {
|
||||||
|
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys if k in query
|
||||||
|
}
|
||||||
|
one_support['label2idx'] = one_query['label2idx'] = {label:idx for idx, label in label2tag.items()}
|
||||||
|
self.zero_grad()
|
||||||
|
self.inner_update(one_support, inner_steps, lr_inner) # inner update parameters on support data
|
||||||
|
if mode == "train":
|
||||||
|
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query) # evaluate on query data
|
||||||
|
grad = torch.autograd.grad(qy_loss, params) # meta-update
|
||||||
|
meta_grad.append(grad)
|
||||||
|
elif mode == "test":
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
episode_losses.append(qy_loss.item())
|
||||||
|
query_preds.append(qy_pred)
|
||||||
|
query_golds.append(qy_gold)
|
||||||
|
query_logits.append(qy_logits)
|
||||||
|
self.load_weights(names, weights)
|
||||||
|
|
||||||
|
current_query_num += sent_query_num
|
||||||
|
current_support_num += sent_support_num
|
||||||
|
self.zero_grad()
|
||||||
|
return {'loss': np.mean(episode_losses), 'names': names, 'grads': meta_grad, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
|
||||||
|
|
||||||
|
def get_named_params(self, no_grads=[]):
|
||||||
|
names = [n for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
|
||||||
|
params = [p for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
|
||||||
|
return names, params
|
||||||
|
|
||||||
|
def load_weights(self, names, params):
|
||||||
|
model_params = self.state_dict()
|
||||||
|
for n, p in zip(names, params):
|
||||||
|
assert n in model_params
|
||||||
|
model_params[n].data.copy_(p.data)
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_gradients(self, names, grads):
|
||||||
|
model_params = self.state_dict(keep_vars=True)
|
||||||
|
for n, g in zip(names, grads):
|
||||||
|
if model_params[n].grad is None:
|
||||||
|
continue
|
||||||
|
model_params[n].grad.data.add_(g.data) # accumulate
|
||||||
|
return
|
|
@ -0,0 +1,326 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from util.span_sample import convert_bio2spans
|
||||||
|
|
||||||
|
class FewShotSpanModel(nn.Module):
|
||||||
|
def __init__(self, my_word_encoder):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.word_encoder = nn.DataParallel(my_word_encoder)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def span_accuracy(self, pred, gold):
|
||||||
|
return np.mean(pred == gold)
|
||||||
|
|
||||||
|
def merge_ment(self, span_list, seq_lens, subsentence_nums):
|
||||||
|
new_span_list = []
|
||||||
|
subsent_id = 0
|
||||||
|
for snum in subsentence_nums:
|
||||||
|
sent_st_id = 0
|
||||||
|
whole_sent_span_list = []
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
tmp = [[x[0] + sent_st_id, x[1] + sent_st_id] for x in span_list[k]]
|
||||||
|
tmp = sorted(tmp, key=lambda x: x[1], reverse=False)
|
||||||
|
if len(whole_sent_span_list) > 0 and len(tmp) > 0:
|
||||||
|
if tmp[0][0] == whole_sent_span_list[-1][1] + 1:
|
||||||
|
whole_sent_span_list[-1][1] = tmp[0][1]
|
||||||
|
tmp = tmp[1:]
|
||||||
|
whole_sent_span_list.extend(tmp)
|
||||||
|
sent_st_id += seq_lens[k]
|
||||||
|
subsent_id += snum
|
||||||
|
new_span_list.append(whole_sent_span_list)
|
||||||
|
assert len(new_span_list) == len(subsentence_nums)
|
||||||
|
return new_span_list
|
||||||
|
|
||||||
|
def get_mention_prob(self, ment_probs, query):
|
||||||
|
span_indices = query["span_indices"].detach().cpu()
|
||||||
|
span_masks = query['span_mask'].detach().cpu()
|
||||||
|
ment_probs = ment_probs.detach().cpu().tolist()
|
||||||
|
cand_indices_list = []
|
||||||
|
cand_prob_list = []
|
||||||
|
cur_span_num = 0
|
||||||
|
for k in range(len(span_indices)):
|
||||||
|
mask = span_masks[k, :]
|
||||||
|
effect_indices = span_indices[k, mask.eq(1)].tolist()
|
||||||
|
effect_probs = ment_probs[cur_span_num: cur_span_num + len(effect_indices)]
|
||||||
|
cand_indices_list.append(effect_indices)
|
||||||
|
cand_prob_list.append(effect_probs)
|
||||||
|
cur_span_num += len(effect_indices)
|
||||||
|
return cand_indices_list, cand_prob_list
|
||||||
|
|
||||||
|
def filter_by_threshold(self, cand_indices_list, cand_prob_list, threshold):
|
||||||
|
final_indices_list = []
|
||||||
|
final_prob_list = []
|
||||||
|
for indices, probs in zip(cand_indices_list, cand_prob_list):
|
||||||
|
final_indices_list.append([])
|
||||||
|
final_prob_list.append([])
|
||||||
|
for x, y in zip(indices, probs):
|
||||||
|
if y > threshold:
|
||||||
|
final_indices_list[-1].append(x)
|
||||||
|
final_prob_list[-1].append(y)
|
||||||
|
return final_indices_list, final_prob_list
|
||||||
|
|
||||||
|
def seqment_eval(self, ep_preds, query, idx2label, schema):
|
||||||
|
query_seq_lens = query["seq_len"].detach().cpu().tolist()
|
||||||
|
ment_indices_list = []
|
||||||
|
sid = 0
|
||||||
|
for batch_preds in ep_preds:
|
||||||
|
for pred in batch_preds.detach().cpu().tolist():
|
||||||
|
seqs = [idx2label[idx] for idx in pred[:query_seq_lens[sid]]]
|
||||||
|
ents = convert_bio2spans(seqs, schema)
|
||||||
|
ment_indices_list.append([[x[1], x[2]] for x in ents])
|
||||||
|
sid += 1
|
||||||
|
pred_ments_list = self.merge_ment(ment_indices_list, query_seq_lens, query['subsentence_num'])
|
||||||
|
gold_ments_list = []
|
||||||
|
subsent_id = 0
|
||||||
|
query['ment_labels'] = query['ment_labels'].cpu().tolist()
|
||||||
|
for snum in query['subsentence_num']:
|
||||||
|
whole_sent_labels = []
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
whole_sent_labels += query['ment_labels'][k][:query_seq_lens[k]]
|
||||||
|
ents = convert_bio2spans([idx2label[idx] for idx in whole_sent_labels], schema)
|
||||||
|
gold_ments_list.append([[x[1], x[2]] for x in ents])
|
||||||
|
subsent_id += snum
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_ments_list, gold_ments_list)
|
||||||
|
logs = {
|
||||||
|
"index": query["index"],
|
||||||
|
"seq_len": query_seq_lens,
|
||||||
|
"pred": pred_ments_list,
|
||||||
|
"gold": gold_ments_list,
|
||||||
|
"sentence_num": query["sentence_num"],
|
||||||
|
"subsentence_num": query['subsentence_num']
|
||||||
|
}
|
||||||
|
metric_logs = {
|
||||||
|
"ment_pred_cnt": pred_cnt,
|
||||||
|
"ment_gold_cnt": gold_cnt,
|
||||||
|
"ment_hit_cnt": hit_cnt,
|
||||||
|
}
|
||||||
|
return metric_logs, logs
|
||||||
|
|
||||||
|
def ment_eval(self, ep_probs, query, threshold=0.5):
|
||||||
|
if len(ep_probs) == 0:
|
||||||
|
print("no mention")
|
||||||
|
return {}, {}
|
||||||
|
probs = torch.cat(ep_probs, dim=0)
|
||||||
|
cand_indices_list, cand_prob_list = self.get_mention_prob(probs, query)
|
||||||
|
ment_indices_list, ment_prob_list = self.filter_by_threshold(cand_indices_list, cand_prob_list, threshold)
|
||||||
|
gold_ments_list = []
|
||||||
|
for sent_spans in query['spans']:
|
||||||
|
gold_ments_list.append([])
|
||||||
|
for tagid, sp_st, sp_ed in sent_spans:
|
||||||
|
gold_ments_list[-1].append([sp_st, sp_ed])
|
||||||
|
assert len(ment_indices_list) == len(query['seq_len'])
|
||||||
|
assert len(gold_ments_list) == len(query['seq_len'])
|
||||||
|
assert len(ment_indices_list) >= len(query['index'])
|
||||||
|
query_seq_lens = query["seq_len"].detach().cpu().tolist()
|
||||||
|
pred_ments_list = self.merge_ment(ment_indices_list, query_seq_lens, query['subsentence_num'])
|
||||||
|
gold_ments_list = self.merge_ment(gold_ments_list, query_seq_lens, query['subsentence_num'])
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_ments_list, gold_ments_list)
|
||||||
|
logs = {
|
||||||
|
"index": query["index"],
|
||||||
|
"seq_len": query_seq_lens,
|
||||||
|
"pred": pred_ments_list,
|
||||||
|
"gold": gold_ments_list,
|
||||||
|
"before_ind": ment_indices_list,
|
||||||
|
"before_prob": ment_prob_list,
|
||||||
|
"sentence_num": query["sentence_num"],
|
||||||
|
"subsentence_num": query['subsentence_num']
|
||||||
|
}
|
||||||
|
metric_logs = {
|
||||||
|
"ment_pred_cnt": pred_cnt,
|
||||||
|
"ment_gold_cnt": gold_cnt,
|
||||||
|
"ment_hit_cnt": hit_cnt,
|
||||||
|
}
|
||||||
|
return metric_logs, logs
|
||||||
|
|
||||||
|
def metrics_by_ment(self, pred_spans_list, gold_spans_list):
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
|
||||||
|
for pred_spans, gold_spans in zip(pred_spans_list, gold_spans_list):
|
||||||
|
pred_spans = set(map(lambda x: (x[1], x[2]), pred_spans))
|
||||||
|
gold_spans = set(map(lambda x: (x[1], x[2]), gold_spans))
|
||||||
|
pred_cnt += len(pred_spans)
|
||||||
|
gold_cnt += len(gold_spans)
|
||||||
|
hit_cnt += len(pred_spans.intersection(gold_spans))
|
||||||
|
return pred_cnt, gold_cnt, hit_cnt
|
||||||
|
|
||||||
|
def get_emission(self, ep_logits, query):
|
||||||
|
assert len(query['label2tag']) == len(query['sentence_num'])
|
||||||
|
span_indices = query["span_indices"]
|
||||||
|
cur_sent_num = 0
|
||||||
|
cand_indices_list = []
|
||||||
|
cand_prob_list = []
|
||||||
|
label2tag_list = []
|
||||||
|
for i, query_sent_num in enumerate(query['sentence_num']):
|
||||||
|
probs = F.softmax(ep_logits[i], dim=-1).detach().cpu()
|
||||||
|
cur_span_num = 0
|
||||||
|
for j in range(query_sent_num):
|
||||||
|
mask = query['span_mask'][cur_sent_num, :]
|
||||||
|
effect_indices = span_indices[cur_sent_num, mask.eq(1)].detach().cpu().tolist()
|
||||||
|
cand_indices_list.append(effect_indices)
|
||||||
|
cand_prob_list.append(probs[cur_span_num:cur_span_num + len(effect_indices)].numpy())
|
||||||
|
label2tag_list.append(query['label2tag'][i])
|
||||||
|
cur_sent_num += 1
|
||||||
|
cur_span_num += len(effect_indices)
|
||||||
|
return cand_indices_list, cand_prob_list, label2tag_list
|
||||||
|
|
||||||
|
def to_triple_score(self, indices_list, prob_list, label2tag):
|
||||||
|
tri_list = []
|
||||||
|
for sp_id, (i, j) in enumerate(indices_list):
|
||||||
|
k = np.argmax(prob_list[sp_id, :])
|
||||||
|
if label2tag[k] == 'O':
|
||||||
|
assert k == 0
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
tri_list.append([label2tag[k], i, j, prob_list[sp_id, k]])
|
||||||
|
return tri_list
|
||||||
|
|
||||||
|
def greedy_search(self, span_indices_list, prob_list, label2tag_list, seq_len_list, overlap=False, threshold=-1):
|
||||||
|
output_spans_list = []
|
||||||
|
for sid in range(len(span_indices_list)):
|
||||||
|
output_spans_list.append([])
|
||||||
|
tri_list = self.to_triple_score(span_indices_list[sid], prob_list[sid], label2tag_list[sid])
|
||||||
|
sorted_tri_list = sorted(tri_list, key=lambda x: x[-1], reverse=True)
|
||||||
|
used_words = np.zeros(seq_len_list[sid])
|
||||||
|
for tag, sp_st, sp_ed, score in sorted_tri_list:
|
||||||
|
if score < threshold:
|
||||||
|
continue
|
||||||
|
if sum(used_words[sp_st: sp_ed + 1]) > 0 and (not overlap):
|
||||||
|
continue
|
||||||
|
used_words[sp_st: sp_ed + 1] = 1
|
||||||
|
assert tag != "O"
|
||||||
|
output_spans_list[sid].append([tag, sp_st, sp_ed])
|
||||||
|
return output_spans_list
|
||||||
|
|
||||||
|
def merge_entity(self, span_list, seq_lens, subsentence_nums):
|
||||||
|
new_span_list = []
|
||||||
|
subsent_id = 0
|
||||||
|
for snum in subsentence_nums:
|
||||||
|
sent_st_id = 0
|
||||||
|
whole_sent_span_list = []
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
tmp = [[x[0], x[1] + sent_st_id, x[2] + sent_st_id] for x in span_list[k]]
|
||||||
|
tmp = sorted(tmp, key=lambda x: x[2], reverse=False)
|
||||||
|
if len(whole_sent_span_list) > 0 and len(tmp) > 0:
|
||||||
|
if whole_sent_span_list[-1][0] == tmp[0][0] and tmp[0][1] == whole_sent_span_list[-1][2] + 1:
|
||||||
|
whole_sent_span_list[-1][2] = tmp[0][2]
|
||||||
|
tmp = tmp[1:]
|
||||||
|
whole_sent_span_list.extend(tmp)
|
||||||
|
sent_st_id += seq_lens[k]
|
||||||
|
subsent_id += snum
|
||||||
|
new_span_list.append(whole_sent_span_list)
|
||||||
|
assert len(new_span_list) == len(subsentence_nums)
|
||||||
|
return new_span_list
|
||||||
|
|
||||||
|
def seq_eval(self, ep_preds, query, schema):
|
||||||
|
query_seq_lens = query["seq_len"].detach().cpu().tolist()
|
||||||
|
subsent_label2tag_ids = []
|
||||||
|
subsent_pred_ids = []
|
||||||
|
for k, batch_preds in enumerate(ep_preds):
|
||||||
|
batch_preds = batch_preds.detach().cpu().tolist()
|
||||||
|
for pred in batch_preds:
|
||||||
|
subsent_pred_ids.append(pred)
|
||||||
|
subsent_label2tag_ids.append(k)
|
||||||
|
|
||||||
|
sent_gold_labels = []
|
||||||
|
sent_pred_labels = []
|
||||||
|
pred_spans_list = []
|
||||||
|
gold_spans_list = []
|
||||||
|
subsent_id = 0
|
||||||
|
query['word_labels'] = query['word_labels'].cpu().tolist()
|
||||||
|
for snum in query['subsentence_num']:
|
||||||
|
whole_sent_gids = []
|
||||||
|
whole_sent_pids = []
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
whole_sent_gids += query['word_labels'][k][:query_seq_lens[k]]
|
||||||
|
whole_sent_pids += subsent_pred_ids[k][:query_seq_lens[k]]
|
||||||
|
label2tag = query['label2tag'][subsent_label2tag_ids[subsent_id]]
|
||||||
|
sent_gold_labels.append([label2tag[lid] for lid in whole_sent_gids])
|
||||||
|
sent_pred_labels.append([label2tag[lid] for lid in whole_sent_pids])
|
||||||
|
gold_spans_list.append(convert_bio2spans(sent_gold_labels[-1], schema))
|
||||||
|
pred_spans_list.append(convert_bio2spans(sent_pred_labels[-1], schema))
|
||||||
|
subsent_id += snum
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_spans_list, gold_spans_list)
|
||||||
|
ment_pred_cnt, ment_gold_cnt, ment_hit_cnt = self.metrics_by_ment(pred_spans_list, gold_spans_list)
|
||||||
|
logs = {
|
||||||
|
"index": query["index"],
|
||||||
|
"seq_len": query_seq_lens,
|
||||||
|
"pred": pred_spans_list,
|
||||||
|
"gold": gold_spans_list,
|
||||||
|
"sentence_num": query["sentence_num"],
|
||||||
|
"subsentence_num": query['subsentence_num']
|
||||||
|
}
|
||||||
|
metric_logs = {
|
||||||
|
"ment_pred_cnt": ment_pred_cnt,
|
||||||
|
"ment_gold_cnt": ment_gold_cnt,
|
||||||
|
"ment_hit_cnt": ment_hit_cnt,
|
||||||
|
"ent_pred_cnt": pred_cnt,
|
||||||
|
"ent_gold_cnt": gold_cnt,
|
||||||
|
"ent_hit_cnt": hit_cnt
|
||||||
|
}
|
||||||
|
return metric_logs, logs
|
||||||
|
|
||||||
|
def greedy_eval(self, ep_logits, query, overlap=False, threshold=-1):
|
||||||
|
if len(ep_logits) == 0:
|
||||||
|
print("no entity")
|
||||||
|
return {}, {}
|
||||||
|
query_seq_lens = query["seq_len"].detach().cpu().tolist()
|
||||||
|
cand_indices_list, cand_prob_list, label2tag_list = self.get_emission(ep_logits, query)
|
||||||
|
pred_spans_list = self.greedy_search(cand_indices_list, cand_prob_list, label2tag_list, query_seq_lens, overlap, threshold)
|
||||||
|
|
||||||
|
gold_spans_list = []
|
||||||
|
subsent_idx = 0
|
||||||
|
for sent_spans, label2tag in zip(query['spans'], label2tag_list):
|
||||||
|
gold_spans_list.append([])
|
||||||
|
for tagid, sp_st, sp_ed in sent_spans:
|
||||||
|
gold_spans_list[-1].append([label2tag[tagid], sp_st, sp_ed])
|
||||||
|
subsent_idx += 1
|
||||||
|
assert len(pred_spans_list) == len(query['seq_len'])
|
||||||
|
assert len(gold_spans_list) == len(query['seq_len'])
|
||||||
|
assert len(pred_spans_list) >= len(query['index'])
|
||||||
|
pred_spans_list = self.merge_entity(pred_spans_list, query_seq_lens, query['subsentence_num'])
|
||||||
|
gold_spans_list = self.merge_entity(gold_spans_list, query_seq_lens, query['subsentence_num'])
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_spans_list, gold_spans_list)
|
||||||
|
ment_pred_cnt, ment_gold_cnt, ment_hit_cnt = self.metrics_by_ment(pred_spans_list, gold_spans_list)
|
||||||
|
logs = {
|
||||||
|
"index": query["index"],
|
||||||
|
"seq_len": query_seq_lens,
|
||||||
|
"pred": pred_spans_list,
|
||||||
|
"gold": gold_spans_list,
|
||||||
|
"before_ind": cand_indices_list,
|
||||||
|
"before_prob": cand_prob_list,
|
||||||
|
"label_tag": label2tag_list,
|
||||||
|
"sentence_num": query["sentence_num"],
|
||||||
|
"subsentence_num": query['subsentence_num']
|
||||||
|
}
|
||||||
|
metric_logs = {
|
||||||
|
"ment_pred_cnt": ment_pred_cnt,
|
||||||
|
"ment_gold_cnt": ment_gold_cnt,
|
||||||
|
"ment_hit_cnt": ment_hit_cnt,
|
||||||
|
"ent_pred_cnt": pred_cnt,
|
||||||
|
"ent_gold_cnt": gold_cnt,
|
||||||
|
"ent_hit_cnt": hit_cnt
|
||||||
|
}
|
||||||
|
return metric_logs, logs
|
||||||
|
|
||||||
|
def metrics_by_entity(self, pred_spans_list, gold_spans_list):
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
|
||||||
|
for pred_spans, gold_spans in zip(pred_spans_list, gold_spans_list):
|
||||||
|
pred_spans = set(map(lambda x: tuple(x), pred_spans))
|
||||||
|
gold_spans = set(map(lambda x: tuple(x), gold_spans))
|
||||||
|
pred_cnt += len(pred_spans)
|
||||||
|
gold_cnt += len(gold_spans)
|
||||||
|
hit_cnt += len(pred_spans.intersection(gold_spans))
|
||||||
|
return pred_cnt, gold_cnt, hit_cnt
|
||||||
|
|
||||||
|
def metrics_by_ment(self, pred_spans_list, gold_spans_list):
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
|
||||||
|
for pred_spans, gold_spans in zip(pred_spans_list, gold_spans_list):
|
||||||
|
pred_spans = set(map(lambda x: (x[1], x[2]), pred_spans))
|
||||||
|
gold_spans = set(map(lambda x: (x[1], x[2]), gold_spans))
|
||||||
|
pred_cnt += len(pred_spans)
|
||||||
|
gold_cnt += len(gold_spans)
|
||||||
|
hit_cnt += len(pred_spans.intersection(gold_spans))
|
||||||
|
return pred_cnt, gold_cnt, hit_cnt
|
|
@ -0,0 +1,204 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .span_fewshotmodel import FewShotSpanModel
|
||||||
|
from copy import deepcopy
|
||||||
|
from .loss_model import FocalLoss
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class SpanProtoCls(FewShotSpanModel):
|
||||||
|
def __init__(self, span_encoder, use_oproto, dot=False, normalize="none",
|
||||||
|
temperature=None, use_focal=False):
|
||||||
|
super(SpanProtoCls, self).__init__(span_encoder)
|
||||||
|
self.dot = dot
|
||||||
|
self.normalize = normalize
|
||||||
|
self.temperature = temperature
|
||||||
|
self.use_focal = use_focal
|
||||||
|
self.use_oproto = use_oproto
|
||||||
|
self.proto = None
|
||||||
|
if use_focal:
|
||||||
|
self.base_loss_fct = FocalLoss(gamma=1.0)
|
||||||
|
print("use focal loss")
|
||||||
|
else:
|
||||||
|
self.base_loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
|
||||||
|
print("use cross entropy loss")
|
||||||
|
print("use dot : {} use normalizatioin: {} use temperature: {}".format(self.dot, self.normalize,
|
||||||
|
self.temperature if self.temperature else "none"))
|
||||||
|
self.cached_o_proto = torch.zeros(span_encoder.span_dim, requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
def loss_fct(self, logits, targets, inst_weights=None):
|
||||||
|
if inst_weights is None:
|
||||||
|
loss = self.base_loss_fct(logits, targets)
|
||||||
|
loss = loss.mean()
|
||||||
|
else:
|
||||||
|
targets = torch.clamp(targets, min=0)
|
||||||
|
one_hot_targets = torch.zeros(logits.size(), device=logits.device).scatter_(1, targets.unsqueeze(1), 1)
|
||||||
|
soft_labels = inst_weights.unsqueeze(1) * one_hot_targets + (1 - one_hot_targets) * (1 - inst_weights).unsqueeze(1) / (logits.size(1) - 1)
|
||||||
|
logp = F.log_softmax(logits, dim=-1)
|
||||||
|
loss = - (logp * soft_labels).sum(1)
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def __dist__(self, x, y, dim):
|
||||||
|
if self.normalize == 'l2':
|
||||||
|
x = F.normalize(x, p=2, dim=-1)
|
||||||
|
y = F.normalize(y, p=2, dim=-1)
|
||||||
|
if self.dot:
|
||||||
|
sim = (x * y).sum(dim)
|
||||||
|
else:
|
||||||
|
sim = -(torch.pow(x - y, 2)).sum(dim)
|
||||||
|
if self.temperature:
|
||||||
|
sim = sim / self.temperature
|
||||||
|
return sim
|
||||||
|
|
||||||
|
def __batch_dist__(self, S_emb, Q_emb, Q_mask):
|
||||||
|
Q_emb = Q_emb[Q_mask.eq(1), :].view(-1, Q_emb.size(-1))
|
||||||
|
dist = self.__dist__(S_emb.unsqueeze(0), Q_emb.unsqueeze(1), 2)
|
||||||
|
return dist
|
||||||
|
|
||||||
|
def __get_proto__(self, S_emb, S_tag, S_mask):
|
||||||
|
proto = []
|
||||||
|
embedding = S_emb[S_mask.eq(1), :].view(-1, S_emb.size(-1))
|
||||||
|
S_tag = S_tag[S_mask.eq(1)]
|
||||||
|
if self.use_oproto:
|
||||||
|
st_idx = 0
|
||||||
|
else:
|
||||||
|
st_idx = 1
|
||||||
|
proto = [self.cached_o_proto]
|
||||||
|
for label in range(st_idx, torch.max(S_tag) + 1):
|
||||||
|
proto.append(torch.mean(embedding[S_tag.eq(label), :], 0))
|
||||||
|
proto = torch.stack(proto, dim=0)
|
||||||
|
return proto
|
||||||
|
|
||||||
|
def __get_proto_dist__(self, Q_emb, Q_mask):
|
||||||
|
dist = self.__batch_dist__(self.proto, Q_emb, Q_mask)
|
||||||
|
if not self.use_oproto:
|
||||||
|
dist[:, 0] = -1000000
|
||||||
|
return dist
|
||||||
|
|
||||||
|
def forward_step(self, query):
|
||||||
|
if query['span_mask'].sum().item() == 0: # there is no query mentions
|
||||||
|
print("no query mentions")
|
||||||
|
empty_tensor = torch.tensor([], device=query['word'].device)
|
||||||
|
zero_tensor = torch.tensor([0], device=query['word'].device)
|
||||||
|
return empty_tensor, empty_tensor, empty_tensor, zero_tensor
|
||||||
|
query_span_emb = self.word_encoder(query['word'], query['word_mask'], word_to_piece_inds=query['word_to_piece_ind'],
|
||||||
|
word_to_piece_ends=query['word_to_piece_end'], span_indices=query['span_indices'])
|
||||||
|
|
||||||
|
logits = self.__get_proto_dist__(query_span_emb, query['span_mask'])
|
||||||
|
golds = query["span_tag"][query["span_mask"].eq(1)].view(-1)
|
||||||
|
query_span_weights = query["span_weights"][query["span_mask"].eq(1)].view(-1)
|
||||||
|
if self.use_oproto:
|
||||||
|
loss = self.loss_fct(logits, golds, inst_weights=query_span_weights)
|
||||||
|
else:
|
||||||
|
loss = self.loss_fct(logits[:, 1:], golds - 1, inst_weights=query_span_weights)
|
||||||
|
_, preds = torch.max(logits, dim=-1)
|
||||||
|
return logits, preds, golds, loss
|
||||||
|
|
||||||
|
def init_proto(self, support_data):
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
support_span_emb = self.word_encoder(support_data['word'], support_data['word_mask'], word_to_piece_inds=support_data['word_to_piece_ind'],
|
||||||
|
word_to_piece_ends=support_data['word_to_piece_end'], span_indices=support_data['span_indices'])
|
||||||
|
self.cached_o_proto = self.cached_o_proto.to(support_span_emb.device)
|
||||||
|
proto = self.__get_proto__(support_span_emb, support_data['span_tag'], support_data['span_mask'])
|
||||||
|
self.proto = nn.Parameter(proto.data, requires_grad=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
def inner_update(self, support_data, inner_steps, lr_inner):
|
||||||
|
self.init_proto(support_data)
|
||||||
|
parameters_to_optimize = list(self.named_parameters())
|
||||||
|
decay_params = []
|
||||||
|
nodecay_params = []
|
||||||
|
decay_names = []
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
for n, p in parameters_to_optimize:
|
||||||
|
if p.requires_grad:
|
||||||
|
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
|
||||||
|
decay_params.append(p)
|
||||||
|
decay_names.append(n)
|
||||||
|
else:
|
||||||
|
nodecay_params.append(p)
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': decay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 1e-3},
|
||||||
|
{'params': nodecay_params,
|
||||||
|
'lr': lr_inner, 'weight_decay': 0},
|
||||||
|
]
|
||||||
|
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
|
||||||
|
self.train()
|
||||||
|
for _ in range(inner_steps):
|
||||||
|
inner_opt.zero_grad()
|
||||||
|
_, _, _, loss = self.forward_step(support_data)
|
||||||
|
if loss.requires_grad:
|
||||||
|
loss.backward()
|
||||||
|
inner_opt.step()
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward_meta(self, batch, inner_steps, lr_inner, mode):
|
||||||
|
names, params = self.get_named_params(no_grads=["proto"])
|
||||||
|
weights = deepcopy(params)
|
||||||
|
|
||||||
|
meta_grad = []
|
||||||
|
episode_losses = []
|
||||||
|
query_logits = []
|
||||||
|
query_preds = []
|
||||||
|
query_golds = []
|
||||||
|
current_support_num = 0
|
||||||
|
current_query_num = 0
|
||||||
|
support, query = batch["support"], batch["query"]
|
||||||
|
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'span_indices', 'span_mask', 'span_tag', 'span_weights']
|
||||||
|
|
||||||
|
for i, sent_support_num in enumerate(support['sentence_num']):
|
||||||
|
sent_query_num = query['sentence_num'][i]
|
||||||
|
one_support = {
|
||||||
|
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys if k in support
|
||||||
|
}
|
||||||
|
one_query = {
|
||||||
|
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys if k in query
|
||||||
|
}
|
||||||
|
self.zero_grad()
|
||||||
|
self.inner_update(one_support, inner_steps, lr_inner)
|
||||||
|
if mode == "train":
|
||||||
|
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
|
||||||
|
if one_query['span_mask'].sum().item() == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
grad = torch.autograd.grad(qy_loss, params)
|
||||||
|
meta_grad.append(grad)
|
||||||
|
else:
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
|
||||||
|
|
||||||
|
episode_losses.append(qy_loss.item())
|
||||||
|
query_preds.append(qy_pred)
|
||||||
|
query_golds.append(qy_gold)
|
||||||
|
query_logits.append(qy_logits)
|
||||||
|
self.load_weights(names, weights)
|
||||||
|
|
||||||
|
current_query_num += sent_query_num
|
||||||
|
current_support_num += sent_support_num
|
||||||
|
self.zero_grad()
|
||||||
|
return {'loss': np.mean(episode_losses), 'names': names, 'grads': meta_grad, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
|
||||||
|
|
||||||
|
def get_named_params(self, no_grads=[]):
|
||||||
|
names = [n for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
|
||||||
|
params = [p for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
|
||||||
|
return names, params
|
||||||
|
|
||||||
|
def load_weights(self, names, params):
|
||||||
|
model_params = self.state_dict()
|
||||||
|
for n, p in zip(names, params):
|
||||||
|
assert n in model_params
|
||||||
|
model_params[n].data.copy_(p.data)
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_gradients(self, names, grads):
|
||||||
|
model_params = self.state_dict(keep_vars=True)
|
||||||
|
for n, g in zip(names, grads):
|
||||||
|
if model_params[n].grad is None:
|
||||||
|
continue
|
||||||
|
model_params[n].grad.data.add_(g.data) # accumulate
|
|
@ -0,0 +1,135 @@
|
||||||
|
seed=${1:-12}
|
||||||
|
d=${2:-1}
|
||||||
|
K=${3:-1}
|
||||||
|
dataset=${4:-"ner"}
|
||||||
|
mft_steps=${5:-30}
|
||||||
|
tft_steps=${6:-30}
|
||||||
|
schema=${7:-"BIO"}
|
||||||
|
mloss=${8:-2}
|
||||||
|
|
||||||
|
aids=8-9-10-11
|
||||||
|
batch_size=8
|
||||||
|
|
||||||
|
base_dir=data
|
||||||
|
|
||||||
|
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/${dataset}/${K}shot
|
||||||
|
train_fn=train_domain${d}.txt
|
||||||
|
val_fn=valid_domain${d}.txt
|
||||||
|
test_fn=test_domain${d}.txt
|
||||||
|
ep_dir=$root
|
||||||
|
ep_train_fn=train_domain${d}_id.jsonl
|
||||||
|
ep_val_fn=valid_domain${d}_id.jsonl
|
||||||
|
ep_test_fn=test_domain${d}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=800
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=100
|
||||||
|
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/few-nerd/${dataset}
|
||||||
|
ep_dir=${base_dir}/few-nerd/episode-${dataset}
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=False
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=500
|
||||||
|
val_steps=500
|
||||||
|
else
|
||||||
|
root=${base_dir}/fewevent
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_dir=${root}
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=500
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $schema == 'IO' ]
|
||||||
|
then
|
||||||
|
crf=False
|
||||||
|
else
|
||||||
|
crf=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
tft_steps=${mft_steps}
|
||||||
|
|
||||||
|
name=joint_m${mft_steps}_max${mloss}_${schema}_t${tft_steps}_bz${batch_size}
|
||||||
|
inner_steps=1
|
||||||
|
|
||||||
|
ckpt_dir=outputs/${dataset}/${d}-${K}shot/joint/${name}/seed${seed}
|
||||||
|
|
||||||
|
|
||||||
|
output_dir=${ckpt_dir}/
|
||||||
|
|
||||||
|
log_name=${dataset}-${d}-${K}-${name}-${seed}-test.log
|
||||||
|
|
||||||
|
|
||||||
|
python train_joint.py --mode test-twostage \
|
||||||
|
--seed ${seed} \
|
||||||
|
--root ${root} \
|
||||||
|
--train ${train_fn} \
|
||||||
|
--val ${val_fn} \
|
||||||
|
--test ${test_fn} \
|
||||||
|
--ep_dir ${ep_dir} \
|
||||||
|
--ep_train ${ep_train_fn} \
|
||||||
|
--ep_val ${ep_val_fn} \
|
||||||
|
--ep_test ${ep_test_fn} \
|
||||||
|
--output_dir ${output_dir} \
|
||||||
|
--N ${d} \
|
||||||
|
--K ${K} \
|
||||||
|
--Q 1 \
|
||||||
|
--bio ${bio} \
|
||||||
|
--max_loss ${mloss} \
|
||||||
|
--use_crf ${crf} \
|
||||||
|
--schema ${schema} \
|
||||||
|
--adapter_layer_ids ${aids} \
|
||||||
|
--encoder_name_or_path bert-base-uncased \
|
||||||
|
--last_n_layer -1 \
|
||||||
|
--max_length 128 \
|
||||||
|
--word_encode_choice first \
|
||||||
|
--span_encode_choice avg \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--bert_learning_rate 5e-5 \
|
||||||
|
--bert_weight_decay 0.01 \
|
||||||
|
--train_iter ${train_iter} \
|
||||||
|
--dev_iter ${dev_iter} \
|
||||||
|
--val_steps ${val_steps} \
|
||||||
|
--warmup_step 0 \
|
||||||
|
--log_steps 50 \
|
||||||
|
--type_lam 1 \
|
||||||
|
--use_adapter True \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--use_width False \
|
||||||
|
--use_case False \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--max_grad_norm 5 \
|
||||||
|
--dot False \
|
||||||
|
--normalize l2 \
|
||||||
|
--temperature 0.1 \
|
||||||
|
--use_focal False \
|
||||||
|
--use_att False \
|
||||||
|
--att_hidden_dim 100 \
|
||||||
|
--adapter_size 64 \
|
||||||
|
--use_oproto False \
|
||||||
|
--hou_eval_ep ${hou_eval_ep} \
|
||||||
|
--overlap False \
|
||||||
|
--type_threshold 0 \
|
||||||
|
--use_maml True \
|
||||||
|
--train_inner_lr 2e-5 \
|
||||||
|
--train_inner_steps ${inner_steps} \
|
||||||
|
--warmup_prop_inner 0 \
|
||||||
|
--eval_inner_lr 2e-5 \
|
||||||
|
--eval_ment_inner_steps ${mft_steps} \
|
||||||
|
--eval_type_inner_steps ${tft_steps} \
|
||||||
|
--load_ckpt ${ckpt_dir}/model.pth.tar
|
|
@ -0,0 +1,118 @@
|
||||||
|
seed=${1:-12}
|
||||||
|
d=${2:-1}
|
||||||
|
K=${3:-1}
|
||||||
|
dataset=${4:-"ner"}
|
||||||
|
ft_steps=${5:-20}
|
||||||
|
ment_name=${6:-"sqlment_maml30_max2_BIO_bz16"}
|
||||||
|
|
||||||
|
base_dir=data
|
||||||
|
batch_size=16
|
||||||
|
|
||||||
|
|
||||||
|
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/${dataset}/${K}shot
|
||||||
|
train_fn=train_domain${d}.txt
|
||||||
|
val_fn=valid_domain${d}.txt
|
||||||
|
test_fn=test_domain${d}.txt
|
||||||
|
ep_dir=$root
|
||||||
|
ep_train_fn=train_domain${d}_id.jsonl
|
||||||
|
ep_val_fn=valid_domain${d}_id.jsonl
|
||||||
|
ep_test_fn=test_domain${d}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=800
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=100
|
||||||
|
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/few-nerd/${dataset}
|
||||||
|
ep_dir=${base_dir}/few-nerd/episode-${dataset}
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=False
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=500
|
||||||
|
val_steps=500
|
||||||
|
else
|
||||||
|
root=${base_dir}/fewevent
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_dir=${root}
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=500
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
name=type_mamlcls${ft_steps}_bz${batch_size}
|
||||||
|
inner_steps=1
|
||||||
|
|
||||||
|
output_dir=outputs/${dataset}/${d}-${K}shot/${ment_name}/seed${seed}/${name}
|
||||||
|
ckpt_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
|
||||||
|
mkdir -p ${output_dir}
|
||||||
|
|
||||||
|
val_ment_fn=outputs/${dataset}/${d}-${K}shot/${ment_name}/seed${seed}/dev_ment.json
|
||||||
|
test_ment_fn=outputs/${dataset}/${d}-${K}shot/${ment_name}/seed${seed}/test_ment.json
|
||||||
|
|
||||||
|
python train_type.py --mode test \
|
||||||
|
--seed ${seed} \
|
||||||
|
--root ${root} \
|
||||||
|
--train ${train_fn} \
|
||||||
|
--val ${val_fn} \
|
||||||
|
--test ${test_fn} \
|
||||||
|
--ep_dir ${ep_dir} \
|
||||||
|
--ep_train ${ep_train_fn} \
|
||||||
|
--ep_val ${ep_val_fn} \
|
||||||
|
--ep_test ${ep_test_fn} \
|
||||||
|
--val_ment_fn ${val_ment_fn} \
|
||||||
|
--test_ment_fn ${test_ment_fn} \
|
||||||
|
--output_dir ${output_dir} \
|
||||||
|
--load_ckpt ${ckpt_dir}/model.pth.tar \
|
||||||
|
--N ${d} \
|
||||||
|
--K ${K} \
|
||||||
|
--Q 1 \
|
||||||
|
--bio ${bio} \
|
||||||
|
--encoder_name_or_path bert-base-uncased \
|
||||||
|
--last_n_layer -4 \
|
||||||
|
--max_length 128 \
|
||||||
|
--word_encode_choice first \
|
||||||
|
--span_encode_choice avg \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--bert_learning_rate 5e-5 \
|
||||||
|
--bert_weight_decay 0.01 \
|
||||||
|
--train_iter ${train_iter} \
|
||||||
|
--dev_iter ${dev_iter} \
|
||||||
|
--val_steps ${val_steps} \
|
||||||
|
--warmup_step 0 \
|
||||||
|
--log_steps 50 \
|
||||||
|
--train_batch_size ${batch_size} \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--use_width False \
|
||||||
|
--use_case False \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--max_grad_norm 5 \
|
||||||
|
--dot False \
|
||||||
|
--normalize l2 \
|
||||||
|
--temperature 0.1 \
|
||||||
|
--use_focal False \
|
||||||
|
--use_att False \
|
||||||
|
--att_hidden_dim 100 \
|
||||||
|
--use_oproto False \
|
||||||
|
--hou_eval_ep ${hou_eval_ep} \
|
||||||
|
--overlap False \
|
||||||
|
--type_threshold 0 \
|
||||||
|
--use_maml True \
|
||||||
|
--train_inner_lr 2e-5 \
|
||||||
|
--train_inner_steps ${inner_steps} \
|
||||||
|
--warmup_prop_inner 0 \
|
||||||
|
--eval_inner_lr 2e-5 \
|
||||||
|
--eval_inner_steps ${ft_steps}
|
|
@ -0,0 +1,130 @@
|
||||||
|
seed=${1:-12}
|
||||||
|
d=${2:-5}
|
||||||
|
K=${3:-1}
|
||||||
|
dataset=${4:-"inter"}
|
||||||
|
mft_steps=${5:-3}
|
||||||
|
tft_steps=${6:-3}
|
||||||
|
schema=${7:-"BIO"}
|
||||||
|
mloss=${8:-2}
|
||||||
|
|
||||||
|
base_dir=data
|
||||||
|
batch_size=8
|
||||||
|
aids=8-9-10-11
|
||||||
|
|
||||||
|
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/${dataset}/${K}shot
|
||||||
|
train_fn=train_domain${d}.txt
|
||||||
|
val_fn=valid_domain${d}.txt
|
||||||
|
test_fn=test_domain${d}.txt
|
||||||
|
ep_dir=$root
|
||||||
|
ep_train_fn=train_domain${d}_id.jsonl
|
||||||
|
ep_val_fn=valid_domain${d}_id.jsonl
|
||||||
|
ep_test_fn=test_domain${d}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=800
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=100
|
||||||
|
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/few-nerd/${dataset}
|
||||||
|
ep_dir=${base_dir}/few-nerd/episode-${dataset}
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=False
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=500
|
||||||
|
val_steps=500
|
||||||
|
else
|
||||||
|
root=${base_dir}/fewevent
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_dir=${root}
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=500
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $schema == 'IO' ]
|
||||||
|
then
|
||||||
|
crf=False
|
||||||
|
else
|
||||||
|
crf=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
tft_steps=${mft_steps}
|
||||||
|
name=joint_m${mft_steps}_max${mloss}_${schema}_t${tft_steps}_bz${batch_size}
|
||||||
|
inner_steps=1
|
||||||
|
|
||||||
|
output_dir=outputs/${dataset}/${d}-${K}shot/joint/${name}/seed${seed}
|
||||||
|
mkdir -p ${output_dir}
|
||||||
|
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
|
||||||
|
|
||||||
|
|
||||||
|
python train_joint.py --mode train \
|
||||||
|
--seed ${seed} \
|
||||||
|
--root ${root} \
|
||||||
|
--train ${train_fn} \
|
||||||
|
--val ${val_fn} \
|
||||||
|
--test ${test_fn} \
|
||||||
|
--ep_dir ${ep_dir} \
|
||||||
|
--ep_train ${ep_train_fn} \
|
||||||
|
--ep_val ${ep_val_fn} \
|
||||||
|
--ep_test ${ep_test_fn} \
|
||||||
|
--output_dir ${output_dir} \
|
||||||
|
--N ${d} \
|
||||||
|
--K ${K} \
|
||||||
|
--Q 1 \
|
||||||
|
--bio ${bio} \
|
||||||
|
--max_loss ${mloss} \
|
||||||
|
--use_crf ${crf} \
|
||||||
|
--schema ${schema} \
|
||||||
|
--adapter_layer_ids ${aids} \
|
||||||
|
--encoder_name_or_path bert-base-uncased \
|
||||||
|
--last_n_layer -1 \
|
||||||
|
--max_length 128 \
|
||||||
|
--word_encode_choice first \
|
||||||
|
--span_encode_choice avg \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--bert_learning_rate 5e-5 \
|
||||||
|
--bert_weight_decay 0.01 \
|
||||||
|
--train_iter ${train_iter} \
|
||||||
|
--dev_iter ${dev_iter} \
|
||||||
|
--val_steps ${val_steps} \
|
||||||
|
--warmup_step 0 \
|
||||||
|
--log_steps 50 \
|
||||||
|
--train_batch_size ${batch_size} \
|
||||||
|
--type_lam 1 \
|
||||||
|
--use_adapter True \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--use_width False \
|
||||||
|
--use_case False \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--max_grad_norm 5 \
|
||||||
|
--dot False \
|
||||||
|
--normalize l2 \
|
||||||
|
--temperature 0.1 \
|
||||||
|
--use_focal False \
|
||||||
|
--use_att False \
|
||||||
|
--att_hidden_dim 100 \
|
||||||
|
--adapter_size 64 \
|
||||||
|
--use_oproto False \
|
||||||
|
--hou_eval_ep ${hou_eval_ep} \
|
||||||
|
--overlap False \
|
||||||
|
--type_threshold 0 \
|
||||||
|
--use_maml True \
|
||||||
|
--train_inner_lr 2e-5 \
|
||||||
|
--train_inner_steps ${inner_steps} \
|
||||||
|
--warmup_prop_inner 0 \
|
||||||
|
--eval_inner_lr 2e-5 \
|
||||||
|
--eval_ment_inner_steps ${mft_steps} \
|
||||||
|
--eval_type_inner_steps ${tft_steps}
|
|
@ -0,0 +1,115 @@
|
||||||
|
seed=${1:-12}
|
||||||
|
d=${2:-1}
|
||||||
|
K=${3:-1}
|
||||||
|
dataset=${4:-"ner"}
|
||||||
|
schema=${5:-"BIO"}
|
||||||
|
mloss=${6:-2}
|
||||||
|
ft_steps=${7:-30}
|
||||||
|
|
||||||
|
base_dir=data
|
||||||
|
|
||||||
|
wp_inner=0
|
||||||
|
lr=5e-5
|
||||||
|
batch_size=16
|
||||||
|
|
||||||
|
echo $batch_size
|
||||||
|
|
||||||
|
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/${dataset}/${K}shot
|
||||||
|
train_fn=train_domain${d}.txt
|
||||||
|
val_fn=valid_domain${d}.txt
|
||||||
|
test_fn=test_domain${d}.txt
|
||||||
|
ep_dir=$root
|
||||||
|
ep_train_fn=train_domain${d}_id.jsonl
|
||||||
|
ep_val_fn=valid_domain${d}_id.jsonl
|
||||||
|
ep_test_fn=test_domain${d}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=800
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=100
|
||||||
|
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/few-nerd/${dataset}
|
||||||
|
ep_dir=${base_dir}/few-nerd/episode-${dataset}
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=False
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=500
|
||||||
|
val_steps=500
|
||||||
|
else
|
||||||
|
root=${base_dir}/fewevent
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_dir=${root}
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=500
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $schema == 'IO' ]
|
||||||
|
then
|
||||||
|
crf=False
|
||||||
|
else
|
||||||
|
crf=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
name=sqlment_maml${ft_steps}_max${mloss}_${schema}_bz${batch_size}
|
||||||
|
inner_steps=2
|
||||||
|
|
||||||
|
output_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
|
||||||
|
mkdir -p ${output_dir}
|
||||||
|
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
|
||||||
|
|
||||||
|
python train_sql.py --mode train \
|
||||||
|
--seed ${seed} \
|
||||||
|
--root ${root} \
|
||||||
|
--train ${train_fn} \
|
||||||
|
--val ${val_fn} \
|
||||||
|
--test ${test_fn} \
|
||||||
|
--ep_dir ${ep_dir} \
|
||||||
|
--ep_train ${ep_train_fn} \
|
||||||
|
--ep_val ${ep_val_fn} \
|
||||||
|
--ep_test ${ep_test_fn} \
|
||||||
|
--output_dir ${output_dir} \
|
||||||
|
--encoder_name_or_path bert-base-uncased \
|
||||||
|
--last_n_layer -1 \
|
||||||
|
--max_length 128 \
|
||||||
|
--word_encode_choice first \
|
||||||
|
--warmup_step 0 \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--bert_learning_rate ${lr} \
|
||||||
|
--bert_weight_decay 0.01 \
|
||||||
|
--log_steps 50 \
|
||||||
|
--train_iter ${train_iter} \
|
||||||
|
--dev_iter ${dev_iter} \
|
||||||
|
--val_steps ${val_steps} \
|
||||||
|
--train_batch_size ${batch_size} \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--max_loss ${mloss} \
|
||||||
|
--use_crf ${crf} \
|
||||||
|
--schema ${schema} \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--max_grad_norm 5 \
|
||||||
|
--eval_all_after_train True \
|
||||||
|
--bio ${bio} \
|
||||||
|
--use_maml True \
|
||||||
|
--train_inner_lr 2e-5 \
|
||||||
|
--train_inner_steps ${inner_steps} \
|
||||||
|
--warmup_prop_inner ${wp_inner} \
|
||||||
|
--eval_inner_lr 2e-5 \
|
||||||
|
--eval_inner_steps ${ft_steps} | tee ${log_name}
|
||||||
|
|
||||||
|
mv ${log_name} ${output_dir}/train.log
|
|
@ -0,0 +1,70 @@
|
||||||
|
seed=${1:-12}
|
||||||
|
d=${2:-2}
|
||||||
|
K=${3:-1}
|
||||||
|
mloss=${4:-2}
|
||||||
|
ft_steps=${5:-20}
|
||||||
|
|
||||||
|
base_dir=data
|
||||||
|
batch_size=16
|
||||||
|
dataset=postag
|
||||||
|
|
||||||
|
root=${base_dir}/postag/${d}-${K}shot
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_dir=$root
|
||||||
|
ep_train_fn=train_id.jsonl
|
||||||
|
ep_val_fn=dev_id.jsonl
|
||||||
|
ep_test_fn=test_id.jsonl
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=100
|
||||||
|
|
||||||
|
|
||||||
|
name=pos_mamlcls${ft_steps}_max${mloss}_bz${batch_size}
|
||||||
|
inner_steps=1
|
||||||
|
|
||||||
|
output_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
|
||||||
|
mkdir -p ${output_dir}
|
||||||
|
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
|
||||||
|
|
||||||
|
python train_pos.py --mode train \
|
||||||
|
--seed ${seed} \
|
||||||
|
--root ${root} \
|
||||||
|
--train ${train_fn} \
|
||||||
|
--val ${val_fn} \
|
||||||
|
--test ${test_fn} \
|
||||||
|
--ep_dir ${ep_dir} \
|
||||||
|
--ep_train ${ep_train_fn} \
|
||||||
|
--ep_val ${ep_val_fn} \
|
||||||
|
--ep_test ${ep_test_fn} \
|
||||||
|
--output_dir ${output_dir} \
|
||||||
|
--N ${d} \
|
||||||
|
--K ${K} \
|
||||||
|
--Q 1 \
|
||||||
|
--encoder_name_or_path bert-base-uncased \
|
||||||
|
--last_n_layer -1 \
|
||||||
|
--max_length 128 \
|
||||||
|
--word_encode_choice first \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--bert_learning_rate 5e-5 \
|
||||||
|
--bert_weight_decay 0.01 \
|
||||||
|
--train_iter ${train_iter} \
|
||||||
|
--dev_iter ${dev_iter} \
|
||||||
|
--val_steps ${val_steps} \
|
||||||
|
--warmup_step 0 \
|
||||||
|
--log_steps 50 \
|
||||||
|
--train_batch_size ${batch_size} \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--max_grad_norm 5 \
|
||||||
|
--dot False \
|
||||||
|
--normalize l2 \
|
||||||
|
--temperature 0.1 \
|
||||||
|
--max_loss ${mloss} \
|
||||||
|
--use_maml True \
|
||||||
|
--train_inner_lr 2e-5 \
|
||||||
|
--train_inner_steps ${inner_steps} \
|
||||||
|
--warmup_prop_inner 0.1 \
|
||||||
|
--eval_inner_lr 2e-5 \
|
||||||
|
--eval_inner_steps ${ft_steps}
|
|
@ -0,0 +1,114 @@
|
||||||
|
seed=${1:-12}
|
||||||
|
d=${2:-1}
|
||||||
|
K=${3:-1}
|
||||||
|
dataset=${4:-"ner"}
|
||||||
|
ft_steps=${5:-20}
|
||||||
|
|
||||||
|
base_dir=data
|
||||||
|
batch_size=16
|
||||||
|
|
||||||
|
|
||||||
|
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/${dataset}/${K}shot
|
||||||
|
train_fn=train_domain${d}.txt
|
||||||
|
val_fn=valid_domain${d}.txt
|
||||||
|
test_fn=test_domain${d}.txt
|
||||||
|
ep_dir=$root
|
||||||
|
ep_train_fn=train_domain${d}_id.jsonl
|
||||||
|
ep_val_fn=valid_domain${d}_id.jsonl
|
||||||
|
ep_test_fn=test_domain${d}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=800
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=100
|
||||||
|
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
|
||||||
|
then
|
||||||
|
root=${base_dir}/few-nerd/${dataset}
|
||||||
|
ep_dir=${base_dir}/few-nerd/episode-${dataset}
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=False
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=500
|
||||||
|
val_steps=500
|
||||||
|
else
|
||||||
|
root=${base_dir}/fewevent
|
||||||
|
train_fn=train.txt
|
||||||
|
val_fn=dev.txt
|
||||||
|
test_fn=test.txt
|
||||||
|
ep_dir=${root}
|
||||||
|
ep_train_fn=train_${d}_${K}_id.jsonl
|
||||||
|
ep_val_fn=dev_${d}_${K}_id.jsonl
|
||||||
|
ep_test_fn=test_${d}_${K}_id.jsonl
|
||||||
|
bio=True
|
||||||
|
train_iter=1000
|
||||||
|
dev_iter=-1
|
||||||
|
val_steps=500
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
name=type_mamlcls${ft_steps}_bz${batch_size}
|
||||||
|
inner_steps=1
|
||||||
|
|
||||||
|
output_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
|
||||||
|
mkdir -p ${output_dir}
|
||||||
|
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
|
||||||
|
|
||||||
|
python train_type.py --mode train \
|
||||||
|
--seed ${seed} \
|
||||||
|
--root ${root} \
|
||||||
|
--train ${train_fn} \
|
||||||
|
--val ${val_fn} \
|
||||||
|
--test ${test_fn} \
|
||||||
|
--ep_dir ${ep_dir} \
|
||||||
|
--ep_train ${ep_train_fn} \
|
||||||
|
--ep_val ${ep_val_fn} \
|
||||||
|
--ep_test ${ep_test_fn} \
|
||||||
|
--output_dir ${output_dir} \
|
||||||
|
--N ${d} \
|
||||||
|
--K ${K} \
|
||||||
|
--Q 1 \
|
||||||
|
--bio ${bio} \
|
||||||
|
--encoder_name_or_path bert-base-uncased \
|
||||||
|
--last_n_layer -4 \
|
||||||
|
--max_length 128 \
|
||||||
|
--word_encode_choice first \
|
||||||
|
--span_encode_choice avg \
|
||||||
|
--learning_rate 0.001 \
|
||||||
|
--bert_learning_rate 5e-5 \
|
||||||
|
--bert_weight_decay 0.01 \
|
||||||
|
--train_iter ${train_iter} \
|
||||||
|
--dev_iter ${dev_iter} \
|
||||||
|
--val_steps ${val_steps} \
|
||||||
|
--warmup_step 0 \
|
||||||
|
--log_steps 50 \
|
||||||
|
--train_batch_size ${batch_size} \
|
||||||
|
--eval_batch_size 1 \
|
||||||
|
--use_width False \
|
||||||
|
--use_case False \
|
||||||
|
--dropout 0.5 \
|
||||||
|
--max_grad_norm 5 \
|
||||||
|
--dot False \
|
||||||
|
--normalize l2 \
|
||||||
|
--temperature 0.1 \
|
||||||
|
--use_focal False \
|
||||||
|
--use_att False \
|
||||||
|
--att_hidden_dim 100 \
|
||||||
|
--use_oproto False \
|
||||||
|
--hou_eval_ep ${hou_eval_ep} \
|
||||||
|
--overlap False \
|
||||||
|
--type_threshold 0 \
|
||||||
|
--use_maml True \
|
||||||
|
--train_inner_lr 2e-5 \
|
||||||
|
--train_inner_steps ${inner_steps} \
|
||||||
|
--warmup_prop_inner 0 \
|
||||||
|
--eval_inner_lr 2e-5 \
|
||||||
|
--eval_inner_steps ${ft_steps} | tee ${log_name}
|
||||||
|
|
||||||
|
mv ${log_name} ${output_dir}/train.log
|
|
@ -0,0 +1,267 @@
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from util.joint_loader import get_joint_loader
|
||||||
|
from trainer.joint_trainer import JointTrainer
|
||||||
|
from util.span_encoder import BERTSpanEncoder
|
||||||
|
from util.log_utils import eval_ent_log, cal_episode_prf, set_seed, save_json
|
||||||
|
from model.joint_model import SelectedJointModel
|
||||||
|
|
||||||
|
def add_args():
|
||||||
|
def str2bool(arg):
|
||||||
|
if arg.lower() == "true":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--mode', default='test', type=str,
|
||||||
|
help='train / test')
|
||||||
|
parser.add_argument('--load_ckpt', default=None,
|
||||||
|
help='load ckpt')
|
||||||
|
parser.add_argument('--output_dir', default=None,
|
||||||
|
help='output dir')
|
||||||
|
parser.add_argument('--log_dir', default=None,
|
||||||
|
help='log dir')
|
||||||
|
parser.add_argument('--root', default=None, type=str,
|
||||||
|
help='data root dir')
|
||||||
|
parser.add_argument('--train', default='train.txt',
|
||||||
|
help='train file')
|
||||||
|
parser.add_argument('--val', default='dev.txt',
|
||||||
|
help='val file')
|
||||||
|
parser.add_argument('--test', default='test.txt',
|
||||||
|
help='test file')
|
||||||
|
parser.add_argument('--train_ment_fn', default=None,
|
||||||
|
help='train file')
|
||||||
|
parser.add_argument('--val_ment_fn', default=None,
|
||||||
|
help='val file')
|
||||||
|
parser.add_argument('--test_ment_fn', default=None,
|
||||||
|
help='test file')
|
||||||
|
parser.add_argument('--ep_dir', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_train', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_val', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_test', default=None, type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--N', default=None, type=int,
|
||||||
|
help='N way')
|
||||||
|
parser.add_argument('--K', default=None, type=int,
|
||||||
|
help='K shot')
|
||||||
|
parser.add_argument('--Q', default=None, type=int,
|
||||||
|
help='Num of query per class')
|
||||||
|
|
||||||
|
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
|
||||||
|
parser.add_argument('--word_encode_choice', default='first', type=str)
|
||||||
|
parser.add_argument('--span_encode_choice', default=None, type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--max_length', default=128, type=int,
|
||||||
|
help='max length')
|
||||||
|
parser.add_argument('--max_span_len', default=8, type=int)
|
||||||
|
parser.add_argument('--max_neg_ratio', default=5, type=int)
|
||||||
|
parser.add_argument('--last_n_layer', default=-4, type=int)
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument('--dot', type=str2bool, default=False,
|
||||||
|
help='use dot instead of L2 distance for knn')
|
||||||
|
parser.add_argument("--normalize", default='none', type=str, choices=['none', 'l2'])
|
||||||
|
parser.add_argument("--temperature", default=None, type=float)
|
||||||
|
parser.add_argument("--use_width", default=False, type=str2bool)
|
||||||
|
parser.add_argument("--width_dim", default=20, type=int)
|
||||||
|
parser.add_argument("--use_case", default=False, type=str2bool)
|
||||||
|
parser.add_argument("--case_dim", default=20, type=int)
|
||||||
|
|
||||||
|
parser.add_argument('--dropout', default=0.5, type=float,
|
||||||
|
help='dropout rate')
|
||||||
|
|
||||||
|
parser.add_argument('--log_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
parser.add_argument('--val_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
|
||||||
|
parser.add_argument('--train_batch_size', default=4, type=int,
|
||||||
|
help='batch size')
|
||||||
|
parser.add_argument('--eval_batch_size', default=1, type=int,
|
||||||
|
help='batch size')
|
||||||
|
|
||||||
|
parser.add_argument('--train_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--dev_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--test_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
|
||||||
|
parser.add_argument("--max_grad_norm", default=None, type=float)
|
||||||
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
|
||||||
|
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
|
||||||
|
|
||||||
|
parser.add_argument("--warmup_step", default=0, type=int)
|
||||||
|
|
||||||
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
|
help='seed')
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=False,
|
||||||
|
help='use nvidia apex fp16')
|
||||||
|
parser.add_argument('--use_focal', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--iou_thred', default=None, type=float)
|
||||||
|
parser.add_argument('--use_att', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--att_hidden_dim', default=-1, type=int)
|
||||||
|
parser.add_argument('--label_fn', default=None, type=str)
|
||||||
|
parser.add_argument('--hou_eval_ep', default=False, type=str2bool)
|
||||||
|
|
||||||
|
parser.add_argument('--use_maml', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--warmup_prop_inner', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_steps', default=0, type=int)
|
||||||
|
parser.add_argument('--eval_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--eval_type_inner_steps', default=0, type=int)
|
||||||
|
parser.add_argument('--eval_ment_inner_steps', default=0, type=int)
|
||||||
|
parser.add_argument('--overlap', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--type_lam', default=1, type=float)
|
||||||
|
parser.add_argument('--use_adapter', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--adapter_size', default=64, type=int)
|
||||||
|
parser.add_argument('--type_threshold', default=-1, type=float)
|
||||||
|
parser.add_argument('--use_oproto', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--bio', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--schema', default='IO', type=str, choices=['IO', 'BIO', 'BIOES'])
|
||||||
|
parser.add_argument('--use_crf', type=str2bool, default=False)
|
||||||
|
parser.add_argument('--max_loss', default=0, type=float)
|
||||||
|
parser.add_argument('--adapter_layer_ids', default='9-10-11', type=str)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
print("Joint Model pipeline {} way {} shot".format(opt.N, opt.K))
|
||||||
|
set_seed(opt.seed)
|
||||||
|
if opt.mode == "train":
|
||||||
|
output_dir = opt.output_dir
|
||||||
|
print("Output dir is ", output_dir)
|
||||||
|
log_dir = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
"logs",
|
||||||
|
)
|
||||||
|
opt.log_dir = log_dir
|
||||||
|
if not os.path.exists(opt.output_dir):
|
||||||
|
os.makedirs(opt.output_dir)
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
|
||||||
|
else:
|
||||||
|
if not os.path.exists(opt.output_dir):
|
||||||
|
os.makedirs(opt.output_dir)
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
|
||||||
|
print('loading model and tokenizer...')
|
||||||
|
print("use adapter: ", opt.use_adapter)
|
||||||
|
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
|
||||||
|
word_encode_choice=opt.word_encode_choice,
|
||||||
|
span_encode_choice=opt.span_encode_choice,
|
||||||
|
use_width=opt.use_width, width_dim=opt.width_dim, use_case=opt.use_case,
|
||||||
|
case_dim=opt.case_dim,
|
||||||
|
drop_p=opt.dropout, use_att=opt.use_att,
|
||||||
|
att_hidden_dim=opt.att_hidden_dim, use_adapter=opt.use_adapter, adapter_size=opt.adapter_size,
|
||||||
|
adapter_layer_ids=[int(x) for x in opt.adapter_layer_ids.split('-')])
|
||||||
|
print('loading data')
|
||||||
|
if opt.mode == "train":
|
||||||
|
train_loader = get_joint_loader(os.path.join(opt.root, opt.train), "train", word_encoder, N=opt.N,
|
||||||
|
K=opt.K,
|
||||||
|
Q=opt.Q, batch_size=opt.train_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=True,
|
||||||
|
bio=opt.bio,
|
||||||
|
schema=opt.schema,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None,
|
||||||
|
query_file=opt.train_ment_fn,
|
||||||
|
iou_thred=opt.iou_thred,
|
||||||
|
use_oproto=opt.use_oproto,
|
||||||
|
label_fn=opt.label_fn)
|
||||||
|
print(os.path.join(opt.root, opt.val))
|
||||||
|
dev_loader = get_joint_loader(os.path.join(opt.root, opt.val), "test", word_encoder, N=opt.N, K=opt.K,
|
||||||
|
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio,
|
||||||
|
schema=opt.schema,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None,
|
||||||
|
query_file=opt.val_ment_fn,
|
||||||
|
iou_thred=opt.iou_thred,
|
||||||
|
use_oproto=opt.use_oproto,
|
||||||
|
label_fn=opt.label_fn)
|
||||||
|
|
||||||
|
test_loader = get_joint_loader(os.path.join(opt.root, opt.test), "test", word_encoder, N=opt.N, K=opt.K,
|
||||||
|
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio,
|
||||||
|
schema=opt.schema,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None,
|
||||||
|
query_file=opt.test_ment_fn,
|
||||||
|
iou_thred=opt.iou_thred, hidden_query_label=False,
|
||||||
|
use_oproto=opt.use_oproto,
|
||||||
|
label_fn=opt.label_fn)
|
||||||
|
|
||||||
|
print("max_length: {}".format(opt.max_length))
|
||||||
|
print('mode: {}'.format(opt.mode))
|
||||||
|
|
||||||
|
print("{}-way-{}-shot Proto Few-Shot NER".format(opt.N, opt.K))
|
||||||
|
|
||||||
|
model = SelectedJointModel(word_encoder, num_tag=len(test_loader.dataset.ment_label2tag),
|
||||||
|
ment_label2idx=test_loader.dataset.ment_tag2label,
|
||||||
|
schema=opt.schema, use_crf=opt.use_crf, max_loss=opt.max_loss,
|
||||||
|
use_oproto=opt.use_oproto, dot=opt.dot, normalize=opt.normalize,
|
||||||
|
temperature=opt.temperature, use_focal=opt.use_focal, type_lam=opt.type_lam)
|
||||||
|
|
||||||
|
num_params = sum(param.numel() for param in model.parameters())
|
||||||
|
print("total parameter numbers: ", num_params)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
trainer = JointTrainer()
|
||||||
|
assert opt.eval_batch_size == 1
|
||||||
|
|
||||||
|
if opt.mode == 'train':
|
||||||
|
trainer.train(model, opt, device, train_loader, dev_loader, dev_pred_fn=os.path.join(opt.output_dir, "dev_metrics.json"), dev_log_fn=os.path.join(opt.output_dir, "dev_log.txt"), load_ckpt=opt.load_ckpt)
|
||||||
|
test_loss, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_ment_logs, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=os.path.join(opt.output_dir, "model.pth.tar"),
|
||||||
|
ment_update_iter=opt.eval_ment_inner_steps,
|
||||||
|
type_update_iter=opt.eval_type_inner_steps,
|
||||||
|
learning_rate=opt.eval_inner_lr,
|
||||||
|
overlap=opt.overlap,
|
||||||
|
threshold=opt.type_threshold,
|
||||||
|
eval_mode="test-twostage")
|
||||||
|
if opt.hou_eval_ep:
|
||||||
|
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
|
||||||
|
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
|
||||||
|
print('[TEST] loss: {0:2.6f} | [Entity] precision: {1:3.4f}, recall: {2:3.4f}, f1: {3:3.4f}'\
|
||||||
|
.format(test_loss, test_p, test_r, test_f1) + '\r')
|
||||||
|
else:
|
||||||
|
|
||||||
|
_, dev_ment_p, dev_ment_r, dev_ment_f1, dev_p, dev_r, dev_f1, dev_ment_logs, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=opt.load_ckpt,
|
||||||
|
eval_iter=opt.dev_iter,
|
||||||
|
ment_update_iter=opt.eval_ment_inner_steps,
|
||||||
|
type_update_iter=opt.eval_type_inner_steps,
|
||||||
|
learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold, eval_mode=opt.mode)
|
||||||
|
|
||||||
|
if opt.hou_eval_ep:
|
||||||
|
dev_p, dev_r, dev_f1 = cal_episode_prf(dev_logs)
|
||||||
|
print("Mention dev precision {:.5f} recall {:.5f} f1 {:.5f}".format(dev_ment_p, dev_ment_r, dev_ment_f1))
|
||||||
|
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
|
||||||
|
with open(os.path.join(opt.output_dir, "dev_metrics.json"), mode="w", encoding="utf-8") as fp:
|
||||||
|
json.dump({"ment_p": dev_ment_p, "ment_r": dev_ment_r, "ment_f1": dev_ment_f1, "precision": dev_p, "recall": dev_r, "f1": dev_f1}, fp)
|
||||||
|
|
||||||
|
test_loss, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_ment_logs, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=opt.load_ckpt,
|
||||||
|
ment_update_iter=opt.eval_ment_inner_steps,
|
||||||
|
type_update_iter=opt.eval_type_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold, eval_mode=opt.mode)
|
||||||
|
if opt.hou_eval_ep:
|
||||||
|
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
|
||||||
|
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
|
||||||
|
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
|
||||||
|
|
||||||
|
|
||||||
|
with open(os.path.join(opt.output_dir, "test_metrics.json"), mode="w", encoding="utf-8") as fp:
|
||||||
|
json.dump({"ment_p": test_ment_p, "ment_r": test_ment_r, "ment_f1": test_ment_f1, "precision": test_p, "recall": test_r, "f1": test_f1}, fp)
|
||||||
|
eval_ent_log(test_loader.dataset.samples, test_logs)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
opt = add_args()
|
||||||
|
main(opt)
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from util.pos_loader import get_seq_loader
|
||||||
|
from trainer.pos_trainer import POSTrainer
|
||||||
|
from util.span_encoder import BERTSpanEncoder
|
||||||
|
from model.pos_model import SeqProtoCls
|
||||||
|
from util.log_utils import write_pos_pred_json, save_json, set_seed
|
||||||
|
|
||||||
|
def add_args():
|
||||||
|
def str2bool(arg):
|
||||||
|
if arg.lower() == "true":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--mode', default='test', type=str,
|
||||||
|
help='train / test / typetest')
|
||||||
|
parser.add_argument('--train_eval_mode', default='test', type=str, choices=['test', 'typetest'])
|
||||||
|
parser.add_argument('--load_ckpt', default=None,
|
||||||
|
help='load ckpt')
|
||||||
|
parser.add_argument('--output_dir', default=None,
|
||||||
|
help='output dir')
|
||||||
|
parser.add_argument('--log_dir', default=None,
|
||||||
|
help='log dir')
|
||||||
|
parser.add_argument('--root', default=None, type=str,
|
||||||
|
help='data root dir')
|
||||||
|
parser.add_argument('--train', default='train.txt',
|
||||||
|
help='train file')
|
||||||
|
parser.add_argument('--val', default='dev.txt',
|
||||||
|
help='val file')
|
||||||
|
parser.add_argument('--test', default='test.txt',
|
||||||
|
help='test file')
|
||||||
|
parser.add_argument('--ep_dir', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_train', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_val', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_test', default=None, type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--N', default=None, type=int,
|
||||||
|
help='N way')
|
||||||
|
parser.add_argument('--K', default=None, type=int,
|
||||||
|
help='K shot')
|
||||||
|
parser.add_argument('--Q', default=None, type=int,
|
||||||
|
help='Num of query per class')
|
||||||
|
|
||||||
|
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
|
||||||
|
parser.add_argument('--word_encode_choice', default='first', type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--max_length', default=128, type=int,
|
||||||
|
help='max length')
|
||||||
|
parser.add_argument('--last_n_layer', default=-4, type=int)
|
||||||
|
parser.add_argument('--max_loss', default=0, type=float)
|
||||||
|
|
||||||
|
parser.add_argument('--dot', type=str2bool, default=False,
|
||||||
|
help='use dot instead of L2 distance for knn')
|
||||||
|
parser.add_argument("--normalize", default='none', type=str, choices=['none', 'l2'])
|
||||||
|
parser.add_argument("--temperature", default=None, type=float)
|
||||||
|
|
||||||
|
parser.add_argument('--dropout', default=0.5, type=float,
|
||||||
|
help='dropout rate')
|
||||||
|
|
||||||
|
parser.add_argument('--log_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
parser.add_argument('--val_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
|
||||||
|
parser.add_argument('--train_batch_size', default=4, type=int,
|
||||||
|
help='batch size')
|
||||||
|
parser.add_argument('--eval_batch_size', default=1, type=int,
|
||||||
|
help='batch size')
|
||||||
|
|
||||||
|
parser.add_argument('--train_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--dev_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--test_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
|
||||||
|
parser.add_argument("--max_grad_norm", default=None, type=float)
|
||||||
|
parser.add_argument("--learning_rate", default=1e-3, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
|
||||||
|
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
|
||||||
|
|
||||||
|
parser.add_argument("--warmup_step", default=0, type=int)
|
||||||
|
|
||||||
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
|
help='seed')
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=False,
|
||||||
|
help='use nvidia apex fp16')
|
||||||
|
|
||||||
|
parser.add_argument('--use_maml', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--warmup_prop_inner', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_steps', default=0, type=int)
|
||||||
|
parser.add_argument('--eval_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--eval_inner_steps', default=0, type=int)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
return opt
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
set_seed(opt.seed)
|
||||||
|
if not os.path.exists(opt.output_dir):
|
||||||
|
os.makedirs(opt.output_dir)
|
||||||
|
if opt.mode == "train":
|
||||||
|
output_dir = opt.output_dir
|
||||||
|
print("Output dir is ", output_dir)
|
||||||
|
log_dir = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
"logs",
|
||||||
|
)
|
||||||
|
opt.log_dir = log_dir
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
|
||||||
|
else:
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
|
||||||
|
|
||||||
|
print('loading model and tokenizer...')
|
||||||
|
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
|
||||||
|
word_encode_choice=opt.word_encode_choice, drop_p=opt.dropout)
|
||||||
|
print('loading data')
|
||||||
|
if opt.mode == "train":
|
||||||
|
train_loader = get_seq_loader(os.path.join(opt.root, opt.train), "train", word_encoder, batch_size=opt.train_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=True,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
|
||||||
|
else:
|
||||||
|
train_loader = get_seq_loader(os.path.join(opt.root, opt.train), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=False,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
|
||||||
|
dev_loader = get_seq_loader(os.path.join(opt.root, opt.val), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=False,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None)
|
||||||
|
test_loader = get_seq_loader(os.path.join(opt.root, opt.test), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=False,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None)
|
||||||
|
|
||||||
|
print("max_length: {}".format(opt.max_length))
|
||||||
|
print('mode: {}'.format(opt.mode))
|
||||||
|
|
||||||
|
print("{}-way-{}-shot Token MAML-Proto Few-Shot NER".format(opt.N, opt.K))
|
||||||
|
print("this mode can only used for maml training !!!!!!!!!!!!!!!!!")
|
||||||
|
model = SeqProtoCls(word_encoder, opt.max_loss, dot=opt.dot, normalize=opt.normalize,
|
||||||
|
temperature=opt.temperature)
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
trainer = POSTrainer()
|
||||||
|
|
||||||
|
if opt.mode == 'train':
|
||||||
|
print("==================start training==================")
|
||||||
|
trainer.train(model, opt, device, train_loader, dev_loader, dev_pred_fn=os.path.join(opt.output_dir, "dev_pred.json"))
|
||||||
|
test_loss, test_acc, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=os.path.join(opt.output_dir, "model.pth.tar"),
|
||||||
|
update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_mode=opt.train_eval_mode)
|
||||||
|
|
||||||
|
print('[TEST] loss: {0:2.6f} | [POS] acc: {1:3.4f}'\
|
||||||
|
.format(test_loss, test_acc) + '\r')
|
||||||
|
else:
|
||||||
|
test_loss, test_acc, test_logs = trainer.eval(model, device, test_loader, load_ckpt=opt.load_ckpt,
|
||||||
|
eval_iter=opt.test_iter, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_mode=opt.mode)
|
||||||
|
|
||||||
|
print('[TEST] loss: {0:2.6f} | [POS] acc: {1:3.4f}'\
|
||||||
|
.format(test_loss, test_acc) + '\r')
|
||||||
|
name = "test_metrics.json"
|
||||||
|
if opt.mode != "test":
|
||||||
|
name = f"{opt.mode}_test_metrics.json"
|
||||||
|
with open(os.path.join(opt.output_dir, name), mode="w", encoding="utf-8") as fp:
|
||||||
|
res_mp = {"test_acc": test_acc}
|
||||||
|
json.dump(res_mp, fp)
|
||||||
|
|
||||||
|
write_pos_pred_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_pred.json"))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
opt = add_args()
|
||||||
|
main(opt)
|
|
@ -0,0 +1,180 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from util.seq_loader import get_seq_loader as get_ment_loader
|
||||||
|
from trainer.ment_trainer import MentTrainer
|
||||||
|
from util.log_utils import eval_ment_log, write_ep_ment_log_json, save_json, set_seed
|
||||||
|
from util.span_encoder import BERTSpanEncoder
|
||||||
|
from model.ment_model import MentSeqtagger
|
||||||
|
|
||||||
|
def add_args():
|
||||||
|
def str2bool(arg):
|
||||||
|
if arg.lower() == "true":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--mode', default='test', type=str,
|
||||||
|
help='train / test')
|
||||||
|
parser.add_argument("--use_episode", default=False, type=str2bool)
|
||||||
|
parser.add_argument("--bio", default=False, type=str2bool)
|
||||||
|
parser.add_argument('--load_ckpt', default=None,
|
||||||
|
help='load ckpt')
|
||||||
|
parser.add_argument('--output_dir', default=None,
|
||||||
|
help='output dir')
|
||||||
|
parser.add_argument('--log_dir', default=None,
|
||||||
|
help='log dir')
|
||||||
|
parser.add_argument('--root', default=None, type=str,
|
||||||
|
help='data root dir')
|
||||||
|
parser.add_argument('--train', default='train.txt',
|
||||||
|
help='train file')
|
||||||
|
parser.add_argument('--val', default='dev.txt',
|
||||||
|
help='val file')
|
||||||
|
parser.add_argument('--test', default='test.txt',
|
||||||
|
help='test file')
|
||||||
|
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
|
||||||
|
parser.add_argument('--word_encode_choice', default='first', type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--max_length', default=128, type=int,
|
||||||
|
help='max length')
|
||||||
|
parser.add_argument('--last_n_layer', default=-4, type=int)
|
||||||
|
parser.add_argument('--schema', default='IO', type=str, choices=['IO', 'BIO', 'BIOES'])
|
||||||
|
parser.add_argument('--use_crf', type=str2bool, default=False)
|
||||||
|
parser.add_argument('--max_loss', default=0, type=float)
|
||||||
|
|
||||||
|
parser.add_argument('--dropout', default=0.5, type=float,
|
||||||
|
help='dropout rate')
|
||||||
|
|
||||||
|
parser.add_argument('--log_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
parser.add_argument('--val_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
|
||||||
|
parser.add_argument('--train_batch_size', default=16, type=int,
|
||||||
|
help='batch size')
|
||||||
|
parser.add_argument('--eval_batch_size', default=16, type=int,
|
||||||
|
help='batch size')
|
||||||
|
|
||||||
|
parser.add_argument('--train_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--dev_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--test_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
|
||||||
|
parser.add_argument("--max_grad_norm", default=None, type=float)
|
||||||
|
parser.add_argument("--learning_rate", default=1e-3, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
|
||||||
|
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
|
||||||
|
|
||||||
|
parser.add_argument("--warmup_step", default=0, type=int)
|
||||||
|
|
||||||
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
|
help='seed')
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=False,
|
||||||
|
help='use nvidia apex fp16')
|
||||||
|
|
||||||
|
parser.add_argument('--ep_dir', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_train', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_val', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_test', default=None, type=str)
|
||||||
|
parser.add_argument('--eval_all_after_train', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--use_maml', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--warmup_prop_inner', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_steps', default=0, type=int)
|
||||||
|
parser.add_argument('--eval_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--eval_inner_steps', default=0, type=int)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
return opt
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
set_seed(opt.seed)
|
||||||
|
if not os.path.exists(opt.output_dir):
|
||||||
|
os.makedirs(opt.output_dir)
|
||||||
|
if opt.mode == "train":
|
||||||
|
output_dir = opt.output_dir
|
||||||
|
print("Output dir is ", output_dir)
|
||||||
|
log_dir = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
"logs",
|
||||||
|
)
|
||||||
|
opt.log_dir = log_dir
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
|
||||||
|
else:
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
|
||||||
|
|
||||||
|
print('loading model and tokenizer...')
|
||||||
|
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
|
||||||
|
word_encode_choice=opt.word_encode_choice, drop_p=opt.dropout)
|
||||||
|
print('loading data')
|
||||||
|
if opt.mode == "train":
|
||||||
|
train_loader = get_ment_loader(os.path.join(opt.root, opt.train), "train", word_encoder, batch_size=opt.train_batch_size, max_length=opt.max_length,
|
||||||
|
schema=opt.schema,
|
||||||
|
shuffle=True,
|
||||||
|
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
|
||||||
|
else:
|
||||||
|
train_loader = get_ment_loader(os.path.join(opt.root, opt.train), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
schema=opt.schema,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
|
||||||
|
dev_loader = get_ment_loader(os.path.join(opt.root, opt.val), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
schema=opt.schema,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None)
|
||||||
|
test_loader = get_ment_loader(os.path.join(opt.root, opt.test), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
schema=opt.schema,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None)
|
||||||
|
|
||||||
|
print("max_length: {}".format(opt.max_length))
|
||||||
|
print('mode: {}'.format(opt.mode))
|
||||||
|
|
||||||
|
print("mention detection sequence labeling with max loss {}".format(opt.max_loss))
|
||||||
|
model = MentSeqtagger(word_encoder, len(test_loader.dataset.ment_label2tag), test_loader.dataset.ment_tag2label, opt.schema, opt.use_crf, opt.max_loss)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
print(device)
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
trainer = MentTrainer()
|
||||||
|
assert opt.eval_batch_size == 1
|
||||||
|
|
||||||
|
if opt.mode == 'train':
|
||||||
|
print("==================start training==================")
|
||||||
|
trainer.train(model, opt, device, train_loader, dev_loader, eval_log_fn=os.path.join(opt.output_dir, "dev_ment_log.txt"))
|
||||||
|
if opt.eval_all_after_train:
|
||||||
|
train_loader = get_ment_loader(os.path.join(opt.root, opt.train), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
schema=opt.schema,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
|
||||||
|
load_ckpt = os.path.join(opt.output_dir, "model.pth.tar")
|
||||||
|
|
||||||
|
_, dev_p, dev_r, dev_f1, _, _, _, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.dev_iter)
|
||||||
|
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
|
||||||
|
eval_ment_log(dev_loader.dataset.samples, dev_logs)
|
||||||
|
write_ep_ment_log_json(dev_loader.dataset.samples, dev_logs, os.path.join(opt.output_dir, "dev_ment.json"))
|
||||||
|
|
||||||
|
_, test_p, test_r, test_f1, _, _, _, test_logs = trainer.eval(model, device, test_loader, load_ckpt=load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.test_iter)
|
||||||
|
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
|
||||||
|
eval_ment_log(test_loader.dataset.samples, test_logs)
|
||||||
|
write_ep_ment_log_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_ment.json"))
|
||||||
|
else:
|
||||||
|
_, dev_p, dev_r, dev_f1, _, _, _, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=opt.load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.dev_iter)
|
||||||
|
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
|
||||||
|
eval_ment_log(dev_loader.dataset.samples, dev_logs)
|
||||||
|
write_ep_ment_log_json(dev_loader.dataset.samples, dev_logs, os.path.join(opt.output_dir, "dev_ment.json"))
|
||||||
|
|
||||||
|
_, test_p, test_r, test_f1, _, _, _, test_logs = trainer.eval(model, device, test_loader, load_ckpt=opt.load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.test_iter)
|
||||||
|
model.timer.avg()
|
||||||
|
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
|
||||||
|
eval_ment_log(test_loader.dataset.samples, test_logs)
|
||||||
|
write_ep_ment_log_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_ment.json"))
|
||||||
|
res_mp = {"dev_p": dev_p, "dev_r": dev_r, "dev_f1": dev_f1, "test_p": test_p, "test_r": test_r, "test_f1": test_f1}
|
||||||
|
save_json(res_mp, os.path.join(opt.output_dir, "metrics.json"))
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
opt = add_args()
|
||||||
|
main(opt)
|
|
@ -0,0 +1,243 @@
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from trainer.span_trainer import SpanTrainer
|
||||||
|
from util.span_encoder import BERTSpanEncoder
|
||||||
|
from util.span_loader import get_query_loader
|
||||||
|
from util.log_utils import eval_ent_log, cal_episode_prf, write_ent_pred_json, set_seed, save_json
|
||||||
|
from model.type_model import SpanProtoCls
|
||||||
|
|
||||||
|
|
||||||
|
def add_args():
|
||||||
|
def str2bool(arg):
|
||||||
|
if arg.lower() == "true":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--mode', default='test', type=str,
|
||||||
|
help='train / test')
|
||||||
|
parser.add_argument('--load_ckpt', default=None,
|
||||||
|
help='load ckpt')
|
||||||
|
parser.add_argument('--output_dir', default=None,
|
||||||
|
help='output dir')
|
||||||
|
parser.add_argument('--log_dir', default=None,
|
||||||
|
help='log dir')
|
||||||
|
parser.add_argument('--root', default=None, type=str,
|
||||||
|
help='data root dir')
|
||||||
|
parser.add_argument('--train', default='train.txt',
|
||||||
|
help='train file')
|
||||||
|
parser.add_argument('--val', default='dev.txt',
|
||||||
|
help='val file')
|
||||||
|
parser.add_argument('--test', default='test.txt',
|
||||||
|
help='test file')
|
||||||
|
parser.add_argument('--train_ment_fn', default=None,
|
||||||
|
help='train file')
|
||||||
|
parser.add_argument('--val_ment_fn', default=None,
|
||||||
|
help='val file')
|
||||||
|
parser.add_argument('--test_ment_fn', default=None,
|
||||||
|
help='test file')
|
||||||
|
parser.add_argument('--ep_dir', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_train', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_val', default=None, type=str)
|
||||||
|
parser.add_argument('--ep_test', default=None, type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--N', default=None, type=int,
|
||||||
|
help='N way')
|
||||||
|
parser.add_argument('--K', default=None, type=int,
|
||||||
|
help='K shot')
|
||||||
|
parser.add_argument('--Q', default=None, type=int,
|
||||||
|
help='Num of query per class')
|
||||||
|
|
||||||
|
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
|
||||||
|
parser.add_argument('--word_encode_choice', default='first', type=str)
|
||||||
|
parser.add_argument('--span_encode_choice', default=None, type=str)
|
||||||
|
|
||||||
|
parser.add_argument('--max_length', default=128, type=int,
|
||||||
|
help='max length')
|
||||||
|
parser.add_argument('--max_span_len', default=8, type=int)
|
||||||
|
parser.add_argument('--max_neg_ratio', default=5, type=int)
|
||||||
|
parser.add_argument('--last_n_layer', default=-4, type=int)
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument('--dot', type=str2bool, default=False,
|
||||||
|
help='use dot instead of L2 distance for knn')
|
||||||
|
parser.add_argument("--normalize", default='none', type=str, choices=['none', 'l2'])
|
||||||
|
parser.add_argument("--temperature", default=None, type=float)
|
||||||
|
parser.add_argument("--use_width", default=False, type=str2bool)
|
||||||
|
parser.add_argument("--width_dim", default=20, type=int)
|
||||||
|
parser.add_argument("--use_case", default=False, type=str2bool)
|
||||||
|
parser.add_argument("--case_dim", default=20, type=int)
|
||||||
|
|
||||||
|
parser.add_argument('--dropout', default=0.5, type=float,
|
||||||
|
help='dropout rate')
|
||||||
|
|
||||||
|
parser.add_argument('--log_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
parser.add_argument('--val_steps', default=2000, type=int,
|
||||||
|
help='val after training how many iters')
|
||||||
|
|
||||||
|
parser.add_argument('--train_batch_size', default=4, type=int,
|
||||||
|
help='batch size')
|
||||||
|
parser.add_argument('--eval_batch_size', default=1, type=int,
|
||||||
|
help='batch size')
|
||||||
|
|
||||||
|
parser.add_argument('--train_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--dev_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
parser.add_argument('--test_iter', default=-1, type=int,
|
||||||
|
help='num of iters in training')
|
||||||
|
|
||||||
|
parser.add_argument("--max_grad_norm", default=None, type=float)
|
||||||
|
parser.add_argument("--learning_rate", default=1e-3, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
|
||||||
|
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
|
||||||
|
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
|
||||||
|
|
||||||
|
parser.add_argument("--warmup_step", default=0, type=int)
|
||||||
|
|
||||||
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
|
help='seed')
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=False,
|
||||||
|
help='use nvidia apex fp16')
|
||||||
|
parser.add_argument('--use_focal', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--iou_thred', default=None, type=float)
|
||||||
|
parser.add_argument('--use_att', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--att_hidden_dim', default=-1, type=int)
|
||||||
|
parser.add_argument('--label_fn', default=None, type=str)
|
||||||
|
parser.add_argument('--hou_eval_ep', default=False, type=str2bool)
|
||||||
|
|
||||||
|
parser.add_argument('--use_maml', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--warmup_prop_inner', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--train_inner_steps', default=0, type=int)
|
||||||
|
parser.add_argument('--eval_inner_lr', default=0, type=float)
|
||||||
|
parser.add_argument('--eval_inner_steps', default=0, type=int)
|
||||||
|
parser.add_argument('--overlap', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--type_threshold', default=-1, type=float)
|
||||||
|
parser.add_argument('--use_oproto', default=False, type=str2bool)
|
||||||
|
parser.add_argument('--bio', default=False, type=str2bool)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
print("Span based proto pipeline {} way {} shot".format(opt.N, opt.K))
|
||||||
|
set_seed(opt.seed)
|
||||||
|
if opt.mode == "train":
|
||||||
|
output_dir = opt.output_dir
|
||||||
|
print("Output dir is ", output_dir)
|
||||||
|
log_dir = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
"logs",
|
||||||
|
)
|
||||||
|
opt.log_dir = log_dir
|
||||||
|
if not os.path.exists(opt.output_dir):
|
||||||
|
os.makedirs(opt.output_dir)
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
|
||||||
|
else:
|
||||||
|
if not os.path.exists(opt.output_dir):
|
||||||
|
os.makedirs(opt.output_dir)
|
||||||
|
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
|
||||||
|
print('loading model and tokenizer...')
|
||||||
|
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
|
||||||
|
word_encode_choice=opt.word_encode_choice,
|
||||||
|
span_encode_choice=opt.span_encode_choice,
|
||||||
|
use_width=opt.use_width, width_dim=opt.width_dim, use_case=opt.use_case,
|
||||||
|
case_dim=opt.case_dim,
|
||||||
|
drop_p=opt.dropout, use_att=opt.use_att,
|
||||||
|
att_hidden_dim=opt.att_hidden_dim)
|
||||||
|
print('loading data')
|
||||||
|
if opt.mode == "train":
|
||||||
|
train_loader = get_query_loader(os.path.join(opt.root, opt.train), "train", word_encoder, N=opt.N,
|
||||||
|
K=opt.K,
|
||||||
|
Q=opt.Q, batch_size=opt.train_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=True,
|
||||||
|
bio=opt.bio,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None,
|
||||||
|
query_file=opt.train_ment_fn,
|
||||||
|
iou_thred=opt.iou_thred,
|
||||||
|
use_oproto=opt.use_oproto,
|
||||||
|
label_fn=opt.label_fn)
|
||||||
|
|
||||||
|
dev_loader = get_query_loader(os.path.join(opt.root, opt.val), "test", word_encoder, N=opt.N, K=opt.K,
|
||||||
|
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None,
|
||||||
|
query_file=opt.val_ment_fn,
|
||||||
|
iou_thred=opt.iou_thred,
|
||||||
|
use_oproto=opt.use_oproto,
|
||||||
|
label_fn=opt.label_fn)
|
||||||
|
|
||||||
|
test_loader = get_query_loader(os.path.join(opt.root, opt.test), "test", word_encoder, N=opt.N, K=opt.K,
|
||||||
|
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
|
||||||
|
shuffle=False,
|
||||||
|
bio=opt.bio,
|
||||||
|
debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None,
|
||||||
|
query_file=opt.test_ment_fn,
|
||||||
|
iou_thred=opt.iou_thred, hidden_query_label=False,
|
||||||
|
use_oproto=opt.use_oproto,
|
||||||
|
label_fn=opt.label_fn)
|
||||||
|
|
||||||
|
print("max_length: {}".format(opt.max_length))
|
||||||
|
print('mode: {}'.format(opt.mode))
|
||||||
|
|
||||||
|
print("{}-way-{}-shot Proto Few-Shot NER".format(opt.N, opt.K))
|
||||||
|
if opt.train_inner_steps > 0:
|
||||||
|
print("this mode can only used for maml training !!!!!!!!!!!!!!!!!")
|
||||||
|
model = SpanProtoCls(word_encoder, use_oproto=opt.use_oproto, dot=opt.dot, normalize=opt.normalize,
|
||||||
|
temperature=opt.temperature, use_focal=opt.use_focal)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
trainer = SpanTrainer()
|
||||||
|
assert opt.eval_batch_size == 1
|
||||||
|
|
||||||
|
if opt.mode == 'train':
|
||||||
|
trainer.train(model, opt, device, train_loader, dev_loader, dev_pred_fn=os.path.join(opt.output_dir, "dev_pred.json"), dev_log_fn=os.path.join(opt.output_dir, "dev_log.txt"))
|
||||||
|
test_loss, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=os.path.join(opt.output_dir, "model.pth.tar"),
|
||||||
|
update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold)
|
||||||
|
if opt.hou_eval_ep:
|
||||||
|
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
|
||||||
|
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
|
||||||
|
print('[TEST] loss: {0:2.6f} | [Entity] precision: {1:3.4f}, recall: {2:3.4f}, f1: {3:3.4f}'\
|
||||||
|
.format(test_loss, test_p, test_r, test_f1) + '\r')
|
||||||
|
else:
|
||||||
|
|
||||||
|
_, dev_ment_p, dev_ment_r, dev_ment_f1, dev_p, dev_r, dev_f1, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=opt.load_ckpt,
|
||||||
|
eval_iter=opt.dev_iter, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold)
|
||||||
|
|
||||||
|
if opt.hou_eval_ep:
|
||||||
|
dev_p, dev_r, dev_f1 = cal_episode_prf(dev_logs)
|
||||||
|
print("Mention dev precision {:.5f} recall {:.5f} f1 {:.5f}".format(dev_ment_p, dev_ment_r, dev_ment_f1))
|
||||||
|
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
|
||||||
|
with open(os.path.join(opt.output_dir, "dev_metrics.json"), mode="w", encoding="utf-8") as fp:
|
||||||
|
json.dump({"precision": dev_p, "recall": dev_r, "f1": dev_f1}, fp)
|
||||||
|
|
||||||
|
_, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_logs = trainer.eval(model, device, test_loader, load_ckpt=opt.load_ckpt,
|
||||||
|
eval_iter=opt.test_iter, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold)
|
||||||
|
if opt.hou_eval_ep:
|
||||||
|
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
|
||||||
|
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
|
||||||
|
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
|
||||||
|
|
||||||
|
with open(os.path.join(opt.output_dir, "test_metrics.json"), mode="w", encoding="utf-8") as fp:
|
||||||
|
json.dump({"precision": test_p, "recall": test_r, "f1": test_f1}, fp)
|
||||||
|
eval_ent_log(test_loader.dataset.samples, test_logs)
|
||||||
|
write_ent_pred_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_pred.json"))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
opt = add_args()
|
||||||
|
main(opt)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,217 @@
|
||||||
|
import json
|
||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
from collections import defaultdict
|
||||||
|
from util.log_utils import write_ep_ment_log_json, write_ment_log, eval_ment_log
|
||||||
|
from util.log_utils import write_ent_log, write_ent_pred_json, eval_ent_log, cal_prf
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class JointTrainer:
|
||||||
|
def __init__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_model__(self, ckpt):
|
||||||
|
if os.path.isfile(ckpt):
|
||||||
|
checkpoint = torch.load(ckpt)
|
||||||
|
print("Successfully loaded checkpoint '%s'" % ckpt)
|
||||||
|
return checkpoint
|
||||||
|
else:
|
||||||
|
raise Exception("No checkpoint found at '%s'" % ckpt)
|
||||||
|
|
||||||
|
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
|
||||||
|
if schedule == "linear":
|
||||||
|
if progress < warmup:
|
||||||
|
lr *= progress / warmup
|
||||||
|
else:
|
||||||
|
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def check_input_ment(self, query):
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
|
||||||
|
for i in range(len(query['spans'])):
|
||||||
|
pred_ments = query['span_indices'][i][query['span_mask'][i].eq(1)].detach().cpu().tolist()
|
||||||
|
gold_ments = [x[1:] for x in query['spans'][i]]
|
||||||
|
pred_cnt += len(pred_ments)
|
||||||
|
gold_cnt += len(gold_ments)
|
||||||
|
for x in gold_ments:
|
||||||
|
if x in pred_ments:
|
||||||
|
hit_cnt += 1
|
||||||
|
return pred_cnt, gold_cnt, hit_cnt
|
||||||
|
|
||||||
|
def eval(self, model, device, dataloader, load_ckpt=None, ment_update_iter=0, type_update_iter=0, learning_rate=3e-5, eval_iter=-1, eval_mode="test-twostage", overlap=False, threshold=-1):
|
||||||
|
if load_ckpt:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
model.eval()
|
||||||
|
eval_loss = 0
|
||||||
|
tot_seq_cnt = 0
|
||||||
|
tot_ment_logs = {}
|
||||||
|
tot_type_logs = {}
|
||||||
|
eval_batchs = iter(dataloader)
|
||||||
|
tot_ment_metric_logs = defaultdict(int)
|
||||||
|
tot_type_metric_logs = defaultdict(int)
|
||||||
|
eval_loss = 0
|
||||||
|
if eval_iter <= 0:
|
||||||
|
eval_iter = len(dataloader.dataset.sampler)
|
||||||
|
tot_seq_cnt = 0
|
||||||
|
print("[eval] update {} steps | total {} episode".format(ment_update_iter, eval_iter))
|
||||||
|
for batch_id in tqdm(range(eval_iter)):
|
||||||
|
batch = next(eval_batchs)
|
||||||
|
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
|
||||||
|
res = model.forward_joint_meta(batch, ment_update_iter, learning_rate, eval_mode)
|
||||||
|
eval_loss += res['loss']
|
||||||
|
|
||||||
|
if eval_mode == 'test-twostage':
|
||||||
|
batch["query"]["span_indices"] = res['pred_spans']
|
||||||
|
batch["query"]["span_mask"] = res['pred_masks']
|
||||||
|
ment_metric_logs, ment_logs = model.seqment_eval(res["ment_preds"], batch["query"], model.ment_idx2label, model.schema)
|
||||||
|
type_metric_logs, type_logs = model.greedy_eval(res['type_logits'], batch['query'], overlap=overlap, threshold=threshold)
|
||||||
|
|
||||||
|
tot_seq_cnt += batch['query']['word'].size(0)
|
||||||
|
for k, v in ment_logs.items():
|
||||||
|
if k not in tot_ment_logs:
|
||||||
|
tot_ment_logs[k] = []
|
||||||
|
tot_ment_logs[k] += v
|
||||||
|
for k, v in type_logs.items():
|
||||||
|
if k not in tot_type_logs:
|
||||||
|
tot_type_logs[k] = []
|
||||||
|
tot_type_logs[k] += v
|
||||||
|
|
||||||
|
for k, v in ment_metric_logs.items():
|
||||||
|
tot_ment_metric_logs[k] += v
|
||||||
|
for k, v in type_metric_logs.items():
|
||||||
|
tot_type_metric_logs[k] += v
|
||||||
|
|
||||||
|
ment_p, ment_r, ment_f1 = cal_prf(tot_ment_metric_logs["ment_hit_cnt"], tot_ment_metric_logs["ment_pred_cnt"], tot_ment_metric_logs["ment_gold_cnt"])
|
||||||
|
print("seq num:", tot_seq_cnt, "hit cnt:", tot_ment_metric_logs["ment_hit_cnt"], "pred cnt:", tot_ment_metric_logs["ment_pred_cnt"], "gold cnt:", tot_ment_metric_logs["ment_gold_cnt"])
|
||||||
|
print("avg hit:", tot_ment_metric_logs["ment_hit_cnt"] / tot_seq_cnt, "avg pred:", tot_ment_metric_logs["ment_pred_cnt"] / tot_seq_cnt, "avg gold:", tot_ment_metric_logs["ment_gold_cnt"] / tot_seq_cnt)
|
||||||
|
ent_p, ent_r, ent_f1 = cal_prf(tot_type_metric_logs["ent_hit_cnt"], tot_type_metric_logs["ent_pred_cnt"], tot_type_metric_logs["ent_gold_cnt"])
|
||||||
|
model.train()
|
||||||
|
return eval_loss / eval_iter, ment_p, ment_r, ment_f1, ent_p, ent_r, ent_f1, tot_ment_logs, tot_type_logs
|
||||||
|
|
||||||
|
|
||||||
|
def train(self, model, training_args, device, trainloader, devloader, load_ckpt=None, ignore_log=True, dev_pred_fn=None, dev_log_fn=None):
|
||||||
|
if load_ckpt is not None:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
print("load ckpt from {}".format(load_ckpt))
|
||||||
|
# Init optimizer
|
||||||
|
print('Use bert optim!')
|
||||||
|
parameters_to_optimize = list(filter(lambda x: x[1].requires_grad, model.named_parameters()))
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
|
||||||
|
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(parameters_groups)
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
|
||||||
|
num_training_steps=training_args.train_iter)
|
||||||
|
model.train()
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
best_f1 = 0.0
|
||||||
|
train_loss = 0.0
|
||||||
|
train_ment_loss = 0.0
|
||||||
|
train_type_loss = 0.0
|
||||||
|
train_type_acc = 0
|
||||||
|
train_ment_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
it = 0
|
||||||
|
train_batchs = iter(trainloader)
|
||||||
|
for _ in range(training_args.train_iter):
|
||||||
|
it += 1
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model.train()
|
||||||
|
batch = next(train_batchs)
|
||||||
|
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
|
||||||
|
if training_args.use_maml:
|
||||||
|
progress = 1.0 * (it - 1) / training_args.train_iter
|
||||||
|
lr_inner = self.get_learning_rate(
|
||||||
|
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
|
||||||
|
)
|
||||||
|
res = model.forward_joint_meta(batch, training_args.train_inner_steps, lr_inner, "train")
|
||||||
|
|
||||||
|
for g in res['grads']:
|
||||||
|
model.load_gradients(res['names'], g)
|
||||||
|
train_loss += res['loss']
|
||||||
|
train_ment_loss += res['ment_loss']
|
||||||
|
train_type_loss += res['type_loss']
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
if training_args.max_grad_norm is not None:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
type_pred = torch.cat(res['type_preds'], dim=0).detach().cpu().numpy()
|
||||||
|
type_gold = torch.cat(res['type_golds'], dim=0).detach().cpu().numpy()
|
||||||
|
ment_pred = torch.cat(res['ment_preds'], dim=0).detach().cpu().numpy()
|
||||||
|
ment_gold = torch.cat(res['ment_golds'], dim=0).detach().cpu().numpy()
|
||||||
|
type_acc = model.span_accuracy(type_pred, type_gold)
|
||||||
|
ment_acc = model.span_accuracy(ment_pred, ment_gold)
|
||||||
|
train_type_acc += type_acc
|
||||||
|
train_ment_acc += ment_acc
|
||||||
|
|
||||||
|
iter_sample += 1
|
||||||
|
if not ignore_log:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
if it % 100 == 0 or it % training_args.log_steps == 0:
|
||||||
|
if not ignore_log:
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
print('step: {0:4} | loss: {1:2.6f} | ment loss: {2:2.6f} | type loss: {3:2.6f} | ment acc {4:.5f} | type acc {5:.5f}'
|
||||||
|
.format(it, train_loss / iter_sample, train_ment_loss / iter_sample, train_type_loss / iter_sample, train_ment_acc / iter_sample, train_type_acc / iter_sample) + '\r')
|
||||||
|
train_loss = 0
|
||||||
|
train_ment_loss = 0
|
||||||
|
train_type_loss = 0
|
||||||
|
train_ment_acc = 0
|
||||||
|
train_type_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
|
||||||
|
if it % training_args.val_steps == 0:
|
||||||
|
eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1, eval_ment_logs, eval_logs = self.eval(model, device, devloader, ment_update_iter=training_args.eval_ment_inner_steps, type_update_iter=training_args.eval_type_inner_steps, learning_rate=training_args.eval_inner_lr, eval_iter=training_args.dev_iter, eval_mode="test-twostage")
|
||||||
|
print('[EVAL] step: {0:4} | loss: {1:2.6f} | [MENTION] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f} [ENTITY] precision: {5:3.4f}, recall: {6:3.4f}, f1: {7:3.4f}'\
|
||||||
|
.format(it, eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1) + '\r')
|
||||||
|
if eval_f1 > best_f1:
|
||||||
|
print('Best checkpoint')
|
||||||
|
torch.save({'state_dict': model.state_dict()},
|
||||||
|
os.path.join(training_args.output_dir, "model.pth.tar"))
|
||||||
|
best_f1 = eval_f1
|
||||||
|
if dev_pred_fn is not None:
|
||||||
|
with open(dev_pred_fn, mode="w", encoding="utf-8") as fp:
|
||||||
|
json.dump({"ment_p": eval_ment_p, "ment_r": eval_ment_r, "ment_f1": eval_ment_f1, "precision": eval_p, "recall": eval_r, "f1": eval_f1}, fp)
|
||||||
|
|
||||||
|
eval_ent_log(devloader.dataset.samples, eval_logs)
|
||||||
|
if training_args.train_iter == it:
|
||||||
|
break
|
||||||
|
print("\n####################\n")
|
||||||
|
print("Finish training ")
|
||||||
|
return
|
|
@ -0,0 +1,192 @@
|
||||||
|
import json
|
||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
from collections import defaultdict
|
||||||
|
from util.log_utils import write_ep_ment_log_json, write_ment_log, eval_ment_log, cal_prf
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class MentTrainer:
|
||||||
|
def __init__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_model__(self, ckpt):
|
||||||
|
if os.path.isfile(ckpt):
|
||||||
|
checkpoint = torch.load(ckpt)
|
||||||
|
print("Successfully loaded checkpoint '%s'" % ckpt)
|
||||||
|
return checkpoint
|
||||||
|
else:
|
||||||
|
raise Exception("No checkpoint found at '%s'" % ckpt)
|
||||||
|
|
||||||
|
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
|
||||||
|
if schedule == "linear":
|
||||||
|
if progress < warmup:
|
||||||
|
lr *= progress / warmup
|
||||||
|
else:
|
||||||
|
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def eval(self, model, device, dataloader, load_ckpt=None, update_iter=0, learning_rate=3e-5, eval_iter=-1):
|
||||||
|
if load_ckpt:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
model.eval()
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
eval_loss = 0
|
||||||
|
tot_seq_cnt = 0
|
||||||
|
tot_logs = {}
|
||||||
|
eval_batchs = iter(dataloader)
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
eval_loss = 0
|
||||||
|
if eval_iter <= 0:
|
||||||
|
eval_iter = len(dataloader.dataset.sampler)
|
||||||
|
tot_seq_cnt = 0
|
||||||
|
tot_logs = {}
|
||||||
|
print("[eval] update {} steps | total {} episode".format(update_iter, eval_iter))
|
||||||
|
for batch_id in range(eval_iter):
|
||||||
|
batch = next(eval_batchs)
|
||||||
|
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
if update_iter > 0:
|
||||||
|
res = model.forward_meta(batch, update_iter, learning_rate, "test")
|
||||||
|
eval_loss += res['loss']
|
||||||
|
else:
|
||||||
|
res = model.forward_sup(batch, "test")
|
||||||
|
eval_loss += res['loss'].item()
|
||||||
|
metric_logs, logs = model.seqment_eval(res["preds"], batch["query"], model.idx2label, model.schema)
|
||||||
|
tot_seq_cnt += batch['query']['word'].size(0)
|
||||||
|
for k, v in logs.items():
|
||||||
|
if k not in tot_logs:
|
||||||
|
tot_logs[k] = []
|
||||||
|
tot_logs[k] += v
|
||||||
|
for k, v in metric_logs.items():
|
||||||
|
tot_metric_logs[k] += v
|
||||||
|
|
||||||
|
ment_p, ment_r, ment_f1 = cal_prf(tot_metric_logs["ment_hit_cnt"], tot_metric_logs["ment_pred_cnt"], tot_metric_logs["ment_gold_cnt"])
|
||||||
|
print("seq num:", tot_seq_cnt, "hit cnt:", tot_metric_logs["ment_hit_cnt"], "pred cnt:", tot_metric_logs["ment_pred_cnt"], "gold cnt:", tot_metric_logs["ment_gold_cnt"])
|
||||||
|
print("avg hit:", tot_metric_logs["ment_hit_cnt"] / tot_seq_cnt, "avg pred:", tot_metric_logs["ment_pred_cnt"] / tot_seq_cnt, "avg gold:", tot_metric_logs["ment_gold_cnt"] / tot_seq_cnt)
|
||||||
|
ent_p, ent_r, ent_f1 = 0, 0, 0
|
||||||
|
model.train()
|
||||||
|
return eval_loss / eval_iter, ment_p, ment_r, ment_f1, ent_p, ent_r, ent_f1, tot_logs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train(self, model, training_args, device, trainloader, devloader, eval_log_fn=None, load_ckpt=None, ignore_log=True):
|
||||||
|
if load_ckpt is not None:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
print("load ckpt from {}".format(load_ckpt))
|
||||||
|
# Init optimizer
|
||||||
|
print('Use bert optim!')
|
||||||
|
parameters_to_optimize = list(model.named_parameters())
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
|
||||||
|
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(parameters_groups)
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
|
||||||
|
num_training_steps=training_args.train_iter)
|
||||||
|
model.train()
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
best_f1 = 0.0
|
||||||
|
train_loss = 0.0
|
||||||
|
train_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
it = 0
|
||||||
|
train_batchs = iter(trainloader)
|
||||||
|
for _ in range(training_args.train_iter):
|
||||||
|
it += 1
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model.train()
|
||||||
|
batch = next(train_batchs)
|
||||||
|
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
|
||||||
|
if training_args.use_maml:
|
||||||
|
progress = 1.0 * (it - 1) / training_args.train_iter
|
||||||
|
lr_inner = self.get_learning_rate(
|
||||||
|
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
|
||||||
|
)
|
||||||
|
res = model.forward_meta(batch, training_args.train_inner_steps, lr_inner, "train")
|
||||||
|
for g in res['grads']:
|
||||||
|
model.load_gradients(res['names'], g)
|
||||||
|
train_loss += res['loss']
|
||||||
|
else:
|
||||||
|
res = model.forward_sup(batch, "train")
|
||||||
|
loss = res['loss']
|
||||||
|
loss.backward()
|
||||||
|
train_loss += loss.item()
|
||||||
|
if training_args.max_grad_norm is not None:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
pred = torch.cat(res['preds'], dim=0).detach().cpu().numpy()
|
||||||
|
gold = torch.cat(res['golds'], dim=0).detach().cpu().numpy()
|
||||||
|
|
||||||
|
acc = model.span_accuracy(pred, gold)
|
||||||
|
train_acc += acc
|
||||||
|
|
||||||
|
iter_sample += 1
|
||||||
|
if not ignore_log:
|
||||||
|
metric_logs, logs = model.seqment_eval(res["preds"], batch["query"], model.idx2label, model.schema)
|
||||||
|
|
||||||
|
for k, v in metric_logs.items():
|
||||||
|
tot_metric_logs[k] += v
|
||||||
|
|
||||||
|
|
||||||
|
if it % 100 == 0 or it % training_args.log_steps == 0:
|
||||||
|
if not ignore_log:
|
||||||
|
precision, recall, f1 = cal_prf(tot_metric_logs["ment_hit_cnt"], tot_metric_logs["ment_pred_cnt"],
|
||||||
|
tot_metric_logs["ment_gold_cnt"])
|
||||||
|
|
||||||
|
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f} [ENTITY] precision: {3:3.4f}, recall: {4:3.4f}, f1: {5:3.4f}'\
|
||||||
|
.format(it, train_loss / iter_sample, train_acc / iter_sample, precision, recall, f1) + '\r')
|
||||||
|
else:
|
||||||
|
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f}'
|
||||||
|
.format(it, train_loss / iter_sample, train_acc / iter_sample) + '\r')
|
||||||
|
train_loss = 0
|
||||||
|
train_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
|
||||||
|
if it % training_args.val_steps == 0:
|
||||||
|
eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1, eval_logs = self.eval(model, device, devloader, update_iter=training_args.eval_inner_steps, learning_rate=training_args.eval_inner_lr, eval_iter=training_args.dev_iter)
|
||||||
|
print('[EVAL] step: {0:4} | loss: {1:2.6f} | [MENTION] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f}'\
|
||||||
|
.format(it, eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1) + '\r')
|
||||||
|
if eval_ment_f1 > best_f1:
|
||||||
|
print('Best checkpoint')
|
||||||
|
torch.save({'state_dict': model.state_dict()},
|
||||||
|
os.path.join(training_args.output_dir, "model.pth.tar"))
|
||||||
|
best_f1 = eval_ment_f1
|
||||||
|
if eval_log_fn is not None:
|
||||||
|
write_ment_log(devloader.dataset.samples, eval_logs, eval_log_fn)
|
||||||
|
if training_args.train_iter == it:
|
||||||
|
break
|
||||||
|
print("\n####################\n")
|
||||||
|
print("Finish training ")
|
||||||
|
return
|
|
@ -0,0 +1,185 @@
|
||||||
|
import json
|
||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
from collections import defaultdict
|
||||||
|
from util.log_utils import write_pos_pred_json
|
||||||
|
|
||||||
|
class POSTrainer:
|
||||||
|
def __init__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_model__(self, ckpt):
|
||||||
|
if os.path.isfile(ckpt):
|
||||||
|
checkpoint = torch.load(ckpt)
|
||||||
|
print("Successfully loaded checkpoint '%s'" % ckpt)
|
||||||
|
return checkpoint
|
||||||
|
else:
|
||||||
|
raise Exception("No checkpoint found at '%s'" % ckpt)
|
||||||
|
|
||||||
|
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
|
||||||
|
if schedule == "linear":
|
||||||
|
if progress < warmup:
|
||||||
|
lr *= progress / warmup
|
||||||
|
else:
|
||||||
|
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def eval(self, model, device, dataloader, load_ckpt=None, update_iter=0, learning_rate=3e-5, eval_iter=-1, eval_mode="test"):
|
||||||
|
if load_ckpt:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
model.eval()
|
||||||
|
eval_batchs = iter(dataloader)
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
eval_loss = 0
|
||||||
|
if eval_iter <= 0:
|
||||||
|
eval_iter = len(dataloader.dataset.sampler)
|
||||||
|
tot_seq_cnt = 0
|
||||||
|
tot_logs = {}
|
||||||
|
for batch_id in tqdm(range(eval_iter)):
|
||||||
|
batch = next(eval_batchs)
|
||||||
|
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
if update_iter > 0:
|
||||||
|
res = model.forward_meta(batch, update_iter, learning_rate, eval_mode)
|
||||||
|
eval_loss += res['loss']
|
||||||
|
else:
|
||||||
|
res = model.forward_proto(batch, eval_mode)
|
||||||
|
eval_loss += res['loss'].item()
|
||||||
|
metric_logs, logs = model.seq_eval(res['preds'], batch['query'])
|
||||||
|
tot_seq_cnt += batch["query"]["seq_len"].size(0)
|
||||||
|
logs["support_index"] = batch["support"]["index"]
|
||||||
|
logs["support_sentence_num"] = batch["support"]["sentence_num"]
|
||||||
|
logs["support_subsentence_num"] = batch["support"]["subsentence_num"]
|
||||||
|
|
||||||
|
for k, v in logs.items():
|
||||||
|
if k not in tot_logs:
|
||||||
|
tot_logs[k] = []
|
||||||
|
tot_logs[k] += v
|
||||||
|
for k, v in metric_logs.items():
|
||||||
|
tot_metric_logs[k] += v
|
||||||
|
token_acc = tot_metric_logs["hit_cnt"] / tot_metric_logs["gold_cnt"]
|
||||||
|
print("seq num:", tot_seq_cnt, "hit cnt:", tot_metric_logs["hit_cnt"], "gold cnt:", tot_metric_logs["gold_cnt"], "acc:", token_acc)
|
||||||
|
model.train()
|
||||||
|
return eval_loss / eval_iter, token_acc, tot_logs
|
||||||
|
|
||||||
|
def train(self, model, training_args, device, trainloader, devloader, load_ckpt=None, dev_pred_fn=None, dev_log_fn=None, ignore_log=True):
|
||||||
|
if load_ckpt is not None:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
print("load ckpt from {}".format(load_ckpt))
|
||||||
|
|
||||||
|
# Init optimizer
|
||||||
|
print('Use bert optim!')
|
||||||
|
parameters_to_optimize = list(model.named_parameters())
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
|
||||||
|
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(parameters_groups)
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
|
||||||
|
num_training_steps=training_args.train_iter)
|
||||||
|
model.train()
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
best_f1 = -1
|
||||||
|
train_loss = 0.0
|
||||||
|
train_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
it = 0
|
||||||
|
train_batchs = iter(trainloader)
|
||||||
|
for _ in range(training_args.train_iter):
|
||||||
|
it += 1
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model.train()
|
||||||
|
batch = next(train_batchs)
|
||||||
|
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
|
||||||
|
if training_args.use_maml:
|
||||||
|
progress = 1.0 * (it - 1) / training_args.train_iter
|
||||||
|
lr_inner = self.get_learning_rate(
|
||||||
|
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
|
||||||
|
)
|
||||||
|
res = model.forward_meta(batch, training_args.train_inner_steps, lr_inner, "train")
|
||||||
|
for g in res['grads']:
|
||||||
|
model.load_gradients(res['names'], g) # loss backward
|
||||||
|
train_loss += res['loss']
|
||||||
|
else:
|
||||||
|
res = model.forward_proto(batch, "train")
|
||||||
|
loss = res['loss']
|
||||||
|
loss.backward()
|
||||||
|
train_loss += loss.item()
|
||||||
|
|
||||||
|
if training_args.max_grad_norm is not None:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
pred = torch.cat(res['preds'], dim=0).detach().cpu().numpy()
|
||||||
|
gold = torch.cat(res['golds'], dim=0).detach().cpu().numpy()
|
||||||
|
acc = model.span_accuracy(pred, gold)
|
||||||
|
train_acc += acc
|
||||||
|
iter_sample += 1
|
||||||
|
|
||||||
|
if not ignore_log:
|
||||||
|
metric_logs, logs = model.seq_eval(res["preds"], batch["query"], model.schema)
|
||||||
|
for k, v in metric_logs.items():
|
||||||
|
tot_metric_logs[k] += v
|
||||||
|
|
||||||
|
if it % 100 == 0 or it % training_args.log_steps == 0:
|
||||||
|
if not ignore_log:
|
||||||
|
acc = tot_metric_logs["hit_cnt"] / tot_metric_logs["gold_cnt"]
|
||||||
|
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f} [Token] acc: {3:3.4f}'\
|
||||||
|
.format(it, train_loss / iter_sample, train_acc / iter_sample, acc) + '\r')
|
||||||
|
else:
|
||||||
|
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f}'\
|
||||||
|
.format(it, train_loss / iter_sample, train_acc / iter_sample) + '\r')
|
||||||
|
train_loss = 0
|
||||||
|
train_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
|
||||||
|
if it % training_args.val_steps == 0:
|
||||||
|
eval_loss, eval_acc, eval_logs = self.eval(model, device, devloader, eval_iter=training_args.dev_iter,
|
||||||
|
update_iter=training_args.eval_inner_steps, learning_rate=training_args.eval_inner_lr, eval_mode=training_args.train_eval_mode)
|
||||||
|
print('[EVAL] step: {0:4} | loss: {1:2.6f} | F1: {2:3.4f}'\
|
||||||
|
.format(it, eval_loss, eval_acc) + '\r')
|
||||||
|
|
||||||
|
if eval_acc > best_f1:
|
||||||
|
print('Best checkpoint')
|
||||||
|
torch.save({'state_dict': model.state_dict()},
|
||||||
|
os.path.join(training_args.output_dir, "model.pth.tar"))
|
||||||
|
best_f1 = eval_acc
|
||||||
|
if dev_pred_fn is not None:
|
||||||
|
write_pos_pred_json(devloader.dataset.samples, eval_logs, dev_pred_fn)
|
||||||
|
print("Best acc: {}".format(best_f1))
|
||||||
|
|
||||||
|
if it > 1003:
|
||||||
|
break
|
||||||
|
print("\n####################\n")
|
||||||
|
print("Finish training ")
|
||||||
|
return
|
|
@ -0,0 +1,204 @@
|
||||||
|
import json
|
||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
from collections import defaultdict
|
||||||
|
from util.log_utils import write_ent_log, write_ent_pred_json, eval_ent_log, cal_prf
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class SpanTrainer:
|
||||||
|
def __init__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_model__(self, ckpt):
|
||||||
|
if os.path.isfile(ckpt):
|
||||||
|
checkpoint = torch.load(ckpt)
|
||||||
|
print("Successfully loaded checkpoint '%s'" % ckpt)
|
||||||
|
return checkpoint
|
||||||
|
else:
|
||||||
|
raise Exception("No checkpoint found at '%s'" % ckpt)
|
||||||
|
|
||||||
|
def check_input_ment(self, query):
|
||||||
|
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
|
||||||
|
for i in range(len(query['spans'])):
|
||||||
|
pred_ments = query['span_indices'][i][query['span_mask'][i].eq(1)].detach().cpu().tolist()
|
||||||
|
gold_ments = [x[1:] for x in query['spans'][i]]
|
||||||
|
pred_cnt += len(pred_ments)
|
||||||
|
gold_cnt += len(gold_ments)
|
||||||
|
for x in gold_ments:
|
||||||
|
if x in pred_ments:
|
||||||
|
hit_cnt += 1
|
||||||
|
return pred_cnt, gold_cnt, hit_cnt
|
||||||
|
|
||||||
|
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
|
||||||
|
if schedule == "linear":
|
||||||
|
if progress < warmup:
|
||||||
|
lr *= progress / warmup
|
||||||
|
else:
|
||||||
|
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def eval(self, model, device, dataloader, load_ckpt=None, update_iter=0, learning_rate=3e-5, eval_iter=-1, overlap=False, threshold=-1):
|
||||||
|
if load_ckpt:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
model.eval()
|
||||||
|
eval_batchs = iter(dataloader)
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
eval_loss = 0
|
||||||
|
if eval_iter <= 0:
|
||||||
|
eval_iter = len(dataloader.dataset.sampler)
|
||||||
|
tot_seq_cnt = 0
|
||||||
|
tot_logs = {}
|
||||||
|
for batch_id in range(eval_iter):
|
||||||
|
batch = next(eval_batchs)
|
||||||
|
|
||||||
|
input_pred_ment_cnt, input_gold_ment_cnt, input_hit_ment_cnt = self.check_input_ment(batch['query'])
|
||||||
|
tot_metric_logs['episode_query_ment_hit_cnt'] += input_hit_ment_cnt
|
||||||
|
tot_metric_logs['episode_query_ment_pred_cnt'] += input_pred_ment_cnt
|
||||||
|
tot_metric_logs['episode_query_ment_gold_cnt'] += input_gold_ment_cnt
|
||||||
|
|
||||||
|
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
if update_iter > 0:
|
||||||
|
res = model.forward_meta(batch, update_iter, learning_rate, "test")
|
||||||
|
eval_loss += res['loss']
|
||||||
|
else:
|
||||||
|
res = model.forward_proto(batch)
|
||||||
|
eval_loss += res['loss'].item()
|
||||||
|
metric_logs, logs = model.greedy_eval(res['logits'], batch['query'], overlap=overlap, threshold=threshold)
|
||||||
|
tot_seq_cnt += batch["query"]["seq_len"].size(0)
|
||||||
|
logs["support_index"] = batch["support"]["index"]
|
||||||
|
logs["support_sentence_num"] = batch["support"]["sentence_num"]
|
||||||
|
logs["support_subsentence_num"] = batch["support"]["subsentence_num"]
|
||||||
|
|
||||||
|
for k, v in logs.items():
|
||||||
|
if k not in tot_logs:
|
||||||
|
tot_logs[k] = []
|
||||||
|
tot_logs[k] += v
|
||||||
|
for k, v in metric_logs.items():
|
||||||
|
tot_metric_logs[k] += v
|
||||||
|
|
||||||
|
ment_p, ment_r, ment_f1 = cal_prf(tot_metric_logs["ment_hit_cnt"], tot_metric_logs["ment_pred_cnt"], tot_metric_logs["ment_gold_cnt"])
|
||||||
|
print("seq num:", tot_seq_cnt, "hit cnt:", tot_metric_logs["ent_hit_cnt"], "pred cnt:", tot_metric_logs["ent_pred_cnt"], "gold cnt:", tot_metric_logs["ent_gold_cnt"])
|
||||||
|
print(tot_metric_logs["ent_hit_cnt"] / tot_seq_cnt, tot_metric_logs["ent_pred_cnt"] / tot_seq_cnt, tot_metric_logs["ent_gold_cnt"] / tot_seq_cnt)
|
||||||
|
ent_p, ent_r, ent_f1 = cal_prf(tot_metric_logs["ent_hit_cnt"], tot_metric_logs["ent_pred_cnt"], tot_metric_logs["ent_gold_cnt"])
|
||||||
|
input_ment_p, input_ment_r, input_ment_f1 = cal_prf(tot_metric_logs["episode_query_ment_hit_cnt"], tot_metric_logs["episode_query_ment_pred_cnt"], tot_metric_logs["episode_query_ment_gold_cnt"])
|
||||||
|
print("episode based input mention precision {:.5f} recall {:.5f} f1 {:.5f}".format(input_ment_p, input_ment_r, input_ment_f1))
|
||||||
|
model.train()
|
||||||
|
return eval_loss / eval_iter, ment_p, ment_r, ment_f1, ent_p, ent_r, ent_f1, tot_logs
|
||||||
|
|
||||||
|
def train(self, model, training_args, device, trainloader, devloader, load_ckpt=None, dev_pred_fn=None, dev_log_fn=None):
|
||||||
|
if load_ckpt is not None:
|
||||||
|
state_dict = self.__load_model__(load_ckpt)['state_dict']
|
||||||
|
own_state = model.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name not in own_state:
|
||||||
|
print('[ERROR] Ignore {}'.format(name))
|
||||||
|
continue
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
print("load ckpt from {}".format(load_ckpt))
|
||||||
|
# Init optimizer
|
||||||
|
print('Use bert optim!')
|
||||||
|
parameters_to_optimize = list(model.named_parameters())
|
||||||
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
|
||||||
|
parameters_groups = [
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
|
||||||
|
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
|
||||||
|
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
|
||||||
|
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(parameters_groups)
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
|
||||||
|
num_training_steps=training_args.train_iter)
|
||||||
|
model.train()
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
best_f1 = -1
|
||||||
|
train_loss = 0.0
|
||||||
|
train_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
it = 0
|
||||||
|
train_batchs = iter(trainloader)
|
||||||
|
for _ in range(training_args.train_iter):
|
||||||
|
it += 1
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model.train()
|
||||||
|
batch = next(train_batchs)
|
||||||
|
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
|
||||||
|
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
|
||||||
|
|
||||||
|
if training_args.use_maml:
|
||||||
|
progress = 1.0 * (it - 1) / training_args.train_iter
|
||||||
|
lr_inner = self.get_learning_rate(
|
||||||
|
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
|
||||||
|
)
|
||||||
|
res = model.forward_meta(batch, training_args.train_inner_steps, lr_inner, "train")
|
||||||
|
for g in res['grads']:
|
||||||
|
model.load_gradients(res['names'], g) # loss backward
|
||||||
|
train_loss += res['loss']
|
||||||
|
else:
|
||||||
|
res = model.forward_proto(batch)
|
||||||
|
loss = res['loss']
|
||||||
|
loss.backward()
|
||||||
|
train_loss += loss.item()
|
||||||
|
|
||||||
|
if training_args.max_grad_norm is not None:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
pred = torch.cat(res['preds'], dim=0).detach().cpu().numpy()
|
||||||
|
gold = torch.cat(res['golds'], dim=0).detach().cpu().numpy()
|
||||||
|
acc = model.span_accuracy(pred, gold)
|
||||||
|
train_acc += acc
|
||||||
|
|
||||||
|
iter_sample += 1
|
||||||
|
metric_logs, logs = model.greedy_eval(res["logits"], batch["query"], overlap=training_args.overlap, threshold=training_args.type_threshold)
|
||||||
|
|
||||||
|
for k, v in metric_logs.items():
|
||||||
|
tot_metric_logs[k] += v
|
||||||
|
|
||||||
|
if it % 100 == 0 or it % training_args.log_steps == 0:
|
||||||
|
precision, recall, f1 = cal_prf(tot_metric_logs["ent_hit_cnt"], tot_metric_logs["ent_pred_cnt"],
|
||||||
|
tot_metric_logs["ent_gold_cnt"])
|
||||||
|
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f} [ENTITY] precision: {3:3.4f}, recall: {4:3.4f}, f1: {5:3.4f}'\
|
||||||
|
.format(it, train_loss / iter_sample, train_acc / iter_sample, precision, recall, f1) + '\r')
|
||||||
|
train_loss = 0
|
||||||
|
train_acc = 0
|
||||||
|
iter_sample = 0
|
||||||
|
tot_metric_logs = defaultdict(int)
|
||||||
|
|
||||||
|
if it % training_args.val_steps == 0:
|
||||||
|
eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1, eval_logs = self.eval(model, device, devloader, eval_iter=training_args.dev_iter,
|
||||||
|
update_iter=training_args.eval_inner_steps, learning_rate=training_args.eval_inner_lr, overlap=training_args.overlap, threshold=training_args.type_threshold)
|
||||||
|
print('[EVAL] step: {0:4} | loss: {1:2.6f} | [MENTION] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f} [ENTITY] precision: {5:3.4f}, recall: {6:3.4f}, f1: {7:3.4f}'\
|
||||||
|
.format(it, eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1) + '\r')
|
||||||
|
if eval_f1 > best_f1:
|
||||||
|
print('Best checkpoint')
|
||||||
|
torch.save({'state_dict': model.state_dict()},
|
||||||
|
os.path.join(training_args.output_dir, "model.pth.tar"))
|
||||||
|
best_f1 = eval_f1
|
||||||
|
if dev_pred_fn is not None:
|
||||||
|
write_ent_pred_json(devloader.dataset.samples, eval_logs, dev_pred_fn)
|
||||||
|
if dev_log_fn is not None:
|
||||||
|
write_ent_log(devloader.dataset.samples, eval_logs, dev_log_fn)
|
||||||
|
eval_ent_log(devloader.dataset.samples, eval_logs)
|
||||||
|
|
||||||
|
print("\n####################\n")
|
||||||
|
print("Finish training ")
|
||||||
|
return
|
|
@ -0,0 +1,30 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class ProjAdapter(nn.Module):
|
||||||
|
def __init__(self, in_size, hidden_size):
|
||||||
|
super().__init__()
|
||||||
|
self.w0 = nn.Linear(in_size, hidden_size, bias=True)
|
||||||
|
self.w1 = nn.Linear(hidden_size, in_size, bias=True)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(self, input_hidden_states):
|
||||||
|
hidden_states = self.act(self.w0(input_hidden_states))
|
||||||
|
hidden_states = self.w1(hidden_states)
|
||||||
|
return hidden_states + input_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterStack(nn.Module):
|
||||||
|
def __init__(self, adapter_hidden_size, num_hidden_layers=12):
|
||||||
|
super().__init__()
|
||||||
|
self.slf_list = nn.ModuleList([ProjAdapter(768, adapter_hidden_size) for _ in range(num_hidden_layers)])
|
||||||
|
self.ffn_list = nn.ModuleList([ProjAdapter(768, adapter_hidden_size) for _ in range(num_hidden_layers)])
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(self, inputs, layer_id, name):
|
||||||
|
if name == "slf":
|
||||||
|
outs = self.slf_list[layer_id](inputs)
|
||||||
|
else:
|
||||||
|
assert name == "ffn"
|
||||||
|
outs = self.ffn_list[layer_id](inputs)
|
||||||
|
return outs
|
|
@ -0,0 +1,783 @@
|
||||||
|
'''
|
||||||
|
The code is adapted from https://github.com/huggingface/transformers/blob/v4.2.1/src/transformers/models/bert/modeling_bert.py
|
||||||
|
'''
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from packaging import version
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from transformers import BertPreTrainedModel
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
#####################for 4.2.1##########################
|
||||||
|
from transformers.modeling_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
# from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
|
||||||
|
from transformers.file_utils import (
|
||||||
|
ModelOutput,
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers import BertConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
||||||
|
_CONFIG_FOR_DOC = "BertConfig"
|
||||||
|
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||||
|
|
||||||
|
# TokenClassification docstring
|
||||||
|
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
|
||||||
|
_TOKEN_CLASS_EXPECTED_OUTPUT = (
|
||||||
|
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
|
||||||
|
)
|
||||||
|
_TOKEN_CLASS_EXPECTED_LOSS = 0.01
|
||||||
|
|
||||||
|
# QuestionAnswering docstring
|
||||||
|
_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
|
||||||
|
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
|
||||||
|
_QA_EXPECTED_LOSS = 7.41
|
||||||
|
_QA_TARGET_START_INDEX = 14
|
||||||
|
_QA_TARGET_END_INDEX = 15
|
||||||
|
|
||||||
|
# SequenceClassification docstring
|
||||||
|
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
|
||||||
|
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
||||||
|
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
BERT_START_DOCSTRING = r"""
|
||||||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
|
etc.)
|
||||||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||||
|
and behavior.
|
||||||
|
Parameters:
|
||||||
|
config ([`BertConfig`]): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BERT_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `({0})`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
||||||
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
||||||
|
1]`:
|
||||||
|
- 0 corresponds to a *sentence A* token,
|
||||||
|
- 1 corresponds to a *sentence B* token.
|
||||||
|
[What are token type IDs?](../glossary#token-type-ids)
|
||||||
|
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.max_position_embeddings - 1]`.
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||||
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
|
model's internal embedding lookup matrix.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BertEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||||
|
|
||||||
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||||
|
# any TensorFlow checkpoint file
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
|
if input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
else:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
|
embeddings = inputs_embeds + token_type_embeddings
|
||||||
|
if self.position_embedding_type == "absolute":
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
embeddings += position_embeddings
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BertPooler(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
|
# to the first token.
|
||||||
|
first_token_tensor = hidden_states[:, 0]
|
||||||
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
class BertSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
|
raise ValueError(
|
||||||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
seq_length = hidden_states.size()[1]
|
||||||
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||||
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||||
|
distance = position_ids_l - position_ids_r
|
||||||
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||||
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key":
|
||||||
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores
|
||||||
|
elif self.position_embedding_type == "relative_key_query":
|
||||||
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||||
|
|
||||||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class BertSelfOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, adapters=None, adapter_id=None) -> torch.Tensor:
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
if (adapters is not None):
|
||||||
|
hidden_states = adapters(hidden_states, layer_id=adapter_id, name="slf")
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BertAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.self = BertSelfAttention(config)
|
||||||
|
self.output = BertSelfOutput(config)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prune linear layers
|
||||||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
|
|
||||||
|
# Update hyper params and store pruned heads
|
||||||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
adapters=None, adapter_id=None
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
self_outputs = self.self(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = self.output(self_outputs[0], hidden_states, adapters=adapters, adapter_id=adapter_id)
|
||||||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class BertIntermediate(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class BertOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, adapters=None, adapter_id=None) -> torch.Tensor:
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
if (adapters is not None):
|
||||||
|
hidden_states = adapters(hidden_states, layer_id=adapter_id, name="ffn")
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class BertLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
self.attention = BertAttention(config)
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
||||||
|
self.crossattention = BertAttention(config)
|
||||||
|
self.intermediate = BertIntermediate(config)
|
||||||
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
adapters=None,
|
||||||
|
adapter_id=None,
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
|
adapters=adapters,
|
||||||
|
adapter_id=adapter_id
|
||||||
|
)
|
||||||
|
attention_output = self_attention_outputs[0]
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
assert hasattr(
|
||||||
|
self, "crossattention"
|
||||||
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
attention_output,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
layer_output = apply_chunking_to_forward(
|
||||||
|
lambda x: self.feed_forward_chunk(x, adapters, adapter_id), self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
|
)
|
||||||
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def feed_forward_chunk(self, attention_output, adapters=None, adapter_id=None):
|
||||||
|
intermediate_output = self.intermediate(attention_output)
|
||||||
|
layer_output = self.output(intermediate_output, attention_output, adapters=adapters, adapter_id=adapter_id)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
|
class BertEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
output_hidden_states: Optional[bool] = False,
|
||||||
|
return_dict: Optional[bool] = True,
|
||||||
|
adapters=None,
|
||||||
|
adapter_layer_ids=None,
|
||||||
|
input_bottom_hiddens=None
|
||||||
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
if input_bottom_hiddens is None:
|
||||||
|
start_layer_i = 0
|
||||||
|
else:
|
||||||
|
start_layer_i = adapter_layer_ids[0]
|
||||||
|
hidden_states = input_bottom_hiddens
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if i < start_layer_i:
|
||||||
|
continue
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
if adapter_layer_ids is None:
|
||||||
|
cur_adapters = adapters
|
||||||
|
cur_i = i
|
||||||
|
elif i not in adapter_layer_ids:
|
||||||
|
cur_adapters = None
|
||||||
|
cur_i = i
|
||||||
|
else:
|
||||||
|
cur_adapters = adapters
|
||||||
|
cur_i = adapter_layer_ids.index(i)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(layer_module),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
cur_adapters,
|
||||||
|
cur_i
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
cur_adapters,
|
||||||
|
cur_i
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class BertModel(BertPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||||
|
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
||||||
|
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||||
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||||
|
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||||
|
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, add_pooling_layer=True):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = BertEmbeddings(config)
|
||||||
|
self.encoder = BertEncoder(config)
|
||||||
|
|
||||||
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
"""
|
||||||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||||
|
class PreTrainedModel
|
||||||
|
"""
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
checkpoint="bert-base-uncased",
|
||||||
|
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
adapters=None,
|
||||||
|
adapter_layer_ids=None,
|
||||||
|
input_bottom_hiddens=None,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
elif input_bottom_hiddens is not None:
|
||||||
|
input_shape = input_bottom_hiddens.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||||
|
|
||||||
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
|
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
if input_bottom_hiddens is None:
|
||||||
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
adapters=adapters,
|
||||||
|
adapter_layer_ids=adapter_layer_ids,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
None,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
adapters=adapters,
|
||||||
|
adapter_layer_ids=adapter_layer_ids,
|
||||||
|
input_bottom_hiddens=input_bottom_hiddens
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
pooler_output=pooled_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
|
)
|
|
@ -0,0 +1,161 @@
|
||||||
|
|
||||||
|
'''
|
||||||
|
The code is adapted from https://github.com/thunlp/Few-NERD/blob/main/util/fewshotsampler.py
|
||||||
|
'''
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
class FewshotSampleBase:
|
||||||
|
'''
|
||||||
|
Abstract Class
|
||||||
|
DO NOT USE
|
||||||
|
Build your own Sample class and inherit from this class
|
||||||
|
'''
|
||||||
|
def __init__(self):
|
||||||
|
self.class_count = {}
|
||||||
|
|
||||||
|
def get_class_count(self):
|
||||||
|
'''
|
||||||
|
return a dictionary of {class_name:count} in format {any : int}
|
||||||
|
'''
|
||||||
|
return self.class_count
|
||||||
|
|
||||||
|
|
||||||
|
class FewshotSampler:
|
||||||
|
'''
|
||||||
|
sample one support set and one query set
|
||||||
|
'''
|
||||||
|
def __init__(self, N, K, Q, samples, classes=None, random_state=0):
|
||||||
|
'''
|
||||||
|
N: int, how many types in each set
|
||||||
|
K: int, how many instances for each type in support set
|
||||||
|
Q: int, how many instances for each type in query set
|
||||||
|
samples: List[Sample], Sample class must have `get_class_count` attribute
|
||||||
|
classes[Optional]: List[any], all unique classes in samples. If not given, the classes will be got from samples.get_class_count()
|
||||||
|
random_state[Optional]: int, the random seed
|
||||||
|
'''
|
||||||
|
self.K = K
|
||||||
|
self.N = N
|
||||||
|
self.Q = Q
|
||||||
|
self.samples = samples
|
||||||
|
self.__check__() # check if samples have correct types
|
||||||
|
if classes:
|
||||||
|
self.classes = classes
|
||||||
|
else:
|
||||||
|
self.classes = self.__get_all_classes__()
|
||||||
|
random.seed(random_state)
|
||||||
|
|
||||||
|
def __get_all_classes__(self):
|
||||||
|
classes = []
|
||||||
|
for sample in self.samples:
|
||||||
|
classes += list(sample.get_class_count().keys())
|
||||||
|
return list(set(classes))
|
||||||
|
|
||||||
|
def __check__(self):
|
||||||
|
for idx, sample in enumerate(self.samples):
|
||||||
|
if not hasattr(sample,'get_class_count'):
|
||||||
|
print('[ERROR] samples in self.samples expected to have `get_class_count` attribute, but self.samples[{idx}] does not')
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def __additem__(self, index, set_class):
|
||||||
|
class_count = self.samples[index].get_class_count()
|
||||||
|
for class_name in class_count:
|
||||||
|
if class_name in set_class:
|
||||||
|
set_class[class_name] += class_count[class_name]
|
||||||
|
else:
|
||||||
|
set_class[class_name] = class_count[class_name]
|
||||||
|
|
||||||
|
def __valid_sample__(self, sample, set_class, target_classes):
|
||||||
|
threshold = 2 * set_class['k']
|
||||||
|
class_count = sample.get_class_count()
|
||||||
|
if not class_count:
|
||||||
|
return False
|
||||||
|
isvalid = False
|
||||||
|
for class_name in class_count:
|
||||||
|
if class_name not in target_classes:
|
||||||
|
return False
|
||||||
|
if class_count[class_name] + set_class.get(class_name, 0) > threshold:
|
||||||
|
return False
|
||||||
|
if set_class.get(class_name, 0) < set_class['k']:
|
||||||
|
isvalid = True
|
||||||
|
return isvalid
|
||||||
|
|
||||||
|
def __finish__(self, set_class):
|
||||||
|
if len(set_class) < self.N+1:
|
||||||
|
return False
|
||||||
|
for k in set_class:
|
||||||
|
if set_class[k] < set_class['k']:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __get_candidates__(self, target_classes):
|
||||||
|
return [idx for idx, sample in enumerate(self.samples) if sample.valid(target_classes)]
|
||||||
|
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
'''
|
||||||
|
randomly sample one support set and one query set
|
||||||
|
return:
|
||||||
|
target_classes: List[any]
|
||||||
|
support_idx: List[int], sample index in support set in samples list
|
||||||
|
support_idx: List[int], sample index in query set in samples list
|
||||||
|
'''
|
||||||
|
support_class = {'k':self.K}
|
||||||
|
support_idx = []
|
||||||
|
query_class = {'k':self.Q}
|
||||||
|
query_idx = []
|
||||||
|
target_classes = random.sample(self.classes, self.N)
|
||||||
|
candidates = self.__get_candidates__(target_classes)
|
||||||
|
while not candidates:
|
||||||
|
target_classes = random.sample(self.classes, self.N)
|
||||||
|
candidates = self.__get_candidates__(target_classes)
|
||||||
|
|
||||||
|
# greedy search for support set
|
||||||
|
while not self.__finish__(support_class):
|
||||||
|
index = random.choice(candidates)
|
||||||
|
if index not in support_idx:
|
||||||
|
if self.__valid_sample__(self.samples[index], support_class, target_classes):
|
||||||
|
self.__additem__(index, support_class)
|
||||||
|
support_idx.append(index)
|
||||||
|
# same for query set
|
||||||
|
while not self.__finish__(query_class):
|
||||||
|
index = random.choice(candidates)
|
||||||
|
if index not in query_idx and index not in support_idx:
|
||||||
|
if self.__valid_sample__(self.samples[index], query_class, target_classes):
|
||||||
|
self.__additem__(index, query_class)
|
||||||
|
query_idx.append(index)
|
||||||
|
return target_classes, support_idx, query_idx
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
class DebugSampler:
|
||||||
|
def __init__(self, filepath, querypath=None):
|
||||||
|
data = []
|
||||||
|
with open(filepath, mode="r", encoding="utf-8") as fp:
|
||||||
|
for line in fp:
|
||||||
|
data.append(json.loads(line))
|
||||||
|
self.data = data
|
||||||
|
if querypath is not None:
|
||||||
|
with open(querypath, mode="r", encoding="utf-8") as fp:
|
||||||
|
self.ment_data = [json.loads(line) for line in fp]
|
||||||
|
if len(self.ment_data) != len(self.data):
|
||||||
|
print("the mention data len is different with input episode number!!!!!!!!!!!!!!!")
|
||||||
|
else:
|
||||||
|
self.ment_data = None
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def __next__(self, idx):
|
||||||
|
idx = idx % len(self.data)
|
||||||
|
target_classes = self.data[idx]["target_classes"]
|
||||||
|
support_idx = self.data[idx]["support_idx"]
|
||||||
|
query_idx = self.data[idx]["query_idx"]
|
||||||
|
if self.ment_data is None:
|
||||||
|
return target_classes, support_idx, query_idx
|
||||||
|
return target_classes, support_idx, query_idx, self.ment_data[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
|
@ -0,0 +1,351 @@
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import os
|
||||||
|
from .fewshotsampler import DebugSampler, FewshotSampleBase
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from .span_sample import SpanSample
|
||||||
|
from .span_loader import QuerySpanBatcher
|
||||||
|
|
||||||
|
|
||||||
|
class JointBatcher(QuerySpanBatcher):
|
||||||
|
def __init__(self, iou_thred, use_oproto):
|
||||||
|
super().__init__(iou_thred, use_oproto)
|
||||||
|
return
|
||||||
|
|
||||||
|
def batchnize_episode(self, data, mode):
|
||||||
|
support_sets, query_sets = zip(*data)
|
||||||
|
if mode == "train":
|
||||||
|
is_train = True
|
||||||
|
else:
|
||||||
|
is_train = False
|
||||||
|
batch_support = self.batchnize_sent(support_sets, "support", is_train)
|
||||||
|
batch_query = self.batchnize_sent(query_sets, "query", is_train)
|
||||||
|
batch_query["label2tag"] = []
|
||||||
|
for i in range(len(query_sets)):
|
||||||
|
batch_query["label2tag"].append(query_sets[i]["label2tag"])
|
||||||
|
return {"support": batch_support, "query": batch_query}
|
||||||
|
|
||||||
|
def batchnize_sent(self, data, mode, is_train):
|
||||||
|
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
|
||||||
|
"seq_len": [],
|
||||||
|
"spans": [], "sentence_num": [], "query_ments": [], "query_probs": [], "subsentence_num": [],
|
||||||
|
'split_words': [],
|
||||||
|
"ment_labels": []}
|
||||||
|
for i in range(len(data)):
|
||||||
|
for k in batch.keys():
|
||||||
|
if k == 'sentence_num':
|
||||||
|
batch[k].append(data[i][k])
|
||||||
|
else:
|
||||||
|
batch[k] += data[i][k]
|
||||||
|
|
||||||
|
max_n_piece = max([sum(x) for x in batch['word_mask']])
|
||||||
|
|
||||||
|
max_n_words = max(batch["seq_len"])
|
||||||
|
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
ment_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
|
||||||
|
for k, slen in enumerate(batch['seq_len']):
|
||||||
|
assert len(batch['word_to_piece_ind'][k]) == slen
|
||||||
|
assert len(batch['word_to_piece_end'][k]) == slen
|
||||||
|
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
|
||||||
|
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
|
||||||
|
ment_labels[k, :slen] = batch['ment_labels'][k]
|
||||||
|
batch['word_to_piece_ind'] = word_to_piece_ind
|
||||||
|
batch['word_to_piece_end'] = word_to_piece_end
|
||||||
|
batch['ment_labels'] = ment_labels
|
||||||
|
if mode == "support":
|
||||||
|
batch = self.make_support_batch(batch)
|
||||||
|
else:
|
||||||
|
batch = self.make_query_batch(batch, is_train)
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k not in ['spans', 'sentence_num', 'label2tag', 'index', "query_ments", "query_probs", "subsentence_num",
|
||||||
|
"split_words"]:
|
||||||
|
v = np.array(v)
|
||||||
|
if k == "span_weights":
|
||||||
|
batch[k] = torch.tensor(v).float()
|
||||||
|
else:
|
||||||
|
batch[k] = torch.tensor(v).long()
|
||||||
|
batch['word'] = batch['word'][:, :max_n_piece]
|
||||||
|
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def __call__(self, batch, mode):
|
||||||
|
return self.batchnize_episode(batch, mode)
|
||||||
|
|
||||||
|
|
||||||
|
class JointNERDataset(data.Dataset):
|
||||||
|
"""
|
||||||
|
Fewshot NER Dataset
|
||||||
|
"""
|
||||||
|
def __init__(self, filepath, encoder, N, K, Q, max_length, schema, \
|
||||||
|
bio=True, debug_file=None, query_file=None, hidden_query_label=False, labelname_fn=None):
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
print("[ERROR] Data file does not exist!")
|
||||||
|
assert (0)
|
||||||
|
self.class2sampleid = {}
|
||||||
|
self.N = N
|
||||||
|
self.K = K
|
||||||
|
self.Q = Q
|
||||||
|
self.encoder = encoder
|
||||||
|
self.schema = schema # this means the meta-train/test schema
|
||||||
|
if self.schema == 'BIO':
|
||||||
|
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2}
|
||||||
|
elif self.schema == 'IO':
|
||||||
|
self.ment_tag2label = {"O": 0, "I-X": 1}
|
||||||
|
elif self.schema == 'BIOES':
|
||||||
|
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2, "E-X": 3, "S-X": 4}
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
self.ment_label2tag = {lidx: tag for tag, lidx in self.ment_tag2label.items()}
|
||||||
|
self.label2tag = None
|
||||||
|
self.tag2label = None
|
||||||
|
self.sql_label2tag = None
|
||||||
|
self.sql_tag2label = None
|
||||||
|
self.samples, self.classes = self.__load_data_from_file__(filepath, bio)
|
||||||
|
if debug_file:
|
||||||
|
if query_file is None:
|
||||||
|
print("use golden mention for typing !!! input_fn: {}".format(filepath))
|
||||||
|
self.sampler = DebugSampler(debug_file, query_file)
|
||||||
|
self.max_length = max_length
|
||||||
|
return
|
||||||
|
|
||||||
|
def __insert_sample__(self, index, sample_classes):
|
||||||
|
for item in sample_classes:
|
||||||
|
if item in self.class2sampleid:
|
||||||
|
self.class2sampleid[item].append(index)
|
||||||
|
else:
|
||||||
|
self.class2sampleid[item] = [index]
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_data_from_file__(self, filepath, bio):
|
||||||
|
samples = []
|
||||||
|
classes = []
|
||||||
|
with open(filepath, 'r', encoding='utf-8')as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
samplelines = []
|
||||||
|
index = 0
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip("\n")
|
||||||
|
if len(line):
|
||||||
|
samplelines.append(line)
|
||||||
|
else:
|
||||||
|
sample = SpanSample(index, samplelines, bio)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
samplelines = []
|
||||||
|
index += 1
|
||||||
|
if len(samplelines):
|
||||||
|
sample = SpanSample(index, samplelines, bio)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
classes = list(set(classes))
|
||||||
|
max_span_len = -1
|
||||||
|
long_ent_num = 0
|
||||||
|
tot_ent_num = 0
|
||||||
|
tot_tok_num = 0
|
||||||
|
for eid, sample in enumerate(samples):
|
||||||
|
max_span_len = max(max_span_len, sample.get_max_ent_len())
|
||||||
|
long_ent_num += sample.get_num_of_long_ent(10)
|
||||||
|
tot_ent_num += len(sample.spans)
|
||||||
|
tot_tok_num += len(sample.words)
|
||||||
|
# convert seq labels to target schema
|
||||||
|
new_tags = ['O' for _ in range(len(sample.words))]
|
||||||
|
for sp in sample.spans:
|
||||||
|
stype = sp[0]
|
||||||
|
sp_st = sp[1]
|
||||||
|
sp_ed = sp[2]
|
||||||
|
assert stype != "O"
|
||||||
|
if self.schema == 'IO':
|
||||||
|
for k in range(sp_st, sp_ed + 1):
|
||||||
|
new_tags[k] = "I-" + stype
|
||||||
|
elif self.schema == 'BIO':
|
||||||
|
new_tags[sp_st] = "B-" + stype
|
||||||
|
for k in range(sp_st + 1, sp_ed + 1):
|
||||||
|
new_tags[k] = "I-" + stype
|
||||||
|
elif self.schema == 'BIOES':
|
||||||
|
if sp_st == sp_ed:
|
||||||
|
new_tags[sp_st] = "S-" + stype
|
||||||
|
else:
|
||||||
|
new_tags[sp_st] = "B-" + stype
|
||||||
|
new_tags[sp_ed] = "E-" + stype
|
||||||
|
for k in range(sp_st + 1, sp_ed):
|
||||||
|
new_tags[k] = "I-" + stype
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
assert len(new_tags) == len(samples[eid].tags)
|
||||||
|
samples[eid].tags = new_tags
|
||||||
|
print("Sentence num {}, token num {}, entity num {} in file {}".format(len(samples), tot_tok_num, tot_ent_num,
|
||||||
|
filepath))
|
||||||
|
print("Total classes {}: {}".format(len(classes), str(classes)))
|
||||||
|
print("The max golden entity len in the dataset is ", max_span_len)
|
||||||
|
print("The max golden entity len in the dataset is greater than 10", long_ent_num)
|
||||||
|
print("The total coverage of spans: {:.5f}".format(1 - long_ent_num / tot_ent_num))
|
||||||
|
return samples, classes
|
||||||
|
|
||||||
|
def get_ment_word_tag(self, wtag):
|
||||||
|
if wtag == "O":
|
||||||
|
return wtag
|
||||||
|
return wtag[:2] + "X"
|
||||||
|
|
||||||
|
def __getraw__(self, sample, add_split):
|
||||||
|
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_ids, seq_lens = self.encoder.tokenize(sample.words)
|
||||||
|
sent_st_id = 0
|
||||||
|
split_seqs = []
|
||||||
|
ment_labels = []
|
||||||
|
word_labels = []
|
||||||
|
|
||||||
|
split_spans = []
|
||||||
|
split_querys = []
|
||||||
|
split_probs = []
|
||||||
|
|
||||||
|
split_words = []
|
||||||
|
cur_wid = 0
|
||||||
|
for k in range(len(seq_lens)):
|
||||||
|
split_words.append(sample.words[cur_wid: cur_wid + seq_lens[k]])
|
||||||
|
cur_wid += seq_lens[k]
|
||||||
|
for cur_len in seq_lens:
|
||||||
|
sent_ed_id = sent_st_id + cur_len
|
||||||
|
split_seqs.append(sample.tags[sent_st_id: sent_ed_id])
|
||||||
|
cur_ment_seqs = []
|
||||||
|
cur_word_seqs = []
|
||||||
|
split_spans.append([])
|
||||||
|
for tag, span_st, span_ed in sample.spans:
|
||||||
|
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
|
||||||
|
continue
|
||||||
|
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
|
||||||
|
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, span_ed - sent_st_id])
|
||||||
|
elif add_split:
|
||||||
|
if span_st >= sent_st_id:
|
||||||
|
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
|
||||||
|
else:
|
||||||
|
split_spans[-1].append([self.tag2label[tag], 0, span_ed - sent_st_id])
|
||||||
|
split_querys.append([])
|
||||||
|
split_probs.append([])
|
||||||
|
for [span_st, span_ed], span_prob in zip(sample.query_ments, sample.query_probs):
|
||||||
|
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
|
||||||
|
continue
|
||||||
|
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
|
||||||
|
split_querys[-1].append([span_st - sent_st_id, span_ed - sent_st_id])
|
||||||
|
split_probs[-1].append(span_prob)
|
||||||
|
elif add_split:
|
||||||
|
if span_st >= sent_st_id:
|
||||||
|
split_querys[-1].append([span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
|
||||||
|
split_probs[-1].append(span_prob)
|
||||||
|
else:
|
||||||
|
split_querys[-1].append([0, span_ed - sent_st_id])
|
||||||
|
split_probs[-1].append(span_prob)
|
||||||
|
|
||||||
|
for wtag in split_seqs[-1]:
|
||||||
|
cur_ment_seqs.append(self.ment_tag2label[self.get_ment_word_tag(wtag)])
|
||||||
|
cur_word_seqs.append(-1) # not used
|
||||||
|
ment_labels.append(cur_ment_seqs)
|
||||||
|
word_labels.append(cur_word_seqs)
|
||||||
|
sent_st_id += cur_len
|
||||||
|
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind, "word_to_piece_end": word_to_piece_end,
|
||||||
|
"seq_len": seq_lens, "word_shape_ids": word_shape_ids, "ment_labels": ment_labels, "word_labels": word_labels,
|
||||||
|
"subsentence_num": len(seq_lens), "spans": split_spans, "query_ments": split_querys, "query_probs": split_probs,
|
||||||
|
"split_words": split_words}
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __additem__(self, index, d, item):
|
||||||
|
d['index'].append(index)
|
||||||
|
d['word'] += item['word']
|
||||||
|
d['word_mask'] += item['word_mask']
|
||||||
|
d['seq_len'] += item['seq_len']
|
||||||
|
d['word_to_piece_ind'] += item['word_to_piece_ind']
|
||||||
|
d['word_to_piece_end'] += item['word_to_piece_end']
|
||||||
|
d['word_shape_ids'] += item['word_shape_ids']
|
||||||
|
d['spans'] += item['spans']
|
||||||
|
d['query_ments'] += item['query_ments']
|
||||||
|
d['query_probs'] += item['query_probs']
|
||||||
|
d['ment_labels'] += item['ment_labels']
|
||||||
|
d['word_labels'] += item['word_labels']
|
||||||
|
d['subsentence_num'].append(item['subsentence_num'])
|
||||||
|
d['split_words'] += item['split_words']
|
||||||
|
return
|
||||||
|
|
||||||
|
def __populate__(self, idx_list, query_ment_mp=None, savelabeldic=False, add_split=False):
|
||||||
|
dataset = {'index': [], 'word': [], 'word_mask': [], 'ment_labels': [], 'word_labels': [], 'word_to_piece_ind': [],
|
||||||
|
"word_to_piece_end": [], "seq_len": [], "word_shape_ids": [], "subsentence_num": [], 'spans': [], "query_ments": [], "query_probs": [], 'split_words': []}
|
||||||
|
for idx in idx_list:
|
||||||
|
if query_ment_mp is not None:
|
||||||
|
self.samples[idx].query_ments = [x[0] for x in query_ment_mp[str(self.samples[idx].index)]]
|
||||||
|
self.samples[idx].query_probs = [x[1] for x in query_ment_mp[str(self.samples[idx].index)]]
|
||||||
|
else:
|
||||||
|
self.samples[idx].query_ments = [[x[1], x[2]] for x in self.samples[idx].spans]
|
||||||
|
self.samples[idx].query_probs = [1 for x in self.samples[idx].spans]
|
||||||
|
item = self.__getraw__(self.samples[idx], add_split)
|
||||||
|
self.__additem__(idx, dataset, item)
|
||||||
|
if savelabeldic:
|
||||||
|
dataset['label2tag'] = self.label2tag
|
||||||
|
dataset['sql_label2tag'] = self.sql_label2tag
|
||||||
|
dataset['sentence_num'] = len(dataset["seq_len"])
|
||||||
|
assert len(dataset['word']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['ment_labels']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_labels']) == len(dataset['seq_len'])
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
tmp = self.sampler.__next__(index)
|
||||||
|
if len(tmp) == 3:
|
||||||
|
target_classes, support_idx, query_idx = tmp
|
||||||
|
query_ment_mp = None
|
||||||
|
else:
|
||||||
|
target_classes, support_idx, query_idx, query_ment_mp = tmp
|
||||||
|
target_tags = ['O']
|
||||||
|
for cname in target_classes:
|
||||||
|
if self.schema == 'IO':
|
||||||
|
target_tags.append(f"I-{cname}")
|
||||||
|
elif self.schema == 'BIO':
|
||||||
|
target_tags.append(f"B-{cname}")
|
||||||
|
target_tags.append(f"I-{cname}")
|
||||||
|
elif self.schema == 'BIOES':
|
||||||
|
target_tags.append(f"B-{cname}")
|
||||||
|
target_tags.append(f"I-{cname}")
|
||||||
|
target_tags.append(f"E-{cname}")
|
||||||
|
target_tags.append(f"S-{cname}")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
self.sql_tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
|
||||||
|
self.sql_label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
|
||||||
|
distinct_tags = ['O'] + target_classes
|
||||||
|
self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
|
||||||
|
self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
|
||||||
|
support_set = self.__populate__(support_idx, add_split=True, savelabeldic=False)
|
||||||
|
query_set = self.__populate__(query_idx, query_ment_mp=query_ment_mp, add_split=True, savelabeldic=True)
|
||||||
|
return support_set, query_set
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1000000
|
||||||
|
|
||||||
|
def batch_to_device(self, batch, device):
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k in ['sentence_num', 'label2tag', 'spans', 'index', 'query_ments', 'query_probs', "subsentence_num", 'split_words']:
|
||||||
|
continue
|
||||||
|
batch[k] = v.to(device)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def get_joint_loader(filepath, mode, encoder, N, K, Q, batch_size, max_length, schema,
|
||||||
|
bio, shuffle, num_workers=8, debug_file=None, query_file=None,
|
||||||
|
iou_thred=None, hidden_query_label=False, label_fn=None, use_oproto=False):
|
||||||
|
batcher = JointBatcher(iou_thred=iou_thred, use_oproto=use_oproto)
|
||||||
|
dataset = JointNERDataset(filepath, encoder, N, K, Q, max_length, bio=bio, schema=schema,
|
||||||
|
debug_file=debug_file, query_file=query_file,
|
||||||
|
hidden_query_label=hidden_query_label, labelname_fn=label_fn)
|
||||||
|
dataloader = data.DataLoader(dataset=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=lambda x: batcher(x, mode))
|
||||||
|
return dataloader
|
|
@ -0,0 +1,416 @@
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import prettytable as pt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def cal_prf(tot_hit_cnt, tot_pred_cnt, tot_gold_cnt):
|
||||||
|
precision = tot_hit_cnt / tot_pred_cnt if tot_hit_cnt > 0 else 0
|
||||||
|
recall = tot_hit_cnt / tot_gold_cnt if tot_hit_cnt > 0 else 0
|
||||||
|
f1 = precision * recall * 2 / (precision + recall) if (tot_hit_cnt > 0) else 0
|
||||||
|
return precision, recall, f1
|
||||||
|
|
||||||
|
|
||||||
|
def cal_episode_prf(logs):
|
||||||
|
pred_list = logs["pred"]
|
||||||
|
gold_list = logs["gold"]
|
||||||
|
indexes = logs["index"]
|
||||||
|
|
||||||
|
k = 0
|
||||||
|
query_episode_start_sent_idlist = []
|
||||||
|
for ep_subsent_num in logs["sentence_num"]:
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
|
||||||
|
subsent_id = 0
|
||||||
|
sent_id = 0
|
||||||
|
ep_prf = {}
|
||||||
|
for cur_preds, cur_golds, cur_index, snum in zip(pred_list, gold_list, indexes, logs["subsentence_num"]):
|
||||||
|
if sent_id in query_episode_start_sent_idlist:
|
||||||
|
ep_id = query_episode_start_sent_idlist.index(sent_id)
|
||||||
|
ep_prf[ep_id] = {"hit": 0, "pred": 0, "gold": 0}
|
||||||
|
|
||||||
|
pred_spans = set(map(lambda x: tuple(x), cur_preds))
|
||||||
|
gold_spans = set(map(lambda x: tuple(x), cur_golds))
|
||||||
|
ep_prf[ep_id]["pred"] += len(pred_spans)
|
||||||
|
ep_prf[ep_id]["gold"] += len(gold_spans)
|
||||||
|
ep_prf[ep_id]["hit"] += len(pred_spans.intersection(gold_spans))
|
||||||
|
subsent_id += snum
|
||||||
|
sent_id += 1
|
||||||
|
ep_p = 0
|
||||||
|
ep_r = 0
|
||||||
|
ep_f1 = 0
|
||||||
|
for ep in ep_prf.values():
|
||||||
|
p, r, f1 = cal_prf(ep["hit"], ep["pred"], ep["gold"])
|
||||||
|
ep_p += p
|
||||||
|
ep_r += r
|
||||||
|
ep_f1 += f1
|
||||||
|
ep_p /= len(ep_prf)
|
||||||
|
ep_r /= len(ep_prf)
|
||||||
|
ep_f1 /= len(ep_prf)
|
||||||
|
return ep_p, ep_r, ep_f1
|
||||||
|
|
||||||
|
def eval_ment_log(samples, logs):
|
||||||
|
indexes = logs["index"]
|
||||||
|
pred_list = logs["pred"]
|
||||||
|
gold_list = logs["gold"]
|
||||||
|
gold_cnt_mp = defaultdict(int)
|
||||||
|
hit_cnt_mp = defaultdict(int)
|
||||||
|
for cur_idx, cur_preds, cur_golds in zip(indexes, pred_list, gold_list):
|
||||||
|
tagged_spans = samples[cur_idx].spans
|
||||||
|
for r, x, y in tagged_spans:
|
||||||
|
gold_cnt_mp[r] += 1
|
||||||
|
if [x, y] in cur_preds:
|
||||||
|
hit_cnt_mp[r] += 1
|
||||||
|
tot_miss = 0
|
||||||
|
tb = pt.PrettyTable(["Type", "recall", "miss_span", "tot_span"])
|
||||||
|
for gtype in sorted(gold_cnt_mp.keys()):
|
||||||
|
rscore = hit_cnt_mp[gtype] / gold_cnt_mp[gtype]
|
||||||
|
miss_cnt = gold_cnt_mp[gtype] - hit_cnt_mp[gtype]
|
||||||
|
tb.add_row([
|
||||||
|
gtype, "{:.4f}".format(rscore), miss_cnt, gold_cnt_mp[gtype]
|
||||||
|
])
|
||||||
|
tot_miss += miss_cnt
|
||||||
|
tb.add_row(["Total", "{:.4f}".format(sum(hit_cnt_mp.values()) / sum(gold_cnt_mp.values())), tot_miss, sum(gold_cnt_mp.values())])
|
||||||
|
print(tb)
|
||||||
|
return
|
||||||
|
|
||||||
|
def write_ep_ment_log_json(samples, logs, output_fn):
|
||||||
|
k = 0
|
||||||
|
subsent_id = 0
|
||||||
|
query_episode_end_subsent_idlist = []
|
||||||
|
for ep_subsent_num in logs["sentence_num"]:
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["subsentence_num"][k]
|
||||||
|
subsent_id += logs["subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
query_episode_end_subsent_idlist.append(subsent_id)
|
||||||
|
indexes = logs["index"]
|
||||||
|
if "before_prob" in logs:
|
||||||
|
split_ind_list = logs["before_ind"]
|
||||||
|
split_prob_list = logs["before_prob"]
|
||||||
|
else:
|
||||||
|
split_ind_list = None
|
||||||
|
split_prob_list = None
|
||||||
|
sent_pred_list = logs["pred"]
|
||||||
|
split_slen_list = logs["seq_len"]
|
||||||
|
subsent_num_list = logs["subsentence_num"]
|
||||||
|
log_lines = []
|
||||||
|
cur_query_res = {}
|
||||||
|
subsent_id = 0
|
||||||
|
sent_id = 0
|
||||||
|
for snum, cur_index in zip(subsent_num_list, indexes):
|
||||||
|
if subsent_id in query_episode_end_subsent_idlist:
|
||||||
|
log_lines.append(cur_query_res)
|
||||||
|
cur_query_res = {}
|
||||||
|
cur_probs = []
|
||||||
|
if split_prob_list is not None:
|
||||||
|
cur_sent_st = 0
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
for x, y in zip(split_ind_list[k], split_prob_list[k]):
|
||||||
|
cur_probs.append(([x[0] + cur_sent_st, x[1] + cur_sent_st], y))
|
||||||
|
cur_sent_st += split_slen_list[k]
|
||||||
|
else:
|
||||||
|
for x in sent_pred_list[sent_id]:
|
||||||
|
cur_probs.append(([x[0], x[1]], 1))
|
||||||
|
cur_query_res[samples[cur_index].index] = cur_probs
|
||||||
|
subsent_id += snum
|
||||||
|
sent_id += 1
|
||||||
|
assert subsent_id in query_episode_end_subsent_idlist
|
||||||
|
log_lines.append(cur_query_res)
|
||||||
|
cur_query_res = {}
|
||||||
|
with open(output_fn, mode="w", encoding="utf-8") as fp:
|
||||||
|
output_lines = []
|
||||||
|
for line in log_lines:
|
||||||
|
output_lines.append(json.dumps(line) + "\n")
|
||||||
|
fp.writelines(output_lines)
|
||||||
|
return
|
||||||
|
|
||||||
|
def write_ment_log(samples, logs, output_fn):
|
||||||
|
indexes = logs["index"]
|
||||||
|
pred_list = logs["pred"]
|
||||||
|
gold_list = logs["gold"]
|
||||||
|
subsent_num_list = logs["subsentence_num"]
|
||||||
|
split_slen_list = logs["seq_len"]
|
||||||
|
if "before_prob" in logs:
|
||||||
|
split_ind_list = logs["before_ind"]
|
||||||
|
split_prob_list = logs["before_prob"]
|
||||||
|
else:
|
||||||
|
split_ind_list = None
|
||||||
|
split_prob_list = None
|
||||||
|
log_lines = []
|
||||||
|
assert len(pred_list) == len(indexes)
|
||||||
|
assert len(gold_list) == len(indexes)
|
||||||
|
subsent_id = 0
|
||||||
|
for cur_preds, cur_golds, cur_index, snum in zip(pred_list, gold_list, indexes, subsent_num_list):
|
||||||
|
cur_sample = samples[cur_index]
|
||||||
|
cur_probs = []
|
||||||
|
if split_prob_list is not None:
|
||||||
|
cur_sent_st = 0
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
for x, y in zip(split_ind_list[k], split_prob_list[k]):
|
||||||
|
cur_probs.append(([x[0] + cur_sent_st, x[1] + cur_sent_st], y))
|
||||||
|
cur_sent_st += split_slen_list[k]
|
||||||
|
subsent_id += snum
|
||||||
|
log_lines.append("index:{}\n".format(cur_sample.index))
|
||||||
|
log_lines.append("pred:\n")
|
||||||
|
for x in cur_preds:
|
||||||
|
log_lines.append(" ".join(cur_sample.words[x[0]: x[1] + 1]) + " " + str(x) + "\n")
|
||||||
|
log_lines.append("gold:\n")
|
||||||
|
for x in cur_golds:
|
||||||
|
log_lines.append(" ".join(cur_sample.words[x[0]: x[1] + 1]) + " " + str(x) + "\n")
|
||||||
|
log_lines.append("log:\n")
|
||||||
|
log_lines.append(json.dumps(cur_probs) + "\n")
|
||||||
|
log_lines.append("\n")
|
||||||
|
with open(output_fn, mode="w", encoding="utf-8") as fp:
|
||||||
|
fp.writelines(log_lines)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def eval_ent_log(samples, logs):
|
||||||
|
indexes = logs["index"]
|
||||||
|
pred_list = logs["pred"]
|
||||||
|
gold_list = logs["gold"]
|
||||||
|
assert len(indexes) == len(pred_list)
|
||||||
|
gold_cnt_mp = defaultdict(int)
|
||||||
|
hit_cnt_mp = defaultdict(int)
|
||||||
|
pred_cnt_mp = defaultdict(int)
|
||||||
|
hit_mentcnt_mp = defaultdict(int)
|
||||||
|
fp_cnt_mp = defaultdict(int)
|
||||||
|
for cur_idx, cur_preds, cur_golds in zip(indexes, pred_list, gold_list):
|
||||||
|
used_token = np.array([0 for i in range(len(samples[cur_idx].words))])
|
||||||
|
cur_ments = [[x[1], x[2]] for x in cur_preds]
|
||||||
|
tagged_spans = samples[cur_idx].spans
|
||||||
|
for r, x, y in tagged_spans:
|
||||||
|
gold_cnt_mp[r] += 1
|
||||||
|
used_token[x: y + 1] = 1
|
||||||
|
if [r, x, y] in cur_preds:
|
||||||
|
hit_cnt_mp[r] += 1
|
||||||
|
if [x, y] in cur_ments:
|
||||||
|
hit_mentcnt_mp[r] += 1
|
||||||
|
for r, x, y in cur_preds:
|
||||||
|
pred_cnt_mp[r] += 1
|
||||||
|
if sum(used_token[x: y + 1]) == 0:
|
||||||
|
fp_cnt_mp[r] += 1
|
||||||
|
tb = pt.PrettyTable(["Type", "precision", "recall", "f1", "overall_ment_recall", "error_miss_ment", "false_span", "span_type_error", "correct"])
|
||||||
|
for gtype in sorted(gold_cnt_mp.keys()):
|
||||||
|
pscore = hit_cnt_mp[gtype] / pred_cnt_mp[gtype] if hit_cnt_mp[gtype] > 0 else 0
|
||||||
|
rscore = hit_cnt_mp[gtype] / gold_cnt_mp[gtype] if hit_cnt_mp[gtype] > 0 else 0
|
||||||
|
fscore = pscore * rscore * 2 / (pscore + rscore) if hit_cnt_mp[gtype] > 0 else 0
|
||||||
|
ment_rscore = hit_mentcnt_mp[gtype] / gold_cnt_mp[gtype]
|
||||||
|
miss_ment_cnt = gold_cnt_mp[gtype] - hit_mentcnt_mp[gtype]
|
||||||
|
fp_cnt = fp_cnt_mp[gtype]
|
||||||
|
type_error_cnt = hit_mentcnt_mp[gtype] - hit_cnt_mp[gtype]
|
||||||
|
correct_cnt = hit_cnt_mp[gtype]
|
||||||
|
tb.add_row([
|
||||||
|
gtype, "{:.4f}".format(pscore), "{:.4f}".format(rscore), "{:.4f}".format(fscore), "{:.4f}".format(ment_rscore), miss_ment_cnt, fp_cnt, type_error_cnt, correct_cnt
|
||||||
|
])
|
||||||
|
pscore = sum(hit_cnt_mp.values()) / sum(pred_cnt_mp.values()) if sum(hit_cnt_mp.values()) > 0 else 0
|
||||||
|
rscore = sum(hit_cnt_mp.values()) / sum(gold_cnt_mp.values()) if sum(hit_cnt_mp.values()) > 0 else 0
|
||||||
|
fscore = pscore * rscore * 2 / (pscore + rscore) if sum(hit_cnt_mp.values()) > 0 else 0
|
||||||
|
ment_rscore = sum(hit_mentcnt_mp.values()) / sum(gold_cnt_mp.values())
|
||||||
|
miss_ment_cnt = sum(gold_cnt_mp.values()) - sum(hit_mentcnt_mp.values())
|
||||||
|
type_error_cnt = sum(hit_mentcnt_mp.values()) - sum(hit_cnt_mp.values())
|
||||||
|
tb.add_row([
|
||||||
|
"Overall", "{:.4f}".format(pscore), "{:.4f}".format(rscore), "{:.4f}".format(fscore), "{:.4f}".format(ment_rscore), miss_ment_cnt, sum(fp_cnt_mp.values()),type_error_cnt, sum(hit_cnt_mp.values())
|
||||||
|
])
|
||||||
|
print(tb)
|
||||||
|
return
|
||||||
|
|
||||||
|
def write_ent_log(samples, logs, output_fn):
|
||||||
|
k = 0
|
||||||
|
support_episode_start_sent_idlist = []
|
||||||
|
for ep_subsent_num in logs["support_sentence_num"]:
|
||||||
|
support_episode_start_sent_idlist.append(k)
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["support_subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
support_episode_start_sent_idlist.append(k)
|
||||||
|
k = 0
|
||||||
|
query_episode_start_sent_idlist = []
|
||||||
|
for ep_subsent_num in logs["sentence_num"]:
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
assert len(support_episode_start_sent_idlist) == len(query_episode_start_sent_idlist)
|
||||||
|
indexes = logs["index"]
|
||||||
|
pred_list = logs["pred"]
|
||||||
|
gold_list = logs["gold"]
|
||||||
|
split_ind_list = logs["before_ind"] if "before_ind" in logs else None
|
||||||
|
split_prob_list = logs["before_prob"] if "before_prob" in logs else None
|
||||||
|
split_slen_list = logs["seq_len"]
|
||||||
|
split_label_tag_list = logs["label_tag"] if "label_tag" in logs else None
|
||||||
|
|
||||||
|
log_lines = []
|
||||||
|
subsent_id = 0
|
||||||
|
sent_id = 0
|
||||||
|
for cur_preds, cur_golds, cur_index, snum in zip(pred_list, gold_list, indexes, logs["subsentence_num"]):
|
||||||
|
if sent_id in query_episode_start_sent_idlist:
|
||||||
|
ep_id = query_episode_start_sent_idlist.index(sent_id)
|
||||||
|
support_sent_id_1 = support_episode_start_sent_idlist[ep_id]
|
||||||
|
support_sent_id_2 = support_episode_start_sent_idlist[ep_id + 1]
|
||||||
|
log_lines.append("="*20+"\n")
|
||||||
|
support_indexes = logs["support_index"][support_sent_id_1:support_sent_id_2]
|
||||||
|
log_lines.append("support:{}\n".format(str(support_indexes)))
|
||||||
|
for x in support_indexes:
|
||||||
|
log_lines.append(str(samples[x].words) + "\n")
|
||||||
|
log_lines.append(str([sp[0] + ":" + " ".join(samples[x].words[sp[1]: sp[2] + 1]) for sp in samples[x].spans]) + "\n")
|
||||||
|
if split_label_tag_list is not None:
|
||||||
|
log_lines.append(json.dumps(split_label_tag_list[subsent_id]) + "\n")
|
||||||
|
log_lines.append("\n")
|
||||||
|
cur_sample = samples[cur_index]
|
||||||
|
cur_probs = []
|
||||||
|
if split_ind_list is not None:
|
||||||
|
cur_sent_st = 0
|
||||||
|
for k in range(subsent_id, subsent_id + snum):
|
||||||
|
for x, y_list in zip(split_ind_list[k], split_prob_list[k]):
|
||||||
|
cur_probs.append(str([x[0] + cur_sent_st, x[1] + cur_sent_st])
|
||||||
|
+ ": " + ",".join(["{:.5f}".format(y) for y in y_list])
|
||||||
|
+ ", "
|
||||||
|
+ " ".join(cur_sample.words[x[0] + cur_sent_st: x[1] + cur_sent_st + 1]) + "\n")
|
||||||
|
cur_sent_st += split_slen_list[k]
|
||||||
|
subsent_id += snum
|
||||||
|
|
||||||
|
log_lines.append("index:{}\n".format(cur_index))
|
||||||
|
log_lines.append(str(cur_sample.words) + "\n")
|
||||||
|
log_lines.append("pred:\n")
|
||||||
|
for x in cur_preds:
|
||||||
|
log_lines.append(" ".join(cur_sample.words[x[1]: x[2] + 1]) + " " + str(x) + "\n")
|
||||||
|
log_lines.append("gold:\n")
|
||||||
|
for x in cur_golds:
|
||||||
|
log_lines.append(" ".join(cur_sample.words[x[1]: x[2] + 1]) + " " + str(x) + "\n")
|
||||||
|
log_lines.append("log:\n")
|
||||||
|
log_lines.extend(cur_probs)
|
||||||
|
log_lines.append("\n")
|
||||||
|
|
||||||
|
sent_id += 1
|
||||||
|
with open(output_fn, mode="w", encoding="utf-8") as fp:
|
||||||
|
fp.writelines(log_lines)
|
||||||
|
return log_lines
|
||||||
|
|
||||||
|
|
||||||
|
def write_ent_pred_json(samples, logs, output_fn):
|
||||||
|
k = 0
|
||||||
|
support_episode_start_sent_idlist = []
|
||||||
|
for ep_subsent_num in logs["support_sentence_num"]:
|
||||||
|
support_episode_start_sent_idlist.append(k)
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["support_subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
support_episode_start_sent_idlist.append(k)
|
||||||
|
k = 0
|
||||||
|
query_episode_start_sent_idlist = []
|
||||||
|
for ep_subsent_num in logs["sentence_num"]:
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
assert len(support_episode_start_sent_idlist) == len(query_episode_start_sent_idlist)
|
||||||
|
indexes = logs["index"]
|
||||||
|
pred_list = logs["pred"]
|
||||||
|
log_lines = []
|
||||||
|
cur_query_res = None
|
||||||
|
support_indexes = None
|
||||||
|
subsent_id = 0
|
||||||
|
sent_id = 0
|
||||||
|
for cur_preds, cur_index, snum in zip(pred_list, indexes, logs["subsentence_num"]):
|
||||||
|
if sent_id in query_episode_start_sent_idlist:
|
||||||
|
if cur_query_res is not None:
|
||||||
|
log_lines.append({"support": support_indexes, "query": cur_query_res})
|
||||||
|
ep_id = query_episode_start_sent_idlist.index(sent_id)
|
||||||
|
support_sent_id_1 = support_episode_start_sent_idlist[ep_id]
|
||||||
|
support_sent_id_2 = support_episode_start_sent_idlist[ep_id + 1]
|
||||||
|
support_indexes = logs["support_index"][support_sent_id_1:support_sent_id_2]
|
||||||
|
cur_query_res = []
|
||||||
|
|
||||||
|
cur_sample = samples[cur_index]
|
||||||
|
subsent_id += snum
|
||||||
|
cur_query_res.append({"index":cur_index, "pred":cur_preds, "gold": cur_sample.spans})
|
||||||
|
sent_id += 1
|
||||||
|
if len(cur_query_res) > 0:
|
||||||
|
log_lines.append({"support": support_indexes, "query": cur_query_res})
|
||||||
|
with open(output_fn, mode="w", encoding="utf-8") as fp:
|
||||||
|
output_lines = []
|
||||||
|
for line in log_lines:
|
||||||
|
output_lines.append(json.dumps(line) + "\n")
|
||||||
|
fp.writelines(output_lines)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def write_pos_pred_json(samples, logs, output_fn):
|
||||||
|
k = 0
|
||||||
|
support_episode_start_sent_idlist = []
|
||||||
|
for ep_subsent_num in logs["support_sentence_num"]:
|
||||||
|
support_episode_start_sent_idlist.append(k)
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["support_subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
support_episode_start_sent_idlist.append(k)
|
||||||
|
k = 0
|
||||||
|
query_episode_start_sent_idlist = []
|
||||||
|
for ep_subsent_num in logs["sentence_num"]:
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
cur_subsent_num = 0
|
||||||
|
while cur_subsent_num < ep_subsent_num:
|
||||||
|
cur_subsent_num += logs["subsentence_num"][k]
|
||||||
|
k += 1
|
||||||
|
query_episode_start_sent_idlist.append(k)
|
||||||
|
assert len(support_episode_start_sent_idlist) == len(query_episode_start_sent_idlist)
|
||||||
|
indexes = logs["index"]
|
||||||
|
pred_list = logs["pred"]
|
||||||
|
log_lines = []
|
||||||
|
cur_query_res = None
|
||||||
|
support_indexes = None
|
||||||
|
subsent_id = 0
|
||||||
|
sent_id = 0
|
||||||
|
for cur_preds, cur_index, snum in zip(pred_list, indexes, logs["subsentence_num"]):
|
||||||
|
if sent_id in query_episode_start_sent_idlist:
|
||||||
|
if cur_query_res is not None:
|
||||||
|
log_lines.append({"support": support_indexes, "query": cur_query_res})
|
||||||
|
ep_id = query_episode_start_sent_idlist.index(sent_id)
|
||||||
|
support_sent_id_1 = support_episode_start_sent_idlist[ep_id]
|
||||||
|
support_sent_id_2 = support_episode_start_sent_idlist[ep_id + 1]
|
||||||
|
support_indexes = logs["support_index"][support_sent_id_1:support_sent_id_2]
|
||||||
|
cur_query_res = []
|
||||||
|
|
||||||
|
cur_sample = samples[cur_index]
|
||||||
|
subsent_id += snum
|
||||||
|
cur_query_res.append({"index":cur_index, "pred":cur_preds, "gold": cur_sample.tags})
|
||||||
|
sent_id += 1
|
||||||
|
if len(cur_query_res) > 0:
|
||||||
|
log_lines.append({"support": support_indexes, "query": cur_query_res})
|
||||||
|
with open(output_fn, mode="w", encoding="utf-8") as fp:
|
||||||
|
output_lines = []
|
||||||
|
for line in log_lines:
|
||||||
|
output_lines.append(json.dumps(line) + "\n")
|
||||||
|
fp.writelines(output_lines)
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
return
|
|
@ -0,0 +1,244 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import os
|
||||||
|
from .fewshotsampler import DebugSampler, FewshotSampleBase
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
class TokenSample(FewshotSampleBase):
|
||||||
|
def __init__(self, idx, filelines):
|
||||||
|
super(TokenSample, self).__init__()
|
||||||
|
self.index = idx
|
||||||
|
filelines = [line.split('\t') for line in filelines]
|
||||||
|
if len(filelines[0]) == 2:
|
||||||
|
self.words, self.tags = zip(*filelines)
|
||||||
|
else:
|
||||||
|
self.words, self.postags, self.tags = zip(*filelines)
|
||||||
|
return
|
||||||
|
|
||||||
|
def __count_entities__(self):
|
||||||
|
self.class_count = {}
|
||||||
|
for tag in self.tags:
|
||||||
|
if tag in self.class_count:
|
||||||
|
self.class_count[tag] += 1
|
||||||
|
else:
|
||||||
|
self.class_count[tag] = 1
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_class_count(self):
|
||||||
|
if self.class_count:
|
||||||
|
return self.class_count
|
||||||
|
else:
|
||||||
|
self.__count_entities__()
|
||||||
|
return self.class_count
|
||||||
|
|
||||||
|
def get_tag_class(self):
|
||||||
|
return list(set(self.tags))
|
||||||
|
|
||||||
|
def valid(self, target_classes):
|
||||||
|
return (set(self.get_class_count().keys()).intersection(set(target_classes))) and \
|
||||||
|
not (set(self.get_class_count().keys()).difference(set(target_classes)))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class SeqBatcher:
|
||||||
|
def __init__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def batchnize_sent(self, data):
|
||||||
|
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
|
||||||
|
"word_shape_ids": [], "word_labels": [],
|
||||||
|
"seq_len": [], "sentence_num": [], "subsentence_num": []}
|
||||||
|
for i in range(len(data)):
|
||||||
|
for k in batch.keys():
|
||||||
|
if k == 'sentence_num':
|
||||||
|
batch[k].append(data[i][k])
|
||||||
|
else:
|
||||||
|
batch[k] += data[i][k]
|
||||||
|
|
||||||
|
max_n_piece = max([sum(x) for x in batch['word_mask']])
|
||||||
|
max_n_words = max(batch["seq_len"])
|
||||||
|
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
word_shape_ids = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
word_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
|
||||||
|
for k, slen in enumerate(batch['seq_len']):
|
||||||
|
assert len(batch['word_to_piece_ind'][k]) == slen
|
||||||
|
assert len(batch['word_to_piece_end'][k]) == slen
|
||||||
|
assert len(batch['word_shape_ids'][k]) == slen
|
||||||
|
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
|
||||||
|
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
|
||||||
|
word_shape_ids[k, :slen] = batch['word_shape_ids'][k]
|
||||||
|
word_labels[k, :slen] = batch['word_labels'][k]
|
||||||
|
batch['word_to_piece_ind'] = word_to_piece_ind
|
||||||
|
batch['word_to_piece_end'] = word_to_piece_end
|
||||||
|
batch['word_shape_ids'] = word_shape_ids
|
||||||
|
batch['word_labels'] = word_labels
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k not in ['sentence_num', 'index', 'subsentence_num', 'label2tag']:
|
||||||
|
v = np.array(v, dtype=np.int32)
|
||||||
|
batch[k] = torch.tensor(v, dtype=torch.long)
|
||||||
|
batch['word'] = batch['word'][:, :max_n_piece]
|
||||||
|
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def __call__(self, batch, use_episode):
|
||||||
|
if use_episode:
|
||||||
|
support_sets, query_sets = zip(*batch)
|
||||||
|
batch_support = self.batchnize_sent(support_sets)
|
||||||
|
batch_query = self.batchnize_sent(query_sets)
|
||||||
|
batch_query["label2tag"] = []
|
||||||
|
for i in range(len(query_sets)):
|
||||||
|
batch_query["label2tag"].append(query_sets[i]["label2tag"])
|
||||||
|
return {"support": batch_support, "query": batch_query}
|
||||||
|
else:
|
||||||
|
batch_query = self.batchnize_sent(batch)
|
||||||
|
batch_query["label2tag"] = []
|
||||||
|
for i in range(len(batch)):
|
||||||
|
batch_query["label2tag"].append(batch[i]["label2tag"])
|
||||||
|
return batch_query
|
||||||
|
|
||||||
|
class SeqNERDataset(data.Dataset):
|
||||||
|
"""
|
||||||
|
Fewshot NER Dataset
|
||||||
|
"""
|
||||||
|
def __init__(self, filepath, encoder, max_length, debug_file=None, use_episode=True):
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
print("[ERROR] Data file does not exist!")
|
||||||
|
assert (0)
|
||||||
|
self.class2sampleid = {}
|
||||||
|
self.encoder = encoder
|
||||||
|
self.label2tag = None
|
||||||
|
self.tag2label = None
|
||||||
|
self.samples, self.classes = self.__load_data_from_file__(filepath)
|
||||||
|
if debug_file is not None:
|
||||||
|
self.sampler = DebugSampler(debug_file)
|
||||||
|
self.max_length = max_length
|
||||||
|
self.use_episode = use_episode
|
||||||
|
return
|
||||||
|
|
||||||
|
def __insert_sample__(self, index, sample_classes):
|
||||||
|
for item in sample_classes:
|
||||||
|
if item in self.class2sampleid:
|
||||||
|
self.class2sampleid[item].append(index)
|
||||||
|
else:
|
||||||
|
self.class2sampleid[item] = [index]
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_data_from_file__(self, filepath):
|
||||||
|
samples = []
|
||||||
|
classes = []
|
||||||
|
with open(filepath, 'r', encoding='utf-8')as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
samplelines = []
|
||||||
|
index = 0
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip("\n")
|
||||||
|
if len(line):
|
||||||
|
samplelines.append(line)
|
||||||
|
else:
|
||||||
|
sample = TokenSample(index, samplelines)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
samplelines = []
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
if len(samplelines):
|
||||||
|
sample = TokenSample(index, samplelines)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
classes = list(set(classes))
|
||||||
|
return samples, classes
|
||||||
|
|
||||||
|
def __getraw__(self, sample):
|
||||||
|
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_ids, seq_lens = self.encoder.tokenize(sample.words)
|
||||||
|
sent_st_id = 0
|
||||||
|
split_seqs = []
|
||||||
|
word_labels = []
|
||||||
|
for cur_len in seq_lens:
|
||||||
|
sent_ed_id = sent_st_id + cur_len
|
||||||
|
split_seqs.append(sample.tags[sent_st_id: sent_ed_id])
|
||||||
|
cur_word_seqs = []
|
||||||
|
for wtag in split_seqs[-1]:
|
||||||
|
if wtag not in self.tag2label:
|
||||||
|
print(wtag, self.tag2label)
|
||||||
|
cur_word_seqs.append(self.tag2label[wtag])
|
||||||
|
word_labels.append(cur_word_seqs)
|
||||||
|
sent_st_id += cur_len
|
||||||
|
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind, "word_to_piece_end": word_to_piece_end,
|
||||||
|
"seq_len": seq_lens, "word_shape_ids": word_shape_ids, "word_labels": word_labels,
|
||||||
|
"subsentence_num": len(seq_lens)}
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __additem__(self, index, d, item):
|
||||||
|
d['index'].append(index)
|
||||||
|
d['word'] += item['word']
|
||||||
|
d['word_mask'] += item['word_mask']
|
||||||
|
d['seq_len'] += item['seq_len']
|
||||||
|
d['word_to_piece_ind'] += item['word_to_piece_ind']
|
||||||
|
d['word_to_piece_end'] += item['word_to_piece_end']
|
||||||
|
d['word_shape_ids'] += item['word_shape_ids']
|
||||||
|
d['word_labels'] += item['word_labels']
|
||||||
|
d['subsentence_num'].append(item['subsentence_num'])
|
||||||
|
return
|
||||||
|
|
||||||
|
def __populate__(self, idx_list, savelabeldic=False):
|
||||||
|
dataset = {'index': [], 'word': [], 'word_mask': [], 'word_labels': [], 'word_to_piece_ind': [],
|
||||||
|
"word_to_piece_end": [], "seq_len": [], "word_shape_ids": [], "subsentence_num": []}
|
||||||
|
for idx in idx_list:
|
||||||
|
item = self.__getraw__(self.samples[idx])
|
||||||
|
self.__additem__(idx, dataset, item)
|
||||||
|
if savelabeldic:
|
||||||
|
dataset['label2tag'] = self.label2tag
|
||||||
|
dataset['sentence_num'] = len(dataset["seq_len"])
|
||||||
|
assert len(dataset['word']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_labels']) == len(dataset['seq_len'])
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
target_tags, support_idx, query_idx = self.sampler.__next__(index)
|
||||||
|
if self.use_episode:
|
||||||
|
self.tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
|
||||||
|
self.label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
|
||||||
|
support_set = self.__populate__(support_idx, savelabeldic=False)
|
||||||
|
query_set = self.__populate__(query_idx, savelabeldic=True)
|
||||||
|
return support_set, query_set
|
||||||
|
else:
|
||||||
|
if self.tag2label is None:
|
||||||
|
self.tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
|
||||||
|
self.label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
|
||||||
|
return self.__populate__(support_idx + query_idx, savelabeldic=True)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1000000
|
||||||
|
|
||||||
|
def batch_to_device(self, batch, device):
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k == 'sentence_num' or k == 'index' or k == 'subsentence_num' or k == 'label2tag':
|
||||||
|
continue
|
||||||
|
batch[k] = v.to(device)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq_loader(filepath, mode, encoder, batch_size, max_length, shuffle, debug_file=None, num_workers=8, use_episode=True):
|
||||||
|
batcher = SeqBatcher()
|
||||||
|
dataset = SeqNERDataset(filepath, encoder, max_length, debug_file=debug_file, use_episode=use_episode)
|
||||||
|
dataloader = data.DataLoader(dataset=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=lambda x: batcher(x, use_episode))
|
||||||
|
return dataloader
|
|
@ -0,0 +1,270 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import os
|
||||||
|
from .fewshotsampler import DebugSampler, FewshotSampleBase
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from .span_sample import SpanSample
|
||||||
|
|
||||||
|
class SeqBatcher:
|
||||||
|
def __init__(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def batchnize_sent(self, data):
|
||||||
|
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
|
||||||
|
"word_shape_ids": [], "word_labels": [],
|
||||||
|
"seq_len": [], "ment_labels": [], "sentence_num": [], "subsentence_num": []}
|
||||||
|
for i in range(len(data)):
|
||||||
|
for k in batch.keys():
|
||||||
|
if k == 'sentence_num':
|
||||||
|
batch[k].append(data[i][k])
|
||||||
|
else:
|
||||||
|
batch[k] += data[i][k]
|
||||||
|
|
||||||
|
max_n_piece = max([sum(x) for x in batch['word_mask']])
|
||||||
|
max_n_words = max(batch["seq_len"])
|
||||||
|
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
word_shape_ids = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
ment_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
|
||||||
|
word_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
|
||||||
|
for k, slen in enumerate(batch['seq_len']):
|
||||||
|
assert len(batch['word_to_piece_ind'][k]) == slen
|
||||||
|
assert len(batch['word_to_piece_end'][k]) == slen
|
||||||
|
assert len(batch['word_shape_ids'][k]) == slen
|
||||||
|
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
|
||||||
|
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
|
||||||
|
word_shape_ids[k, :slen] = batch['word_shape_ids'][k]
|
||||||
|
ment_labels[k, :slen] = batch['ment_labels'][k]
|
||||||
|
word_labels[k, :slen] = batch['word_labels'][k]
|
||||||
|
batch['word_to_piece_ind'] = word_to_piece_ind
|
||||||
|
batch['word_to_piece_end'] = word_to_piece_end
|
||||||
|
batch['word_shape_ids'] = word_shape_ids
|
||||||
|
batch['ment_labels'] = ment_labels
|
||||||
|
batch['word_labels'] = word_labels
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k not in ['sentence_num', 'index', 'subsentence_num', 'label2tag']:
|
||||||
|
v = np.array(v, dtype=np.int32)
|
||||||
|
batch[k] = torch.tensor(v, dtype=torch.long)
|
||||||
|
batch['word'] = batch['word'][:, :max_n_piece]
|
||||||
|
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
support_sets, query_sets = zip(*batch)
|
||||||
|
batch_support = self.batchnize_sent(support_sets)
|
||||||
|
batch_query = self.batchnize_sent(query_sets)
|
||||||
|
batch_query["label2tag"] = []
|
||||||
|
for i in range(len(query_sets)):
|
||||||
|
batch_query["label2tag"].append(query_sets[i]["label2tag"])
|
||||||
|
return {"support": batch_support, "query": batch_query}
|
||||||
|
|
||||||
|
class SeqNERDataset(data.Dataset):
|
||||||
|
"""
|
||||||
|
Fewshot NER Dataset
|
||||||
|
"""
|
||||||
|
def __init__(self, filepath, encoder, max_length, schema, bio=True, debug_file=None):
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
print("[ERROR] Data file does not exist!")
|
||||||
|
assert (0)
|
||||||
|
self.class2sampleid = {}
|
||||||
|
self.encoder = encoder
|
||||||
|
self.schema = schema # this means the meta-train/test schema
|
||||||
|
if self.schema == 'BIO':
|
||||||
|
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2}
|
||||||
|
elif self.schema == 'IO':
|
||||||
|
self.ment_tag2label = {"O": 0, "I-X": 1}
|
||||||
|
elif self.schema == 'BIOES':
|
||||||
|
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2, "E-X": 3, "S-X": 4}
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
self.ment_label2tag = {lidx: tag for tag, lidx in self.ment_tag2label.items()}
|
||||||
|
self.label2tag = None
|
||||||
|
self.tag2label = None
|
||||||
|
self.samples, self.classes = self.__load_data_from_file__(filepath, bio)
|
||||||
|
if debug_file is not None:
|
||||||
|
self.sampler = DebugSampler(debug_file)
|
||||||
|
self.max_length = max_length
|
||||||
|
return
|
||||||
|
|
||||||
|
def __insert_sample__(self, index, sample_classes):
|
||||||
|
for item in sample_classes:
|
||||||
|
if item in self.class2sampleid:
|
||||||
|
self.class2sampleid[item].append(index)
|
||||||
|
else:
|
||||||
|
self.class2sampleid[item] = [index]
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_data_from_file__(self, filepath, bio):
|
||||||
|
samples = []
|
||||||
|
classes = []
|
||||||
|
with open(filepath, 'r', encoding='utf-8')as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
samplelines = []
|
||||||
|
index = 0
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip("\n")
|
||||||
|
if len(line):
|
||||||
|
samplelines.append(line)
|
||||||
|
else:
|
||||||
|
sample = SpanSample(index, samplelines, bio)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
samplelines = []
|
||||||
|
index += 1
|
||||||
|
if len(samplelines):
|
||||||
|
sample = SpanSample(index, samplelines, bio)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
classes = list(set(classes))
|
||||||
|
max_span_len = -1
|
||||||
|
long_ent_num = 0
|
||||||
|
tot_ent_num = 0
|
||||||
|
tot_tok_num = 0
|
||||||
|
for eid, sample in enumerate(samples):
|
||||||
|
max_span_len = max(max_span_len, sample.get_max_ent_len())
|
||||||
|
long_ent_num += sample.get_num_of_long_ent(10)
|
||||||
|
tot_ent_num += len(sample.spans)
|
||||||
|
tot_tok_num += len(sample.words)
|
||||||
|
# convert seq labels to target schema
|
||||||
|
new_tags = ['O' for _ in range(len(sample.words))]
|
||||||
|
for sp in sample.spans:
|
||||||
|
stype = sp[0]
|
||||||
|
sp_st = sp[1]
|
||||||
|
sp_ed = sp[2]
|
||||||
|
assert stype != "O"
|
||||||
|
if self.schema == 'IO':
|
||||||
|
for k in range(sp_st, sp_ed + 1):
|
||||||
|
new_tags[k] = "I-" + stype
|
||||||
|
elif self.schema == 'BIO':
|
||||||
|
new_tags[sp_st] = "B-" + stype
|
||||||
|
for k in range(sp_st + 1, sp_ed + 1):
|
||||||
|
new_tags[k] = "I-" + stype
|
||||||
|
elif self.schema == 'BIOES':
|
||||||
|
if sp_st == sp_ed:
|
||||||
|
new_tags[sp_st] = "S-" + stype
|
||||||
|
else:
|
||||||
|
new_tags[sp_st] = "B-" + stype
|
||||||
|
new_tags[sp_ed] = "E-" + stype
|
||||||
|
for k in range(sp_st + 1, sp_ed):
|
||||||
|
new_tags[k] = "I-" + stype
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
assert len(new_tags) == len(samples[eid].tags)
|
||||||
|
samples[eid].tags = new_tags
|
||||||
|
print("Sentence num {}, token num {}, entity num {} in file {}".format(len(samples), tot_tok_num, tot_ent_num,
|
||||||
|
filepath))
|
||||||
|
print("Total classes {}: {}".format(len(classes), str(classes)))
|
||||||
|
print("The max golden entity len in the dataset is ", max_span_len)
|
||||||
|
print("The max golden entity len in the dataset is greater than 10", long_ent_num)
|
||||||
|
print("The total coverage of spans: {:.5f}".format(1 - long_ent_num / tot_ent_num))
|
||||||
|
return samples, classes
|
||||||
|
|
||||||
|
def get_ment_word_tag(self, wtag):
|
||||||
|
if wtag == "O":
|
||||||
|
return wtag
|
||||||
|
return wtag[:2] + "X"
|
||||||
|
|
||||||
|
def __getraw__(self, sample):
|
||||||
|
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_ids, seq_lens = self.encoder.tokenize(sample.words)
|
||||||
|
sent_st_id = 0
|
||||||
|
split_seqs = []
|
||||||
|
ment_labels = []
|
||||||
|
word_labels = []
|
||||||
|
for cur_len in seq_lens:
|
||||||
|
sent_ed_id = sent_st_id + cur_len
|
||||||
|
split_seqs.append(sample.tags[sent_st_id: sent_ed_id])
|
||||||
|
cur_ment_seqs = []
|
||||||
|
cur_word_seqs = []
|
||||||
|
for wtag in split_seqs[-1]:
|
||||||
|
cur_ment_seqs.append(self.ment_tag2label[self.get_ment_word_tag(wtag)])
|
||||||
|
if wtag not in self.tag2label:
|
||||||
|
print(wtag, self.tag2label)
|
||||||
|
cur_word_seqs.append(self.tag2label[wtag])
|
||||||
|
ment_labels.append(cur_ment_seqs)
|
||||||
|
word_labels.append(cur_word_seqs)
|
||||||
|
sent_st_id += cur_len
|
||||||
|
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind, "word_to_piece_end": word_to_piece_end,
|
||||||
|
"seq_len": seq_lens, "word_shape_ids": word_shape_ids, "ment_labels": ment_labels, "word_labels": word_labels,
|
||||||
|
"subsentence_num": len(seq_lens)}
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __additem__(self, index, d, item):
|
||||||
|
d['index'].append(index)
|
||||||
|
d['word'] += item['word']
|
||||||
|
d['word_mask'] += item['word_mask']
|
||||||
|
d['seq_len'] += item['seq_len']
|
||||||
|
d['word_to_piece_ind'] += item['word_to_piece_ind']
|
||||||
|
d['word_to_piece_end'] += item['word_to_piece_end']
|
||||||
|
d['word_shape_ids'] += item['word_shape_ids']
|
||||||
|
d['ment_labels'] += item['ment_labels']
|
||||||
|
d['word_labels'] += item['word_labels']
|
||||||
|
d['subsentence_num'].append(item['subsentence_num'])
|
||||||
|
return
|
||||||
|
|
||||||
|
def __populate__(self, idx_list, savelabeldic=False):
|
||||||
|
dataset = {'index': [], 'word': [], 'word_mask': [], 'ment_labels': [], 'word_labels': [], 'word_to_piece_ind': [],
|
||||||
|
"word_to_piece_end": [], "seq_len": [], "word_shape_ids": [], "subsentence_num": []}
|
||||||
|
for idx in idx_list:
|
||||||
|
item = self.__getraw__(self.samples[idx])
|
||||||
|
self.__additem__(idx, dataset, item)
|
||||||
|
if savelabeldic:
|
||||||
|
dataset['label2tag'] = self.label2tag
|
||||||
|
dataset['sentence_num'] = len(dataset["seq_len"])
|
||||||
|
assert len(dataset['word']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['ment_labels']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_labels']) == len(dataset['seq_len'])
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
target_classes, support_idx, query_idx = self.sampler.__next__(index)
|
||||||
|
target_tags = ['O']
|
||||||
|
for cname in target_classes:
|
||||||
|
if self.schema == 'IO':
|
||||||
|
target_tags.append(f"I-{cname}")
|
||||||
|
elif self.schema == 'BIO':
|
||||||
|
target_tags.append(f"B-{cname}")
|
||||||
|
target_tags.append(f"I-{cname}")
|
||||||
|
elif self.schema == 'BIOES':
|
||||||
|
target_tags.append(f"B-{cname}")
|
||||||
|
target_tags.append(f"I-{cname}")
|
||||||
|
target_tags.append(f"E-{cname}")
|
||||||
|
target_tags.append(f"S-{cname}")
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
self.tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
|
||||||
|
self.label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
|
||||||
|
support_set = self.__populate__(support_idx, savelabeldic=False)
|
||||||
|
query_set = self.__populate__(query_idx, savelabeldic=True)
|
||||||
|
return support_set, query_set
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1000000
|
||||||
|
|
||||||
|
def batch_to_device(self, batch, device):
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k == 'sentence_num' or k == 'index' or k == 'subsentence_num' or k == 'label2tag':
|
||||||
|
continue
|
||||||
|
batch[k] = v.to(device)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq_loader(filepath, mode, encoder, batch_size, max_length, schema, bio, shuffle, debug_file=None, num_workers=8):
|
||||||
|
batcher = SeqBatcher()
|
||||||
|
dataset = SeqNERDataset(filepath, encoder, max_length, schema, bio, debug_file=debug_file)
|
||||||
|
dataloader = data.DataLoader(dataset=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=batcher)
|
||||||
|
return dataloader
|
|
@ -0,0 +1,306 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import string
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
from .adapter_layer import AdapterStack
|
||||||
|
|
||||||
|
class BERTSpanEncoder(nn.Module):
|
||||||
|
def __init__(self, pretrain_path, max_length, last_n_layer=-4, word_encode_choice='first', span_encode_choice=None,
|
||||||
|
use_width=False, width_dim=20, use_case=False, case_dim=20, drop_p=0.1, use_att=False, att_hidden_dim=100,
|
||||||
|
use_adapter=False, adapter_size=64, adapter_layer_ids=None):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.use_adapter = use_adapter
|
||||||
|
self.adapter_layer_ids = adapter_layer_ids
|
||||||
|
self.ada_layer_num = 12
|
||||||
|
if self.adapter_layer_ids is not None:
|
||||||
|
self.ada_layer_num = len(self.adapter_layer_ids)
|
||||||
|
if self.use_adapter:
|
||||||
|
from .bert_adapter import BertModel
|
||||||
|
self.bert = BertModel.from_pretrained(pretrain_path)
|
||||||
|
self.ment_adapters = AdapterStack(adapter_size, num_hidden_layers=self.ada_layer_num)
|
||||||
|
self.type_adapters = AdapterStack(adapter_size, num_hidden_layers=self.ada_layer_num)
|
||||||
|
print("use task specific adapter !!!!!!!!!")
|
||||||
|
else:
|
||||||
|
self.bert = AutoModel.from_pretrained(pretrain_path)
|
||||||
|
for n, p in self.bert.named_parameters():
|
||||||
|
if "pooler" in n:
|
||||||
|
p.requires_grad = False
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||||||
|
self.max_length = max_length
|
||||||
|
self.last_n_layer = last_n_layer
|
||||||
|
self.word_encode_choice = word_encode_choice
|
||||||
|
self.span_encode_choice = span_encode_choice
|
||||||
|
self.drop = nn.Dropout(drop_p)
|
||||||
|
self.word_dim = self.bert.config.hidden_size
|
||||||
|
|
||||||
|
self.use_att = use_att
|
||||||
|
self.att_hidden_dim = att_hidden_dim
|
||||||
|
if use_att:
|
||||||
|
self.att_layer = nn.Sequential(
|
||||||
|
nn.Linear(self.word_dim, self.att_hidden_dim),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(self.att_hidden_dim, 1)
|
||||||
|
)
|
||||||
|
if span_encode_choice is None:
|
||||||
|
span_dim = self.word_dim * 2
|
||||||
|
print("span representation is [head; tail]")
|
||||||
|
else:
|
||||||
|
span_dim = self.word_dim
|
||||||
|
print("span representation is ", self.span_encode_choice)
|
||||||
|
|
||||||
|
self.use_width = use_width
|
||||||
|
if self.use_width:
|
||||||
|
self.width_dim = width_dim
|
||||||
|
self.width_mat = nn.Embedding(50, width_dim)
|
||||||
|
span_dim = span_dim + width_dim
|
||||||
|
print("use width embedding")
|
||||||
|
self.use_case = use_case
|
||||||
|
if self.use_case:
|
||||||
|
self.case_dim = case_dim
|
||||||
|
self.case_mat = nn.Embedding(10, case_dim)
|
||||||
|
span_dim = span_dim + case_dim
|
||||||
|
print("use case embedding")
|
||||||
|
self.span_dim = span_dim
|
||||||
|
print("word dim is {}, span dim is {}".format(self.word_dim, self.span_dim))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def combine(self, seq_hidden, start_inds, end_inds, pooling='avg'):
|
||||||
|
batch_size, max_seq_len, hidden_dim = seq_hidden.size()
|
||||||
|
max_span_num = start_inds.size(1)
|
||||||
|
if pooling == 'first':
|
||||||
|
embeddings = torch.gather(seq_hidden, 1,
|
||||||
|
start_inds.unsqueeze(-1).expand(batch_size, max_span_num, hidden_dim))
|
||||||
|
elif pooling == 'avg':
|
||||||
|
device = seq_hidden.device
|
||||||
|
span_len = end_inds - start_inds + 1
|
||||||
|
max_span_len = torch.max(span_len).item()
|
||||||
|
subtoken_offset = torch.arange(0, max_span_len).to(device).view(1, 1, max_span_len).expand(batch_size,
|
||||||
|
max_span_num,
|
||||||
|
max_span_len)
|
||||||
|
subtoken_pos = subtoken_offset + start_inds.unsqueeze(-1).expand(batch_size, max_span_num, max_span_len)
|
||||||
|
subtoken_mask = subtoken_pos.le(end_inds.view(batch_size, max_span_num, 1))
|
||||||
|
subtoken_pos = subtoken_pos.masked_fill(~subtoken_mask, 0).view(batch_size, max_span_num * max_span_len,
|
||||||
|
1).expand(batch_size,
|
||||||
|
max_span_num * max_span_len,
|
||||||
|
hidden_dim)
|
||||||
|
embeddings = torch.gather(seq_hidden, 1, subtoken_pos).view(batch_size, max_span_num, max_span_len,
|
||||||
|
hidden_dim)
|
||||||
|
embeddings = embeddings * subtoken_mask.unsqueeze(-1)
|
||||||
|
embeddings = torch.div(embeddings.sum(2), span_len.unsqueeze(-1).float())
|
||||||
|
elif pooling == 'attavg':
|
||||||
|
global_weights = self.att_layer(seq_hidden.view(-1, hidden_dim)).view(batch_size, max_seq_len, 1)
|
||||||
|
seq_hidden_w = torch.cat([seq_hidden, global_weights], dim=2)
|
||||||
|
|
||||||
|
device = seq_hidden.device
|
||||||
|
span_len = end_inds - start_inds + 1
|
||||||
|
max_span_len = torch.max(span_len).item()
|
||||||
|
subtoken_offset = torch.arange(0, max_span_len).to(device).view(1, 1, max_span_len).expand(batch_size,
|
||||||
|
max_span_num,
|
||||||
|
max_span_len)
|
||||||
|
subtoken_pos = subtoken_offset + start_inds.unsqueeze(-1).expand(batch_size, max_span_num, max_span_len)
|
||||||
|
subtoken_mask = subtoken_pos.le(end_inds.view(batch_size, max_span_num, 1))
|
||||||
|
subtoken_pos = subtoken_pos.masked_fill(~subtoken_mask, 0).view(batch_size, max_span_num * max_span_len,
|
||||||
|
1).expand(batch_size,
|
||||||
|
max_span_num * max_span_len,
|
||||||
|
hidden_dim + 1)
|
||||||
|
span_w_embeddings = torch.gather(seq_hidden_w, 1, subtoken_pos).view(batch_size, max_span_num, max_span_len,
|
||||||
|
hidden_dim + 1)
|
||||||
|
span_w_embeddings = span_w_embeddings * subtoken_mask.unsqueeze(-1)
|
||||||
|
word_embeddings = span_w_embeddings[:, :, :, :-1]
|
||||||
|
word_weights = span_w_embeddings[:, :, :, -1].masked_fill(~subtoken_mask, -1e8)
|
||||||
|
word_weights = F.softmax(word_weights, dim=2).unsqueeze(3)
|
||||||
|
embeddings = (word_weights * word_embeddings).sum(2)
|
||||||
|
elif pooling == 'max':
|
||||||
|
device = seq_hidden.device
|
||||||
|
span_len = end_inds - start_inds + 1
|
||||||
|
max_span_len = torch.max(span_len).item()
|
||||||
|
subtoken_offset = torch.arange(0, max_span_len).to(device).view(1, 1, max_span_len).expand(batch_size,
|
||||||
|
max_span_num,
|
||||||
|
max_span_len)
|
||||||
|
subtoken_pos = subtoken_offset + start_inds.unsqueeze(-1).expand(batch_size, max_span_num, max_span_len)
|
||||||
|
subtoken_mask = subtoken_pos.le(
|
||||||
|
end_inds.view(batch_size, max_span_num, 1)) # batch_size, max_span_num, max_span_len
|
||||||
|
subtoken_pos = subtoken_pos.masked_fill(~subtoken_mask, 0).view(batch_size, max_span_num * max_span_len,
|
||||||
|
1).expand(batch_size,
|
||||||
|
max_span_num * max_span_len,
|
||||||
|
hidden_dim)
|
||||||
|
embeddings = torch.gather(seq_hidden, 1, subtoken_pos).view(batch_size, max_span_num, max_span_len,
|
||||||
|
hidden_dim)
|
||||||
|
embeddings = embeddings.masked_fill(
|
||||||
|
(~subtoken_mask).unsqueeze(-1).expand(batch_size, max_span_num, max_span_len, hidden_dim), -1e8).max(
|
||||||
|
dim=2)[0]
|
||||||
|
else:
|
||||||
|
raise ValueError('encode choice not in first / avg / max')
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, words, word_masks, word_to_piece_inds=None, word_to_piece_ends=None, span_indices=None, word_shape_ids=None, mode=None, bottom_hiddens=None):
|
||||||
|
assert word_to_piece_inds.size(0) == words.size(0)
|
||||||
|
assert word_to_piece_inds.size(1) <= words.size(1)
|
||||||
|
assert word_to_piece_ends.size() == word_to_piece_inds.size()
|
||||||
|
output_bottom_hiddens = None
|
||||||
|
if (mode is not None) and self.use_adapter:
|
||||||
|
if mode == 'ment':
|
||||||
|
outputs = self.bert(words, attention_mask=word_masks, output_hidden_states=True, return_dict=True, adapters=self.ment_adapters, adapter_layer_ids=self.adapter_layer_ids, input_bottom_hiddens=bottom_hiddens)
|
||||||
|
else:
|
||||||
|
outputs = self.bert(words, attention_mask=word_masks, output_hidden_states=True, return_dict=True, adapters=self.type_adapters, adapter_layer_ids=self.adapter_layer_ids, input_bottom_hiddens=bottom_hiddens)
|
||||||
|
if mode == 'ment':
|
||||||
|
output_bottom_hiddens = outputs['hidden_states'][self.adapter_layer_ids[0]]
|
||||||
|
else:
|
||||||
|
outputs = self.bert(words, attention_mask=word_masks, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
|
# use the sum of the last 4 layers
|
||||||
|
last_four_hidden_states = torch.cat(
|
||||||
|
[hidden_state.unsqueeze(0) for hidden_state in outputs['hidden_states'][self.last_n_layer:]], 0)
|
||||||
|
del outputs
|
||||||
|
piece_embeddings = torch.sum(last_four_hidden_states, 0) # [num_sent, number_of_tokens, 768]
|
||||||
|
_, piece_len, hidden_dim = piece_embeddings.size()
|
||||||
|
|
||||||
|
word_embeddings = self.combine(piece_embeddings, word_to_piece_inds, word_to_piece_ends,
|
||||||
|
self.word_encode_choice)
|
||||||
|
if span_indices is None:
|
||||||
|
word_embeddings = self.drop(word_embeddings)
|
||||||
|
if mode == 'ment':
|
||||||
|
return word_embeddings, output_bottom_hiddens
|
||||||
|
return word_embeddings
|
||||||
|
|
||||||
|
if word_shape_ids is not None:
|
||||||
|
assert word_shape_ids.size() == word_to_piece_inds.size()
|
||||||
|
assert torch.sum(span_indices[:, :, 1].lt(span_indices[:, :, 0])).item() == 0
|
||||||
|
|
||||||
|
embeds = []
|
||||||
|
if self.span_encode_choice is None:
|
||||||
|
start_word_embeddings = self.combine(word_embeddings, span_indices[:, :, 0], None, 'first')
|
||||||
|
start_word_embeddings = self.drop(start_word_embeddings)
|
||||||
|
embeds.append(start_word_embeddings)
|
||||||
|
end_word_embeddings = self.combine(word_embeddings, span_indices[:, :, 1], None, 'first')
|
||||||
|
end_word_embeddings = self.drop(end_word_embeddings)
|
||||||
|
embeds.append(end_word_embeddings)
|
||||||
|
else:
|
||||||
|
pool_embeddings = self.combine(word_embeddings, span_indices[:, :, 0], span_indices[:, :, 1],
|
||||||
|
self.span_encode_choice)
|
||||||
|
pool_embeddings = self.drop(pool_embeddings)
|
||||||
|
embeds.append(pool_embeddings)
|
||||||
|
|
||||||
|
if self.use_width:
|
||||||
|
width_embeddings = self.width_mat(span_indices[:, :, 1] - span_indices[:, :, 0])
|
||||||
|
width_embeddings = self.drop(width_embeddings)
|
||||||
|
embeds.append(width_embeddings)
|
||||||
|
|
||||||
|
if self.use_case:
|
||||||
|
case_wemb = self.case_mat(word_shape_ids)
|
||||||
|
case_embeddings = self.combine(case_wemb, span_indices[:, :, 0], span_indices[:, :, 1], 'avg')
|
||||||
|
case_embeddings = self.drop(case_embeddings)
|
||||||
|
embeds.append(case_embeddings)
|
||||||
|
|
||||||
|
span_embeddings = torch.cat(embeds, dim=-1)
|
||||||
|
span_embeddings = span_embeddings.view(span_indices.size(0), span_indices.size(1), self.span_dim)
|
||||||
|
return span_embeddings
|
||||||
|
|
||||||
|
def get_word_case(self, token):
|
||||||
|
if token.isdigit():
|
||||||
|
tfeat = 0
|
||||||
|
elif token in string.punctuation:
|
||||||
|
tfeat = 1
|
||||||
|
elif token.isupper():
|
||||||
|
tfeat = 2
|
||||||
|
elif token[0].isupper():
|
||||||
|
tfeat = 3
|
||||||
|
elif token.islower():
|
||||||
|
tfeat = 4
|
||||||
|
else:
|
||||||
|
tfeat = 5
|
||||||
|
return tfeat
|
||||||
|
|
||||||
|
def tokenize_label(self, label_name):
|
||||||
|
token_ids = self.tokenizer.encode(label_name, add_special_tokens=True, max_length=self.max_length, truncation=True)
|
||||||
|
mask = np.zeros((self.max_length), dtype=np.int32)
|
||||||
|
mask[:len(token_ids)] = 1
|
||||||
|
# padding
|
||||||
|
while len(token_ids) < self.max_length:
|
||||||
|
token_ids.append(0)
|
||||||
|
return token_ids, mask
|
||||||
|
|
||||||
|
def tokenize(self, input_tokens, true_token_flags=None):
|
||||||
|
cur_tokens = ['[CLS]']
|
||||||
|
cur_word_to_piece_ind = []
|
||||||
|
cur_word_to_piece_end = []
|
||||||
|
cur_word_shape_ind = []
|
||||||
|
raw_tokens_list = []
|
||||||
|
word_mask_list = []
|
||||||
|
indexed_tokens_list = []
|
||||||
|
word_to_piece_ind_list = []
|
||||||
|
word_to_piece_end_list = []
|
||||||
|
word_shape_ind_list = []
|
||||||
|
seq_len = []
|
||||||
|
word_flag = True
|
||||||
|
for i in range(len(input_tokens)):
|
||||||
|
word = input_tokens[i]
|
||||||
|
if true_token_flags is None:
|
||||||
|
word_flag = True
|
||||||
|
else:
|
||||||
|
word_flag = true_token_flags[i]
|
||||||
|
word_tokens = self.tokenizer.tokenize(word)
|
||||||
|
if len(word_tokens) == 0:
|
||||||
|
word_tokens = ['[UNK]']
|
||||||
|
if len(cur_tokens) + len(word_tokens) + 2 > self.max_length:
|
||||||
|
raw_tokens_list.append(cur_tokens + ['[SEP]'])
|
||||||
|
word_to_piece_ind_list.append(cur_word_to_piece_ind)
|
||||||
|
word_to_piece_end_list.append(cur_word_to_piece_end)
|
||||||
|
word_shape_ind_list.append(cur_word_shape_ind)
|
||||||
|
seq_len.append(len(cur_word_to_piece_ind))
|
||||||
|
cur_tokens = ['[CLS]'] + word_tokens
|
||||||
|
if word_flag:
|
||||||
|
cur_word_to_piece_ind = [1]
|
||||||
|
cur_word_to_piece_end = [len(cur_tokens) - 1]
|
||||||
|
cur_word_shape_ind = [self.get_word_case(word)]
|
||||||
|
else:
|
||||||
|
cur_word_to_piece_ind = []
|
||||||
|
cur_word_to_piece_end = []
|
||||||
|
cur_word_shape_ind = []
|
||||||
|
else:
|
||||||
|
if word_flag:
|
||||||
|
cur_word_to_piece_ind.append(len(cur_tokens))
|
||||||
|
cur_tokens.extend(word_tokens)
|
||||||
|
cur_word_to_piece_end.append(len(cur_tokens) - 1)
|
||||||
|
cur_word_shape_ind.append(self.get_word_case(word))
|
||||||
|
else:
|
||||||
|
cur_tokens.extend(word_tokens)
|
||||||
|
if len(cur_tokens):
|
||||||
|
assert len(cur_tokens) < self.max_length
|
||||||
|
raw_tokens_list.append(cur_tokens + ['[SEP]'])
|
||||||
|
word_to_piece_ind_list.append(cur_word_to_piece_ind)
|
||||||
|
word_to_piece_end_list.append(cur_word_to_piece_end)
|
||||||
|
word_shape_ind_list.append(cur_word_shape_ind)
|
||||||
|
seq_len.append(len(cur_word_to_piece_ind))
|
||||||
|
assert seq_len == [len(x) for x in word_to_piece_ind_list]
|
||||||
|
if true_token_flags is None:
|
||||||
|
assert sum(seq_len) == len(input_tokens)
|
||||||
|
assert len(raw_tokens_list) == len(word_to_piece_ind_list)
|
||||||
|
assert len(raw_tokens_list) == len(word_to_piece_end_list)
|
||||||
|
|
||||||
|
for raw_tokens in raw_tokens_list:
|
||||||
|
indexed_tokens = self.tokenizer.convert_tokens_to_ids(raw_tokens)
|
||||||
|
# padding
|
||||||
|
while len(indexed_tokens) < self.max_length:
|
||||||
|
indexed_tokens.append(0)
|
||||||
|
indexed_tokens_list.append(indexed_tokens)
|
||||||
|
if len(indexed_tokens) != self.max_length:
|
||||||
|
print(input_tokens)
|
||||||
|
print(raw_tokens)
|
||||||
|
raise ValueError
|
||||||
|
assert len(indexed_tokens) == self.max_length
|
||||||
|
# mask
|
||||||
|
mask = np.zeros((self.max_length), dtype=np.int32)
|
||||||
|
mask[:len(raw_tokens)] = 1
|
||||||
|
word_mask_list.append(mask)
|
||||||
|
sent_num = len(indexed_tokens_list)
|
||||||
|
assert sent_num == len(word_mask_list)
|
||||||
|
assert sent_num == len(word_to_piece_ind_list)
|
||||||
|
assert sent_num == len(seq_len)
|
||||||
|
return indexed_tokens_list, word_mask_list, word_to_piece_ind_list, word_to_piece_end_list, word_shape_ind_list, seq_len
|
|
@ -0,0 +1,477 @@
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import os
|
||||||
|
from .fewshotsampler import FewshotSampler, DebugSampler
|
||||||
|
from .span_sample import SpanSample
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
import logging, pickle, gzip
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySpanSample(SpanSample):
|
||||||
|
def __init__(self, index, filelines, ment_prob_list, bio=True):
|
||||||
|
super(QuerySpanSample, self).__init__(index, filelines, bio)
|
||||||
|
self.query_ments = []
|
||||||
|
self.query_probs = []
|
||||||
|
for i in range(len(ment_prob_list)):
|
||||||
|
self.query_ments.append(ment_prob_list[i][0])
|
||||||
|
self.query_probs.append(ment_prob_list[i][1])
|
||||||
|
return
|
||||||
|
|
||||||
|
class QuerySpanBatcher:
|
||||||
|
def __init__(self, iou_thred, use_oproto):
|
||||||
|
self.iou_thred = iou_thred
|
||||||
|
self.use_oproto = use_oproto
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_span_weights(self, batch_triples, seq_len, span_indices, span_tags, thred=0.6, alpha=1):
|
||||||
|
span_weights = []
|
||||||
|
for k in range(len(span_indices)):
|
||||||
|
seq_tags = np.zeros(seq_len[k], dtype=np.int)
|
||||||
|
span_st_inds = np.zeros(seq_len[k], dtype=np.int)
|
||||||
|
span_ed_inds = np.zeros(seq_len[k], dtype=np.int)
|
||||||
|
span_weights.append([1] * len(span_indices[k]))
|
||||||
|
for [r, i, j] in batch_triples[k]:
|
||||||
|
seq_tags[i:j + 1] = r
|
||||||
|
span_st_inds[i:j + 1] = i
|
||||||
|
span_ed_inds[i:j + 1] = j
|
||||||
|
for sp_idx in range(len(span_indices[k])):
|
||||||
|
sp_st = span_indices[k][sp_idx][0]
|
||||||
|
sp_ed = span_indices[k][sp_idx][1]
|
||||||
|
sp_tag = span_tags[k][sp_idx]
|
||||||
|
if sp_tag != 0:
|
||||||
|
continue
|
||||||
|
cur_token_tags = list(seq_tags[sp_st: sp_ed + 1])
|
||||||
|
max_tag = max(cur_token_tags, key=cur_token_tags.count)
|
||||||
|
anchor_idx = cur_token_tags.index(max_tag) + sp_st
|
||||||
|
if max_tag == 0:
|
||||||
|
continue
|
||||||
|
cur_ids = set(range(sp_st, sp_ed + 1))
|
||||||
|
ref_ids = set(range(span_st_inds[anchor_idx], span_ed_inds[anchor_idx] + 1))
|
||||||
|
tag_percent = len(cur_ids & ref_ids) / len(cur_ids | ref_ids)
|
||||||
|
if tag_percent >= thred:
|
||||||
|
span_tags[k][sp_idx] = max_tag
|
||||||
|
span_weights[k][sp_idx] = tag_percent ** alpha
|
||||||
|
else:
|
||||||
|
span_weights[k][sp_idx] = (1 - tag_percent) ** alpha
|
||||||
|
return span_tags, span_weights
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_golden_batch(batch_triples, seq_len):
|
||||||
|
span_indices = []
|
||||||
|
span_tags = []
|
||||||
|
for k in range(len(batch_triples)):
|
||||||
|
per_sp_indices = []
|
||||||
|
per_sp_tags = []
|
||||||
|
for (r, i, j) in batch_triples[k]:
|
||||||
|
if j < seq_len[k]:
|
||||||
|
per_sp_indices.append([i, j])
|
||||||
|
per_sp_tags.append(r)
|
||||||
|
else:
|
||||||
|
print("something error")
|
||||||
|
span_indices.append(per_sp_indices)
|
||||||
|
span_tags.append(per_sp_tags)
|
||||||
|
return span_indices, span_tags
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_train_query_batch(query_ments, query_probs, golden_triples, seq_len):
|
||||||
|
span_indices = []
|
||||||
|
ment_probs = []
|
||||||
|
span_tags = []
|
||||||
|
for k in range(len(query_ments)):
|
||||||
|
per_sp_indices = []
|
||||||
|
per_sp_tags = []
|
||||||
|
per_ment_probs = []
|
||||||
|
per_tag_mp = {(i, j): tag for tag, i, j in golden_triples[k]}
|
||||||
|
for [i, j], prob in zip(query_ments[k], query_probs[k]):
|
||||||
|
if j < seq_len[k]:
|
||||||
|
per_sp_indices.append([i, j])
|
||||||
|
per_sp_tags.append(per_tag_mp.get((i, j), 0))
|
||||||
|
per_ment_probs.append(prob)
|
||||||
|
else:
|
||||||
|
print("something error")
|
||||||
|
#add golden mentions into query batch
|
||||||
|
for [r, i, j] in golden_triples[k]:
|
||||||
|
if [i, j] not in per_sp_indices:
|
||||||
|
per_sp_indices.append([i, j])
|
||||||
|
per_sp_tags.append(r)
|
||||||
|
per_ment_probs.append(0)
|
||||||
|
span_indices.append(per_sp_indices)
|
||||||
|
span_tags.append(per_sp_tags)
|
||||||
|
ment_probs.append(per_ment_probs)
|
||||||
|
return span_indices, span_tags, ment_probs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_test_query_batch(query_ments, query_probs, golden_triples, seq_len):
|
||||||
|
span_indices = []
|
||||||
|
ment_probs = []
|
||||||
|
span_tags = []
|
||||||
|
for k in range(len(query_ments)):
|
||||||
|
per_sp_indices = []
|
||||||
|
per_sp_tags = []
|
||||||
|
per_ment_probs = []
|
||||||
|
per_tag_mp = {(i, j): tag for tag, i, j in golden_triples[k]}
|
||||||
|
for [i, j], prob in zip(query_ments[k], query_probs[k]):
|
||||||
|
if j < seq_len[k]:
|
||||||
|
per_sp_indices.append([i, j])
|
||||||
|
per_sp_tags.append(per_tag_mp.get((i, j), 0))
|
||||||
|
per_ment_probs.append(prob)
|
||||||
|
else:
|
||||||
|
print("something error")
|
||||||
|
span_indices.append(per_sp_indices)
|
||||||
|
span_tags.append(per_sp_tags)
|
||||||
|
ment_probs.append(per_ment_probs)
|
||||||
|
return span_indices, span_tags, ment_probs
|
||||||
|
|
||||||
|
def make_support_batch(self, batch):
|
||||||
|
span_indices, span_tags = self._make_golden_batch(batch["spans"], batch["seq_len"])
|
||||||
|
if self.use_oproto:
|
||||||
|
used_token_flag = []
|
||||||
|
for k in range(len(batch["spans"])):
|
||||||
|
used_token_flag.append(np.zeros(batch["seq_len"][k]))
|
||||||
|
for [r, sp_st, sp_ed] in batch["spans"][k]:
|
||||||
|
used_token_flag[k][sp_st: sp_ed + 1] = 1
|
||||||
|
for k in range(len(batch["seq_len"])):
|
||||||
|
for j in range(batch["seq_len"][k]):
|
||||||
|
if used_token_flag[k][j] == 1:
|
||||||
|
continue
|
||||||
|
if [j, j] not in span_indices[k]:
|
||||||
|
span_indices[k].append([j, j])
|
||||||
|
span_tags[k].append(0)
|
||||||
|
else:
|
||||||
|
print("something error")
|
||||||
|
span_nums = [len(x) for x in span_indices]
|
||||||
|
max_n_spans = max(span_nums)
|
||||||
|
span_masks = np.zeros(shape=(len(span_indices), max_n_spans), dtype='int')
|
||||||
|
for k in range(len(span_indices)):
|
||||||
|
span_masks[k, :span_nums[k]] = 1
|
||||||
|
while len(span_tags[k]) < max_n_spans:
|
||||||
|
span_indices[k].append([0, 0])
|
||||||
|
span_tags[k].append(-1)
|
||||||
|
batch["span_indices"] = span_indices
|
||||||
|
batch["span_mask"] = span_masks
|
||||||
|
batch["span_tag"] = span_tags
|
||||||
|
# all example with equal weights
|
||||||
|
batch["span_weights"] = np.ones(shape=span_masks.shape, dtype='float')
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def make_query_batch(self, batch, is_train):
|
||||||
|
if is_train:
|
||||||
|
span_indices, span_tags, ment_probs = self._make_train_query_batch(batch["query_ments"],
|
||||||
|
batch["query_probs"], batch["spans"],
|
||||||
|
batch["seq_len"])
|
||||||
|
if self.iou_thred is None:
|
||||||
|
span_weights = None
|
||||||
|
else:
|
||||||
|
span_tags, span_weights = self._get_span_weights(batch["spans"], batch["seq_len"], span_indices,
|
||||||
|
span_tags, thred=self.iou_thred)
|
||||||
|
|
||||||
|
else:
|
||||||
|
span_indices, span_tags, ment_probs = self._make_test_query_batch(batch["query_ments"],
|
||||||
|
batch["query_probs"], batch["spans"],
|
||||||
|
batch["seq_len"])
|
||||||
|
span_weights = None
|
||||||
|
|
||||||
|
span_nums = [len(x) for x in span_indices]
|
||||||
|
max_n_spans = max(span_nums)
|
||||||
|
span_masks = np.zeros(shape=(len(span_indices), max_n_spans), dtype='int')
|
||||||
|
new_span_weights = np.ones(shape=span_masks.shape, dtype='float')
|
||||||
|
for k in range(len(span_indices)):
|
||||||
|
span_masks[k, :span_nums[k]] = 1
|
||||||
|
if span_weights is not None:
|
||||||
|
new_span_weights[k, :span_nums[k]] = span_weights[k]
|
||||||
|
while len(span_tags[k]) < max_n_spans:
|
||||||
|
span_indices[k].append([0, 0])
|
||||||
|
span_tags[k].append(-1)
|
||||||
|
ment_probs[k].append(-100)
|
||||||
|
batch["span_indices"] = span_indices
|
||||||
|
batch["span_mask"] = span_masks
|
||||||
|
batch["span_tag"] = span_tags
|
||||||
|
batch["span_probs"] = ment_probs
|
||||||
|
batch["span_weights"] = new_span_weights
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def batchnize_episode(self, data, mode):
|
||||||
|
support_sets, query_sets = zip(*data)
|
||||||
|
if mode == "train":
|
||||||
|
is_train = True
|
||||||
|
else:
|
||||||
|
is_train = False
|
||||||
|
batch_support = self.batchnize_sent(support_sets, "support", is_train)
|
||||||
|
batch_query = self.batchnize_sent(query_sets, "query", is_train)
|
||||||
|
batch_query["label2tag"] = []
|
||||||
|
for i in range(len(query_sets)):
|
||||||
|
batch_query["label2tag"].append(query_sets[i]["label2tag"])
|
||||||
|
return {"support": batch_support, "query": batch_query}
|
||||||
|
|
||||||
|
def batchnize_sent(self, data, mode, is_train):
|
||||||
|
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
|
||||||
|
"seq_len": [],
|
||||||
|
"spans": [], "sentence_num": [], "query_ments": [], "query_probs": [], "subsentence_num": [],
|
||||||
|
'split_words': []}
|
||||||
|
for i in range(len(data)):
|
||||||
|
for k in batch.keys():
|
||||||
|
if k == 'sentence_num':
|
||||||
|
batch[k].append(data[i][k])
|
||||||
|
else:
|
||||||
|
batch[k] += data[i][k]
|
||||||
|
|
||||||
|
max_n_piece = max([sum(x) for x in batch['word_mask']])
|
||||||
|
|
||||||
|
max_n_words = max(batch["seq_len"])
|
||||||
|
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
|
||||||
|
for k, slen in enumerate(batch['seq_len']):
|
||||||
|
assert len(batch['word_to_piece_ind'][k]) == slen
|
||||||
|
assert len(batch['word_to_piece_end'][k]) == slen
|
||||||
|
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
|
||||||
|
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
|
||||||
|
batch['word_to_piece_ind'] = word_to_piece_ind
|
||||||
|
batch['word_to_piece_end'] = word_to_piece_end
|
||||||
|
if mode == "support":
|
||||||
|
batch = self.make_support_batch(batch)
|
||||||
|
else:
|
||||||
|
batch = self.make_query_batch(batch, is_train)
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k not in ['spans', 'sentence_num', 'label2tag', 'index', "query_ments", "query_probs", "subsentence_num",
|
||||||
|
"split_words"]:
|
||||||
|
v = np.array(v)
|
||||||
|
if k == "span_weights":
|
||||||
|
batch[k] = torch.tensor(v).float()
|
||||||
|
else:
|
||||||
|
batch[k] = torch.tensor(v).long()
|
||||||
|
batch['word'] = batch['word'][:, :max_n_piece]
|
||||||
|
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def __call__(self, batch, mode):
|
||||||
|
return self.batchnize_episode(batch, mode)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySpanNERDataset(data.Dataset):
|
||||||
|
"""
|
||||||
|
Fewshot NER Dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filepath, encoder, N, K, Q, max_length, \
|
||||||
|
bio=True, debug_file=None, query_file=None, hidden_query_label=False, labelname_fn=None):
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
print("[ERROR] Data file does not exist!")
|
||||||
|
assert (0)
|
||||||
|
self.class2sampleid = {}
|
||||||
|
self.N = N
|
||||||
|
self.K = K
|
||||||
|
self.Q = Q
|
||||||
|
self.encoder = encoder
|
||||||
|
self.cur_t = 0
|
||||||
|
self.samples, self.classes = self.__load_data_from_file__(filepath, bio)
|
||||||
|
if labelname_fn:
|
||||||
|
self.class2name = self.__load_names__(labelname_fn)
|
||||||
|
else:
|
||||||
|
self.class2name = None
|
||||||
|
self.sampler = FewshotSampler(N, K, Q, self.samples, classes=self.classes)
|
||||||
|
if debug_file:
|
||||||
|
if query_file is None:
|
||||||
|
print("use golden mention for typing !!! input_fn: {}".format(filepath))
|
||||||
|
self.sampler = DebugSampler(debug_file, query_file)
|
||||||
|
self.max_length = max_length
|
||||||
|
self.tag2label = None
|
||||||
|
self.label2tag = None
|
||||||
|
self.hidden_query_label = hidden_query_label
|
||||||
|
|
||||||
|
if self.hidden_query_label:
|
||||||
|
print("Attention! This dataset hidden query set labels !")
|
||||||
|
|
||||||
|
def __insert_sample__(self, index, sample_classes):
|
||||||
|
for item in sample_classes:
|
||||||
|
if item in self.class2sampleid:
|
||||||
|
self.class2sampleid[item].append(index)
|
||||||
|
else:
|
||||||
|
self.class2sampleid[item] = [index]
|
||||||
|
return
|
||||||
|
|
||||||
|
def __load_names__(self, label_mp_fn):
|
||||||
|
class2str = {}
|
||||||
|
with open(label_mp_fn, mode="r", encoding="utf-8") as fp:
|
||||||
|
for line in fp:
|
||||||
|
label, lstr = line.strip("\n").split(":")
|
||||||
|
class2str[label.strip()] = lstr.strip()
|
||||||
|
return class2str
|
||||||
|
|
||||||
|
def __load_data_from_file__(self, filepath, bio):
|
||||||
|
print("load input text from {}".format(filepath))
|
||||||
|
samples = []
|
||||||
|
classes = []
|
||||||
|
with open(filepath, 'r', encoding='utf-8')as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
samplelines = []
|
||||||
|
index = 0
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip("\n")
|
||||||
|
if len(line):
|
||||||
|
samplelines.append(line)
|
||||||
|
else:
|
||||||
|
cur_ments = []
|
||||||
|
sample = QuerySpanSample(index, samplelines, cur_ments, bio)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
samplelines = []
|
||||||
|
index += 1
|
||||||
|
if len(samplelines):
|
||||||
|
cur_ments = []
|
||||||
|
sample = QuerySpanSample(index, samplelines, cur_ments, bio)
|
||||||
|
samples.append(sample)
|
||||||
|
sample_classes = sample.get_tag_class()
|
||||||
|
self.__insert_sample__(index, sample_classes)
|
||||||
|
classes += sample_classes
|
||||||
|
classes = list(set(classes))
|
||||||
|
max_span_len = -1
|
||||||
|
long_ent_num = 0
|
||||||
|
tot_ent_num = 0
|
||||||
|
tot_tok_num = 0
|
||||||
|
|
||||||
|
for sample in samples:
|
||||||
|
max_span_len = max(max_span_len, sample.get_max_ent_len())
|
||||||
|
long_ent_num += sample.get_num_of_long_ent(10)
|
||||||
|
tot_ent_num += len(sample.spans)
|
||||||
|
tot_tok_num += len(sample.words)
|
||||||
|
print("Sentence num {}, token num {}, entity num {} in file {}".format(len(samples), tot_tok_num, tot_ent_num,
|
||||||
|
filepath))
|
||||||
|
print("Total classes {}: {}".format(len(classes), str(classes)))
|
||||||
|
print("The max golden entity len in the dataset is ", max_span_len)
|
||||||
|
print("The max golden entity len in the dataset is greater than 10", long_ent_num)
|
||||||
|
print("The total coverage of spans: {:.5f}".format(1 - long_ent_num / tot_ent_num))
|
||||||
|
return samples, classes
|
||||||
|
|
||||||
|
|
||||||
|
def __getraw__(self, sample, add_split):
|
||||||
|
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_inds, seq_lens = self.encoder.tokenize(sample.words, true_token_flags=None)
|
||||||
|
sent_st_id = 0
|
||||||
|
split_spans = []
|
||||||
|
split_querys = []
|
||||||
|
split_probs = []
|
||||||
|
|
||||||
|
split_words = []
|
||||||
|
cur_wid = 0
|
||||||
|
for k in range(len(seq_lens)):
|
||||||
|
split_words.append(sample.words[cur_wid: cur_wid + seq_lens[k]])
|
||||||
|
cur_wid += seq_lens[k]
|
||||||
|
|
||||||
|
for cur_len in seq_lens:
|
||||||
|
sent_ed_id = sent_st_id + cur_len
|
||||||
|
split_spans.append([])
|
||||||
|
for tag, span_st, span_ed in sample.spans:
|
||||||
|
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
|
||||||
|
continue
|
||||||
|
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
|
||||||
|
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, span_ed - sent_st_id])
|
||||||
|
elif add_split:
|
||||||
|
if span_st >= sent_st_id:
|
||||||
|
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
|
||||||
|
else:
|
||||||
|
split_spans[-1].append([self.tag2label[tag], 0, span_ed - sent_st_id])
|
||||||
|
split_querys.append([])
|
||||||
|
split_probs.append([])
|
||||||
|
for [span_st, span_ed], span_prob in zip(sample.query_ments, sample.query_probs):
|
||||||
|
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
|
||||||
|
continue
|
||||||
|
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
|
||||||
|
split_querys[-1].append([span_st - sent_st_id, span_ed - sent_st_id])
|
||||||
|
split_probs[-1].append(span_prob)
|
||||||
|
elif add_split:
|
||||||
|
if span_st >= sent_st_id:
|
||||||
|
split_querys[-1].append([span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
|
||||||
|
split_probs[-1].append(span_prob)
|
||||||
|
else:
|
||||||
|
split_querys[-1].append([0, span_ed - sent_st_id])
|
||||||
|
split_probs[-1].append(span_prob)
|
||||||
|
sent_st_id += cur_len
|
||||||
|
|
||||||
|
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind,
|
||||||
|
"word_to_piece_end": word_to_piece_end, "spans": split_spans, "query_ments": split_querys, "query_probs": split_probs, "seq_len": seq_lens,
|
||||||
|
"subsentence_num": len(seq_lens), "split_words": split_words}
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __additem__(self, index, d, item):
|
||||||
|
d['index'].append(index)
|
||||||
|
d['word'] += item['word']
|
||||||
|
d['word_mask'] += item['word_mask']
|
||||||
|
d['seq_len'] += item['seq_len']
|
||||||
|
d['word_to_piece_ind'] += item['word_to_piece_ind']
|
||||||
|
d['word_to_piece_end'] += item['word_to_piece_end']
|
||||||
|
d['spans'] += item['spans']
|
||||||
|
d['query_ments'] += item['query_ments']
|
||||||
|
d['query_probs'] += item['query_probs']
|
||||||
|
d['subsentence_num'].append(item['subsentence_num'])
|
||||||
|
d['split_words'] += item['split_words']
|
||||||
|
return
|
||||||
|
|
||||||
|
def __populate__(self, idx_list, query_ment_mp=None, savelabeldic=False, add_split=False):
|
||||||
|
dataset = {'index': [], 'word': [], 'word_mask': [], 'spans': [], 'word_to_piece_ind': [],
|
||||||
|
"word_to_piece_end": [], "seq_len": [], "query_ments": [], "query_probs": [], "subsentence_num": [],
|
||||||
|
'split_words': []}
|
||||||
|
for idx in idx_list:
|
||||||
|
if query_ment_mp is not None:
|
||||||
|
self.samples[idx].query_ments = [x[0] for x in query_ment_mp[str(self.samples[idx].index)]]
|
||||||
|
self.samples[idx].query_probs = [x[1] for x in query_ment_mp[str(self.samples[idx].index)]]
|
||||||
|
else:
|
||||||
|
self.samples[idx].query_ments = [[x[1], x[2]] for x in self.samples[idx].spans]
|
||||||
|
self.samples[idx].query_probs = [1 for x in self.samples[idx].spans]
|
||||||
|
item = self.__getraw__(self.samples[idx], add_split)
|
||||||
|
self.__additem__(idx, dataset, item)
|
||||||
|
dataset['sentence_num'] = len(dataset["seq_len"])
|
||||||
|
assert len(dataset['word']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
|
||||||
|
assert len(dataset['spans']) == len(dataset['seq_len'])
|
||||||
|
if savelabeldic:
|
||||||
|
dataset['label2tag'] = self.label2tag
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
tmp = self.sampler.__next__(index)
|
||||||
|
if len(tmp) == 3:
|
||||||
|
target_classes, support_idx, query_idx = tmp
|
||||||
|
query_ment_mp = None
|
||||||
|
else:
|
||||||
|
target_classes, support_idx, query_idx, query_ment_mp = tmp
|
||||||
|
target_classes = sorted(target_classes)
|
||||||
|
distinct_tags = ['O'] + target_classes
|
||||||
|
self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
|
||||||
|
self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
|
||||||
|
support_set = self.__populate__(support_idx, add_split=True)
|
||||||
|
query_set = self.__populate__(query_idx, query_ment_mp=query_ment_mp, add_split=True, savelabeldic=True)
|
||||||
|
if self.hidden_query_label:
|
||||||
|
query_set['spans'] = [[] for i in range(query_set['sentence_num'])]
|
||||||
|
return support_set, query_set
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1000000
|
||||||
|
|
||||||
|
def batch_to_device(self, batch, device):
|
||||||
|
for k, v in batch.items():
|
||||||
|
if k in ['sentence_num', 'label2tag', 'spans', 'index', 'query_ments', 'query_probs', "subsentence_num", 'split_words']:
|
||||||
|
continue
|
||||||
|
batch[k] = v.to(device)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def get_query_loader(filepath, mode, encoder, N, K, Q, batch_size, max_length,
|
||||||
|
bio, shuffle, num_workers=8, debug_file=None, query_file=None,
|
||||||
|
iou_thred=None, hidden_query_label=False, label_fn=None, use_oproto=False):
|
||||||
|
batcher = QuerySpanBatcher(iou_thred, use_oproto)
|
||||||
|
dataset = QuerySpanNERDataset(filepath, encoder, N, K, Q, max_length, bio,
|
||||||
|
debug_file=debug_file, query_file=query_file,
|
||||||
|
hidden_query_label=hidden_query_label, labelname_fn=label_fn)
|
||||||
|
dataloader = data.DataLoader(dataset=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=lambda x: batcher(x, mode))
|
||||||
|
return dataloader
|
|
@ -0,0 +1,145 @@
|
||||||
|
from .fewshotsampler import FewshotSampleBase
|
||||||
|
|
||||||
|
def get_class_name(rawtag):
|
||||||
|
if rawtag.startswith('B-') or rawtag.startswith('I-'):
|
||||||
|
return rawtag[2:]
|
||||||
|
else:
|
||||||
|
return rawtag
|
||||||
|
|
||||||
|
def convert_bio2spans(tags, schema):
|
||||||
|
spans = []
|
||||||
|
cur_span = []
|
||||||
|
err_cnt = 0
|
||||||
|
for i in range(len(tags)):
|
||||||
|
if schema == 'BIO':
|
||||||
|
if tags[i].startswith("B-") or tags[i] == 'O':
|
||||||
|
if len(cur_span):
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = []
|
||||||
|
if tags[i].startswith("B-"):
|
||||||
|
cur_span.append(tags[i][2:])
|
||||||
|
cur_span.append(i)
|
||||||
|
elif tags[i].startswith("I-"):
|
||||||
|
if len(cur_span) == 0:
|
||||||
|
cur_span = [tags[i][2:], i]
|
||||||
|
err_cnt += 1
|
||||||
|
if cur_span[0] != tags[i][2:]:
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = [tags[i][2:], i]
|
||||||
|
err_cnt += 1
|
||||||
|
assert cur_span[0] == tags[i][2:]
|
||||||
|
else:
|
||||||
|
assert tags[i] == "O"
|
||||||
|
elif schema == 'IO':
|
||||||
|
if tags[i] == "O":
|
||||||
|
if len(cur_span):
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = []
|
||||||
|
elif (i == 0) or (tags[i] != tags[i - 1]):
|
||||||
|
if len(cur_span):
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = []
|
||||||
|
cur_span.append(tags[i].strip("I-"))
|
||||||
|
cur_span.append(i)
|
||||||
|
else:
|
||||||
|
assert cur_span[0] == tags[i].strip("I-")
|
||||||
|
elif schema == "BIOES":
|
||||||
|
if tags[i] == "O":
|
||||||
|
if len(cur_span):
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = []
|
||||||
|
elif tags[i][0] == "S":
|
||||||
|
if len(cur_span):
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = []
|
||||||
|
spans.append([tags[i][2:], i, i])
|
||||||
|
elif tags[i][0] == "E":
|
||||||
|
if len(cur_span) == 0:
|
||||||
|
err_cnt += 1
|
||||||
|
continue
|
||||||
|
cur_span.append(i)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = []
|
||||||
|
elif tags[i][0] == "B":
|
||||||
|
if len(cur_span):
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = []
|
||||||
|
cur_span = [tags[i][2:], i]
|
||||||
|
else:
|
||||||
|
if len(cur_span) == 0:
|
||||||
|
cur_span = [tags[i][2:], i]
|
||||||
|
err_cnt += 1
|
||||||
|
continue
|
||||||
|
if cur_span[0] != tags[i][2:]:
|
||||||
|
cur_span.append(i - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
cur_span = [tags[i][2:], i]
|
||||||
|
err_cnt += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
if len(cur_span):
|
||||||
|
cur_span.append(len(tags) - 1)
|
||||||
|
spans.append(cur_span)
|
||||||
|
return spans
|
||||||
|
|
||||||
|
class SpanSample(FewshotSampleBase):
|
||||||
|
def __init__(self, idx, filelines, bio=True):
|
||||||
|
super(SpanSample, self).__init__()
|
||||||
|
self.index = idx
|
||||||
|
filelines = [line.split('\t') for line in filelines]
|
||||||
|
if len(filelines[0]) == 2:
|
||||||
|
self.words, self.tags = zip(*filelines)
|
||||||
|
else:
|
||||||
|
self.words, self.postags, self.tags = zip(*filelines)
|
||||||
|
self.spans = convert_bio2spans(self.tags, "BIO" if bio else "IO")
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_max_ent_len(self):
|
||||||
|
max_len = -1
|
||||||
|
for sp in self.spans:
|
||||||
|
max_len = max(max_len, sp[2] - sp[1] + 1)
|
||||||
|
return max_len
|
||||||
|
|
||||||
|
def get_num_of_long_ent(self, max_span_len):
|
||||||
|
cnt = 0
|
||||||
|
for sp in self.spans:
|
||||||
|
if sp[2] - sp[1] + 1 > max_span_len:
|
||||||
|
cnt += 1
|
||||||
|
return cnt
|
||||||
|
|
||||||
|
def __count_entities__(self):
|
||||||
|
self.class_count = {}
|
||||||
|
for tag, i, j in self.spans:
|
||||||
|
if tag in self.class_count:
|
||||||
|
self.class_count[tag] += 1
|
||||||
|
else:
|
||||||
|
self.class_count[tag] = 1
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_class_count(self):
|
||||||
|
if self.class_count:
|
||||||
|
return self.class_count
|
||||||
|
else:
|
||||||
|
self.__count_entities__()
|
||||||
|
return self.class_count
|
||||||
|
|
||||||
|
def get_tag_class(self):
|
||||||
|
tag_class = list(set(map(lambda x: x[2:] if x[:2] in ['B-', 'I-', 'E-', 'S-'] else x, self.tags)))
|
||||||
|
if 'O' in tag_class:
|
||||||
|
tag_class.remove('O')
|
||||||
|
return tag_class
|
||||||
|
|
||||||
|
def valid(self, target_classes):
|
||||||
|
return (set(self.get_class_count().keys()).intersection(set(target_classes))) and \
|
||||||
|
not (set(self.get_class_count().keys()).difference(set(target_classes)))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return
|
Загрузка…
Ссылка в новой задаче