This commit is contained in:
t-tingtingma_microsoft 2024-03-15 14:44:27 +00:00
Родитель e8cca59e75
Коммит e062c845ae
39 изменённых файлов: 7292 добавлений и 1 удалений

Просмотреть файл

@ -1 +1,76 @@
Code will be public once the paper is accepted.
# Decomposed Meta-Learning for Few-Shot Sequence Labeling
This repository contains the open-sourced official implementation of the paper:
[Decomposed Meta-Learning for Few-Shot Sequence Labeling](https://ieeexplore.ieee.org/abstract/document/10458261) (TASLP).
_Tingting Ma, Qianhui Wu, Huiqiang Jiang, Jieru Lin, Börje F. Karlsson, Tiejun Zhao, and Chin-Yew Lin_
If you find this repo helpful, please cite the following paper
```bibtex
@ARTICLE{ma-etal-2024-decomposedmetasl,
author={Ma, Tingting and Wu, Qianhui and Jiang, Huiqiang and Lin, Jieru and Karlsson, Börje F. and Zhao, Tiejun and Lin, Chin-Yew},
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
title={Decomposed Meta-Learning for Few-Shot Sequence Labeling},
year={2024},
volume={},
number={},
pages={1-14},
doi={10.1109/TASLP.2024.3372879}
}
```
For any questions/comments, please feel free to open GitHub issues.
### Requirements
- python 3.9.17
- pytorch 1.9.1+cu111
- [HuggingFace Transformers 4.10.0](https://github.com/huggingface/transformers)
### Input data format
train/dev/test_N_K_id.jsonl:
Each line contains the following fields:
1. `target_classes`: A list of types (e.g., "event-other," "person-scholar").
2. `query_idx`: A list of indexes corresponding to query sentences for the i-th instance in train/dev/test.txt.
3. `support_idx`: A list of indexes corresponding to support sentences for the i-th instance in train/dev/test.txt.
### Train and Evaluate
For _Seperate_ model,
```bash
bash scripts/train_ment.sh
bash scripts/train_type.sh
bash scripts/eval_sep.sh
```
For _Joint_ model,
```bash
bash scripts/train_joint.sh
bash scripts/eval_joint.sh
```
For _POS tagging_ task, run
```bash
bash scripts/train_pos.sh
```
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

Просмотреть файл

@ -0,0 +1,10 @@
{"target_classes": ["event-other", "person-scholar", "organization-showorganization", "building-other", "organization-religion"], "query_idx": [2944, 11433, 13834, 9686, 14970], "support_idx": [15362, 12956, 9225, 6364, 17558]}
{"target_classes": ["other-currency", "organization-religion", "building-other", "organization-showorganization", "person-other"], "query_idx": [16913, 16269, 1550, 5031, 6583], "support_idx": [365, 14279, 7493, 5495]}
{"target_classes": ["product-train", "art-painting", "person-scholar", "organization-showorganization", "building-other"], "query_idx": [12906, 17326, 2405, 12872, 9081], "support_idx": [9686, 1202, 1422, 3779]}
{"target_classes": ["location-park", "building-other", "product-train", "other-currency", "organization-religion"], "query_idx": [18096, 7281, 4038, 11566, 14149], "support_idx": [16798, 17009, 11792, 1476, 67]}
{"target_classes": ["product-game", "location-park", "organization-showorganization", "building-other", "other-currency"], "query_idx": [1333, 16185, 15369, 13226, 18582], "support_idx": [4179, 12320, 4168, 5272, 557]}
{"target_classes": ["product-train", "art-painting", "event-other", "building-library", "building-other"], "query_idx": [14660, 2874, 17317, 16888, 4034], "support_idx": [149, 5083, 13658, 12898, 1242]}
{"target_classes": ["other-chemicalthing", "person-scholar", "other-currency", "building-other", "building-library"], "query_idx": [16000, 12005, 8526, 2955, 3635], "support_idx": [14882, 16916, 2826, 3466, 4354]}
{"target_classes": ["building-other", "location-park", "organization-religion", "building-library", "product-train"], "query_idx": [9041, 17985, 11146, 9223, 3569], "support_idx": [15817, 3343, 15093, 3147, 18407]}
{"target_classes": ["building-other", "organization-showorganization", "other-currency", "building-library", "product-game"], "query_idx": [14602, 14513, 3689, 16122, 11496], "support_idx": [213, 8006, 11213, 3411, 9197]}
{"target_classes": ["organization-religion", "other-chemicalthing", "location-park", "person-scholar", "art-painting"], "query_idx": [3652, 12342, 14141, 7028, 1242], "support_idx": [16594, 2171, 6945, 6765, 12307]}

Просмотреть файл

@ -0,0 +1,10 @@
{"target_classes": ["art-writtenart", "other-livingthing", "person-athlete", "building-theater", "product-car"], "query_idx": [5475, 2456, 7708, 13599, 1070], "support_idx": [12355, 10289, 7443, 3571, 4689]}
{"target_classes": ["event-sportsevent", "art-writtenart", "other-livingthing", "location-bodiesofwater", "product-car"], "query_idx": [13726, 10244, 2481, 9620, 5067], "support_idx": [8120, 6841, 6299, 8113, 12566]}
{"target_classes": ["other-medical", "other-livingthing", "organization-government/governmentagency", "person-athlete", "product-weapon"], "query_idx": [6730, 7660, 9544, 10325, 752], "support_idx": [3131, 7347, 12463, 7987, 13812]}
{"target_classes": ["location-other", "art-writtenart", "other-educationaldegree", "building-theater", "event-election"], "query_idx": [8314, 6694, 7124, 9619, 11697], "support_idx": [11050, 13350, 4019, 6025, 11376]}
{"target_classes": ["person-athlete", "organization-politicalparty", "location-bodiesofwater", "event-election", "event-sportsevent"], "query_idx": [7482, 11414, 1001, 13487], "support_idx": [5012, 4715, 13625, 9276, 6870]}
{"target_classes": ["art-music", "location-bodiesofwater", "organization-politicalparty", "product-car", "event-sportsevent"], "query_idx": [9002, 7862, 7715, 5282, 12689], "support_idx": [4504, 7822, 11325, 8480, 1235]}
{"target_classes": ["other-medical", "location-other", "person-actor", "product-weapon", "location-bodiesofwater"], "query_idx": [2532, 11648, 13626, 1109, 10751], "support_idx": [7290, 413, 4334, 10504, 12808]}
{"target_classes": ["person-actor", "product-weapon", "organization-politicalparty", "person-athlete", "organization-government/governmentagency"], "query_idx": [10577, 4080, 13918, 799, 2619], "support_idx": [8103, 12071, 9532, 1554, 2131]}
{"target_classes": ["location-other", "organization-politicalparty", "person-actor", "organization-government/governmentagency", "event-election"], "query_idx": [11580, 8955, 8633, 12802, 4687], "support_idx": [7790, 11474, 6822, 11488, 4439]}
{"target_classes": ["other-medical", "other-educationaldegree", "person-athlete", "product-car", "event-sportsevent"], "query_idx": [8012, 11377, 10920, 800, 12107], "support_idx": [5617, 13104, 4165, 1768, 10965]}

Просмотреть файл

@ -0,0 +1,10 @@
{"target_classes": ["person-director", "person-artist/author", "product-ship", "building-sportsfacility", "other-astronomything"], "query_idx": [128416, 57936, 120966, 6520, 71667], "support_idx": [59271, 51450, 57931, 126361, 83894]}
{"target_classes": ["other-law", "event-protest", "building-airport", "location-road/railway/highway/transit", "organization-sportsteam"], "query_idx": [83479, 31867, 91254, 100102, 108109], "support_idx": [38357, 105279, 29226, 113465, 45224]}
{"target_classes": ["other-god", "product-airplane", "product-ship", "event-protest", "other-award"], "query_idx": [76530, 20953, 26875, 21041, 75295], "support_idx": [87042, 121344, 66832, 122153, 40746]}
{"target_classes": ["location-GPE", "product-software", "art-film", "other-biologything", "building-airport"], "query_idx": [99546, 31778, 60736, 115175, 45604], "support_idx": [13509, 16341, 82343, 63652, 45731]}
{"target_classes": ["product-airplane", "other-language", "person-director", "event-protest", "location-GPE"], "query_idx": [122795, 20509, 76322, 98322, 87870], "support_idx": [116202, 15218, 33921, 102917, 77796]}
{"target_classes": ["organization-company", "product-software", "other-law", "organization-media/newspaper", "product-other"], "query_idx": [13360, 24527, 122844, 126093, 105019], "support_idx": [79122, 95407, 61315, 14147, 64059]}
{"target_classes": ["event-protest", "product-airplane", "person-soldier", "person-artist/author", "product-ship"], "query_idx": [3298, 82634, 111096, 63988, 39442], "support_idx": [96798, 69080, 69749, 60232, 38841]}
{"target_classes": ["other-god", "product-software", "art-broadcastprogram", "other-disease", "organization-sportsteam"], "query_idx": [64957, 90058, 64150, 55510, 82947], "support_idx": [110777, 63235, 63848, 41609, 90769]}
{"target_classes": ["building-hotel", "building-restaurant", "person-director", "person-artist/author", "other-law"], "query_idx": [36475, 97996, 81899, 107727, 120948], "support_idx": [57146, 123180, 45636, 69299, 10967]}
{"target_classes": ["person-artist/author", "event-attack/battle/war/militaryconflict", "organization-media/newspaper", "location-island", "organization-other"], "query_idx": [78920, 90936, 21417, 105820, 56455], "support_idx": [48831, 76272, 104368, 43682]}

Просмотреть файл

@ -0,0 +1,40 @@
On O
June O
15 O
, O
1991 O
, O
Hoch building-other
Auditorium building-other
was O
struck O
by O
lightning O
. O
Completed O
in O
July O
2008 O
by O
Mighty building-other
River building-other
Power building-other
at O
a O
cost O
of O
NZ O
$ O
300 O
million O
, O
the O
plant O
's O
capacity O
proved O
greater O
than O
expected O
. O

Просмотреть файл

@ -0,0 +1,32 @@
The O
stadium O
previously O
hosted O
the O
1982 O
Commonwealth event-sportsevent
Games event-sportsevent
and O
2001 O
Goodwill event-sportsevent
Games event-sportsevent
. O
The O
San organization-government/governmentagency
Diego organization-government/governmentagency
Harbor organization-government/governmentagency
Police organization-government/governmentagency
Department organization-government/governmentagency
is O
the O
law O
enforcement O
authority O
for O
the O
Port location-bodiesofwater
of location-bodiesofwater
San location-bodiesofwater
Diego location-bodiesofwater
. O

Просмотреть файл

@ -0,0 +1,26 @@
When O
reconstruction O
of O
the O
building O
was O
complete O
, O
the O
rear O
half O
of O
the O
building O
was O
named O
Budig O
Hall O
, O
for O
then O
KU organization-education
Chancellor O
Gene O
Budig O
. O

Просмотреть файл

@ -0,0 +1,251 @@
import torch.nn as nn
import torch
from typing import Dict, List
from typing import Tuple
B_PREF="B-"
I_PREF = "I-"
S_PREF = "S-"
E_PREF = "E-"
O = "O"
class LinearCRF(nn.Module):
def __init__(self, tag_size: int, schema: str, add_constraint: bool = False,
label2idx: Dict = None, device: torch.device = None):
super(LinearCRF, self).__init__()
self.label_size = tag_size + 3
self.label2idx = label2idx
self.tag_list = list(self.label2idx.keys())
self.start_idx = tag_size
self.end_idx = tag_size + 1
self.pad_idx = tag_size + 2
self.schema = schema
self.add_constraint = add_constraint
self.init_params(device=device)
return
def reset(self, label2idx, device):
if len(label2idx) == len(self.label2idx):
return
tag_size = len(label2idx)
self.label_size = tag_size + 3
self.label2idx = label2idx
self.tag_list = list(self.label2idx.keys())
self.start_idx = tag_size
self.end_idx = tag_size + 1
self.pad_idx = tag_size + 2
self.add_constraint = True
self.init_params(device)
return
def init_params(self, device=None):
if device is None:
device = torch.device('cpu')
init_transition = torch.zeros(self.label_size, self.label_size, device=device)
init_transition[:, self.start_idx] = -10000.0
init_transition[self.end_idx, :] = -10000.0
init_transition[:, self.pad_idx] = -10000.0
init_transition[self.pad_idx, :] = -10000.0
if self.add_constraint:
if self.schema == "BIO":
self.add_constraint_for_bio(init_transition)
elif self.schema == "BIOES":
self.add_constraint_for_iobes(init_transition)
else:
print("[ERROR] wrong schema name!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
self.transition = nn.Parameter(init_transition, requires_grad=False)
return
def add_constraint_for_bio(self, transition: torch.Tensor):
for prev_label in self.tag_list:
for next_label in self.tag_list:
if prev_label == O and next_label.startswith(I_PREF):
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
if (prev_label.startswith(B_PREF) or prev_label.startswith(I_PREF)) and next_label.startswith(I_PREF):
if prev_label[2:] != next_label[2:]:
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
for label in self.tag_list:
if label.startswith(I_PREF):
transition[self.start_idx, self.label2idx[label]] = -10000.0
return
def add_constraint_for_iobes(self, transition: torch.Tensor):
for prev_label in self.tag_list:
for next_label in self.tag_list:
if prev_label == O and (next_label.startswith(I_PREF) or next_label.startswith(E_PREF)):
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
if prev_label.startswith(B_PREF) or prev_label.startswith(I_PREF):
if next_label.startswith(O) or next_label.startswith(B_PREF) or next_label.startswith(S_PREF):
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
elif prev_label[2:] != next_label[2:]:
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
if prev_label.startswith(S_PREF) or prev_label.startswith(E_PREF):
if next_label.startswith(I_PREF) or next_label.startswith(E_PREF):
transition[self.label2idx[prev_label], self.label2idx[next_label]] = -10000.0
for label in self.tag_list:
if label.startswith(I_PREF) or label.startswith(E_PREF):
transition[self.start_idx, self.label2idx[label]] = -10000.0
if label.startswith(I_PREF) or label.startswith(B_PREF):
transition[self.label2idx[label], self.end_idx] = -10000.0
return
def forward(self, lstm_scores, word_seq_lens, tags, mask, decode_flag=False):
all_scores = self.calculate_all_scores(lstm_scores=lstm_scores)
unlabed_score = self.forward_unlabeled(all_scores, word_seq_lens)
labeled_score = self.forward_labeled(all_scores, word_seq_lens, tags, mask)
per_sent_loss = (unlabed_score - labeled_score).sum() / mask.size(0)
if decode_flag:
_, decodeIdx = self.viterbi_decode(all_scores, word_seq_lens)
return per_sent_loss, decodeIdx
return per_sent_loss
def get_marginal_score(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
marginal = self.forward_backward(lstm_scores=lstm_scores, word_seq_lens=word_seq_lens)
return marginal
def forward_unlabeled(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
batch_size = all_scores.size(0)
seq_len = all_scores.size(1)
dev_num = all_scores.get_device()
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
alpha[:, 0, :] = all_scores[:, 0, self.start_idx, :]
for word_idx in range(1, seq_len):
before_log_sum_exp = alpha[:, word_idx-1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + all_scores[:, word_idx, :, :]
alpha[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size)-1).view(batch_size, self.label_size)
last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
last_alpha = torch.logsumexp(last_alpha.view(batch_size, self.label_size, 1), dim=1).view(batch_size) #log Z(x)
return last_alpha
def backward(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
batch_size = lstm_scores.size(0)
seq_len = lstm_scores.size(1)
dev_num = lstm_scores.get_device()
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
beta = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev)
for batch_idx in range(batch_size):
perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range(word_seq_lens[batch_idx] - 1, 0, -1)
perm_idx = perm_idx.long()
for i, length in enumerate(word_seq_lens):
rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]]
beta[:, 0, :] = rev_score[:, 0, self.end_idx, :]
for word_idx in range(1, seq_len):
before_log_sum_exp = beta[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + rev_score[:, word_idx, :, :]
beta[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
last_beta = torch.gather(beta, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size) - 1).view(batch_size, self.label_size)
last_beta += self.transition.transpose(0, 1)[:, self.start_idx].view(1, self.label_size).expand(batch_size, self.label_size)
last_beta = torch.logsumexp(last_beta, dim=1)
for i, length in enumerate(word_seq_lens):
beta[i, :length] = beta[i, :length][perm_idx[i, :length]]
return torch.sum(last_beta)
def forward_backward(self, lstm_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> torch.Tensor:
batch_size = lstm_scores.size(0)
seq_len = lstm_scores.size(1)
dev_num = lstm_scores.get_device()
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
alpha = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
beta = torch.zeros(batch_size, seq_len, self.label_size, device=curr_dev)
scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
rev_score = self.transition.transpose(0, 1).view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
perm_idx = torch.zeros(batch_size, seq_len, device=curr_dev)
for batch_idx in range(batch_size):
perm_idx[batch_idx][:word_seq_lens[batch_idx]] = torch.range(word_seq_lens[batch_idx] - 1, 0, -1)
perm_idx = perm_idx.long()
for i, length in enumerate(word_seq_lens):
rev_score[i, :length] = rev_score[i, :length][perm_idx[i, :length]]
alpha[:, 0, :] = scores[:, 0, self.start_idx, :]
beta[:, 0, :] = rev_score[:, 0, self.end_idx, :]
for word_idx in range(1, seq_len):
before_log_sum_exp = alpha[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + scores[ :, word_idx, :, :]
alpha[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
before_log_sum_exp = beta[:, word_idx - 1, :].view(batch_size, self.label_size, 1).expand(batch_size, self.label_size, self.label_size) + rev_score[:, word_idx, :, :]
beta[:, word_idx, :] = torch.logsumexp(before_log_sum_exp, dim=1)
last_alpha = torch.gather(alpha, 1, word_seq_lens.view(batch_size, 1, 1).expand(batch_size, 1, self.label_size) - 1).view(batch_size, self.label_size)
last_alpha += self.transition[:, self.end_idx].view(1, self.label_size).expand(batch_size, self.label_size)
last_alpha = torch.logsumexp(last_alpha.view(batch_size, self.label_size), dim=-1).view(batch_size, 1, 1).expand(batch_size, seq_len, self.label_size)
for i, length in enumerate(word_seq_lens):
beta[i, :length] = beta[i, :length][perm_idx[i, :length]]
return alpha + beta - last_alpha - lstm_scores
def forward_labeled(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor, tags: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
batchSize = all_scores.shape[0]
sentLength = all_scores.shape[1]
currentTagScores = torch.gather(all_scores, 3, tags.view(batchSize, sentLength, 1, 1).expand(batchSize, sentLength, self.label_size, 1)).view(batchSize, -1, self.label_size)
tagTransScoresMiddle = None
if sentLength != 1:
tagTransScoresMiddle = torch.gather(currentTagScores[:, 1:, :], 2, tags[:, :sentLength - 1].view(batchSize, sentLength - 1, 1)).view(batchSize, -1)
tagTransScoresBegin = currentTagScores[:, 0, self.start_idx]
endTagIds = torch.gather(tags, 1, word_seq_lens.view(batchSize, 1) - 1)
tagTransScoresEnd = torch.gather(self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size), 1, endTagIds).view(batchSize)
score = tagTransScoresBegin + tagTransScoresEnd
masks = masks.type(torch.float32)
if sentLength != 1:
score += torch.sum(tagTransScoresMiddle.mul(masks[:, 1:]), dim=1)
return score
def calculate_all_scores(self, lstm_scores: torch.Tensor) -> torch.Tensor:
batch_size = lstm_scores.size(0)
seq_len = lstm_scores.size(1)
scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
lstm_scores.view(batch_size, seq_len, 1, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size)
return scores
def decode(self, features, wordSeqLengths, new_label2idx=None) -> Tuple[torch.Tensor, torch.Tensor]:
if new_label2idx is not None:
self.reset(new_label2idx, features.device)
all_scores = self.calculate_all_scores(features)
bestScores, decodeIdx = self.viterbi_decode(all_scores, wordSeqLengths)
return bestScores, decodeIdx
def viterbi_decode(self, all_scores: torch.Tensor, word_seq_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batchSize = all_scores.shape[0]
sentLength = all_scores.shape[1]
dev_num = all_scores.get_device()
curr_dev = torch.device(f"cuda:{dev_num}") if dev_num >= 0 else torch.device("cpu")
scoresRecord = torch.zeros([batchSize, sentLength, self.label_size], device=curr_dev)
idxRecord = torch.zeros([batchSize, sentLength, self.label_size], dtype=torch.int64, device=curr_dev)
startIds = torch.full((batchSize, self.label_size), self.start_idx, dtype=torch.int64, device=curr_dev)
decodeIdx = torch.LongTensor(batchSize, sentLength).to(curr_dev)
scores = all_scores
scoresRecord[:, 0, :] = scores[:, 0, self.start_idx, :]
idxRecord[:, 0, :] = startIds
for wordIdx in range(1, sentLength):
scoresIdx = scoresRecord[:, wordIdx - 1, :].view(batchSize, self.label_size, 1).expand(batchSize, self.label_size,
self.label_size) + scores[:, wordIdx, :, :]
idxRecord[:, wordIdx, :] = torch.argmax(scoresIdx, 1)
scoresRecord[:, wordIdx, :] = torch.gather(scoresIdx, 1, idxRecord[:, wordIdx, :].view(batchSize, 1, self.label_size)).view(batchSize, self.label_size)
lastScores = torch.gather(scoresRecord, 1, word_seq_lens.view(batchSize, 1, 1).expand(batchSize, 1, self.label_size) - 1).view(batchSize, self.label_size) ##select position
lastScores += self.transition[:, self.end_idx].view(1, self.label_size).expand(batchSize, self.label_size)
decodeIdx[:, 0] = torch.argmax(lastScores, 1)
bestScores = torch.gather(lastScores, 1, decodeIdx[:, 0].view(batchSize, 1))
for distance2Last in range(sentLength - 1):
curIdx = torch.clamp(word_seq_lens - distance2Last - 1, min=1).view(batchSize, 1, 1).expand(batchSize, 1, self.label_size)
lastNIdxRecord = torch.gather(idxRecord, 1, curIdx).view(batchSize, self.label_size)
decodeIdx[:, distance2Last + 1] = torch.gather(lastNIdxRecord, 1, decodeIdx[:, distance2Last].view(batchSize, 1)).view(batchSize)
perm_pos = torch.arange(1, sentLength + 1).to(curr_dev)
perm_pos = perm_pos.unsqueeze(0).expand(batchSize, sentLength)
perm_pos = word_seq_lens.unsqueeze(1).expand(batchSize, sentLength) - perm_pos
perm_pos = perm_pos.masked_fill(perm_pos < 0, 0)
decodeIdx = torch.gather(decodeIdx, 1, perm_pos)
return bestScores, decodeIdx

Просмотреть файл

@ -0,0 +1,323 @@
import torch
from torch import nn
import torch.nn.functional as F
from .crf import LinearCRF
from .span_fewshotmodel import FewShotSpanModel
from util.span_sample import convert_bio2spans
from .loss_model import MaxLoss
from copy import deepcopy
import numpy as np
class SelectedJointModel(FewShotSpanModel):
def __init__(self, span_encoder, num_tag, ment_label2idx, schema, use_crf, max_loss, use_oproto, dot=False, normalize="none",
temperature=None, use_focal=False, type_lam=1):
super(SelectedJointModel, self).__init__(span_encoder)
self.dot = dot
self.normalize = normalize
self.temperature = temperature
self.use_oproto = use_oproto
self.proto = None
self.num_tag = num_tag
self.use_crf = use_crf
self.ment_label2idx = ment_label2idx
self.ment_idx2label = {idx: label for label, idx in self.ment_label2idx.items()}
self.schema = schema
self.cls = nn.Linear(span_encoder.word_dim, self.num_tag)
self.type_lam = type_lam
self.proto = None
if self.use_crf:
self.crf_layer = LinearCRF(self.num_tag, schema=schema, add_constraint=True, label2idx=ment_label2idx)
if use_focal:
raise ValueError("not support focal loss")
self.ment_loss_fct = MaxLoss(gamma=max_loss)
self.base_loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
print("use cross entropy loss")
print("use dot : {} use normalizatioin: {} use temperature: {}".format(self.dot, self.normalize,
self.temperature if self.temperature else "none"))
self.cached_o_proto = torch.zeros(span_encoder.span_dim, requires_grad=False)
self.init_weights()
return
def init_weights(self):
self.cls.weight.data.normal_(mean=0.0, std=0.02)
if self.cls.bias is not None:
self.cls.bias.data.zero_()
if self.use_crf:
self.crf_layer.init_params()
return
def type_loss_fct(self, logits, targets, inst_weights=None):
if inst_weights is None:
loss = self.base_loss_fct(logits, targets)
loss = loss.mean()
else:
targets = torch.clamp(targets, min=0)
one_hot_targets = torch.zeros(logits.size(), device=logits.device).scatter_(1, targets.unsqueeze(1), 1)
soft_labels = inst_weights.unsqueeze(1) * one_hot_targets + (1 - one_hot_targets) * (1 - inst_weights).unsqueeze(1) / (logits.size(1) - 1)
logp = F.log_softmax(logits, dim=-1)
loss = - (logp * soft_labels).sum(1)
loss = loss.mean()
return loss
def __dist__(self, x, y, dim):
if self.normalize == 'l2':
x = F.normalize(x, p=2, dim=-1)
y = F.normalize(y, p=2, dim=-1)
if self.dot:
sim = (x * y).sum(dim)
else:
sim = -(torch.pow(x - y, 2)).sum(dim)
if self.temperature:
sim = sim / self.temperature
return sim
def __batch_dist__(self, S_emb, Q_emb, Q_mask):
Q_emb = Q_emb[Q_mask.eq(1), :].view(-1, Q_emb.size(-1))
dist = self.__dist__(S_emb.unsqueeze(0), Q_emb.unsqueeze(1), 2)
return dist
def __get_proto__(self, S_emb, S_tag, S_mask):
proto = []
embedding = S_emb[S_mask.eq(1), :].view(-1, S_emb.size(-1))
S_tag = S_tag[S_mask.eq(1)]
if self.use_oproto:
st_idx = 0
else:
st_idx = 1
proto = [self.cached_o_proto]
for label in range(st_idx, torch.max(S_tag) + 1):
proto.append(torch.mean(embedding[S_tag.eq(label), :], 0))
proto = torch.stack(proto, dim=0)
return proto
def __get_proto_dist__(self, Q_emb, Q_mask):
dist = self.__batch_dist__(self.proto, Q_emb, Q_mask)
if not self.use_oproto:
dist[:, 0] = -1000000
return dist
def forward_type_step(self, query, encoder_mode=None, query_bottom_hiddens=None):
if query['span_mask'].sum().item() == 0: # there is no query mentions
print("no query mentions")
empty_tensor = torch.tensor([], device=query['word'].device)
zero_tensor = torch.tensor([0], device=query['word'].device)
return empty_tensor, empty_tensor, empty_tensor, zero_tensor
query_span_emb = self.word_encoder(query['word'], query['word_mask'], word_to_piece_inds=query['word_to_piece_ind'],
word_to_piece_ends=query['word_to_piece_end'], span_indices=query['span_indices'],
mode=encoder_mode, bottom_hiddens=query_bottom_hiddens)
logits = self.__get_proto_dist__(query_span_emb, query['span_mask'])
golds = query["span_tag"][query["span_mask"].eq(1)].view(-1)
if query["span_weights"] is not None:
query_span_weights = query["span_weights"][query["span_mask"].eq(1)].view(-1)
else:
query_span_weights = None
if self.use_oproto:
loss = self.type_loss_fct(logits, golds, inst_weights=query_span_weights)
else:
loss = self.type_loss_fct(logits[:, 1:], golds - 1, inst_weights=query_span_weights)
_, preds = torch.max(logits, dim=-1)
return logits, preds, golds, loss
def init_proto(self, support_data, encoder_mode=None, support_bottom_hiddens=None):
self.eval()
with torch.no_grad():
support_span_emb = self.word_encoder(support_data['word'], support_data['word_mask'], word_to_piece_inds=support_data['word_to_piece_ind'],
word_to_piece_ends=support_data['word_to_piece_end'], span_indices=support_data['span_indices'],
mode=encoder_mode, bottom_hiddens=support_bottom_hiddens)
self.cached_o_proto = self.cached_o_proto.to(support_span_emb.device)
proto = self.__get_proto__(support_span_emb, support_data['span_tag'], support_data['span_mask'])
self.proto = nn.Parameter(proto.data, requires_grad=True)
return
def forward_ment_step(self, batch, crf_mode=True, encoder_mode=None):
res = self.word_encoder(batch['word'], batch['word_mask'],
batch['word_to_piece_ind'],
batch['word_to_piece_end'], mode=encoder_mode,
)
word_emb = res[0]
bottom_hiddens = res[1]
logits = self.cls(word_emb)
gold = batch['ment_labels']
tot_loss = self.ment_loss_fct(logits, gold)
if self.use_crf and crf_mode:
crf_sp_logits = torch.zeros((logits.size(0), logits.size(1), 3), device=logits.device)
crf_sp_logits = torch.cat([logits, crf_sp_logits], dim=2)
_, pred = self.crf_layer.decode(crf_sp_logits, batch['seq_len'])
else:
pred = torch.argmax(logits, dim=-1)
pred = pred.masked_fill(gold.eq(-1), -1)
return logits, pred, gold, tot_loss, bottom_hiddens
def joint_inner_update(self, support_data, inner_steps, lr_inner):
self.init_proto(support_data, encoder_mode="type")
parameters_to_optimize = list(self.named_parameters())
decay_params = []
nodecay_params = []
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
for n, p in parameters_to_optimize:
if p.requires_grad:
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
decay_params.append(p)
else:
nodecay_params.append(p)
parameters_groups = [
{'params': decay_params,
'lr': lr_inner, 'weight_decay': 1e-3},
{'params': nodecay_params,
'lr': lr_inner, 'weight_decay': 0},
]
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
self.train()
for _ in range(inner_steps):
inner_opt.zero_grad()
_, _, _, ment_loss, support_bottom_hiddens = self.forward_ment_step(support_data, crf_mode=False, encoder_mode="ment")
_, _, _, type_loss = self.forward_type_step(support_data, encoder_mode="type", query_bottom_hiddens=support_bottom_hiddens)
loss = ment_loss + type_loss * self.type_lam
loss.backward()
inner_opt.step()
return
def decode_ments(self, episode_preds):
episode_preds = episode_preds.detach()
span_indices = torch.zeros((episode_preds.size(0), 100, 2), dtype=torch.long)
span_masks = torch.zeros(span_indices.size()[:2], dtype=torch.long)
span_labels = torch.full_like(span_masks, fill_value=1, dtype=torch.long)
max_span_num = 0
for i, pred in enumerate(episode_preds):
seqs = []
for idx in pred:
if idx.item() == -1:
break
seqs.append(self.ment_idx2label[idx.item()])
ents = convert_bio2spans(seqs, self.schema)
max_span_num = max(max_span_num, len(ents))
for j, x in enumerate(ents):
span_indices[i, j, 0] = x[1]
span_indices[i, j, 1] = x[2]
span_masks[i, :len(ents)] = 1
return span_indices, span_masks, span_labels, max_span_num
# forward proto maml
def forward_joint_meta(self, batch, inner_steps, lr_inner, mode):
no_grads = ["proto"]
if batch['query']['span_mask'].sum().item() == 0: # there is no query mentions
print("no query mentions")
no_grads.append("type_adapters")
names, params = self.get_named_params(no_grads=no_grads)
weights = deepcopy(params)
meta_grad = []
episode_losses = []
episode_ment_losses = []
episode_type_losses = []
query_ment_logits = []
query_ment_preds = []
query_ment_golds = []
query_type_logits = []
query_type_preds = []
query_type_golds = []
two_stage_query_ments = []
two_stage_query_masks = []
two_stage_max_snum = 0
current_support_num = 0
current_query_num = 0
support, query = batch["support"], batch["query"]
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'ment_labels', 'span_indices', 'span_mask', 'span_tag', 'span_weights']
for i, sent_support_num in enumerate(support['sentence_num']):
sent_query_num = query['sentence_num'][i]
one_support = {
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys if k in support
}
one_query = {
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys if k in query
}
self.zero_grad()
self.joint_inner_update(one_support, inner_steps, lr_inner) # inner update parameters on support data
if mode == "train":
qy_ment_logits, qy_ment_pred, qy_ment_gold, qy_ment_loss, qy_bottom_hiddens = self.forward_ment_step(one_query, crf_mode=False, encoder_mode="ment") # evaluate on query data
qy_type_logits, qy_type_pred, qy_type_gold, qy_type_loss = self.forward_type_step(one_query, encoder_mode="type", query_bottom_hiddens=qy_bottom_hiddens) # evaluate on query data
qy_loss = qy_ment_loss + self.type_lam * qy_type_loss
if one_query['span_mask'].sum().item() == 0:
pass
else:
grad = torch.autograd.grad(qy_loss, params) # meta-update
meta_grad.append(grad)
elif mode == "test-onestage":
self.eval()
with torch.no_grad():
qy_ment_logits, qy_ment_pred, qy_ment_gold, qy_ment_loss, qy_bottom_hiddens = self.forward_ment_step(one_query, encoder_mode="ment") # evaluate on query data
qy_type_logits, qy_type_pred, qy_type_gold, qy_type_loss = self.forward_type_step(one_query, encoder_mode="type", query_bottom_hiddens=qy_bottom_hiddens) # evaluate on query data
qy_loss = qy_ment_loss + self.type_lam * qy_type_loss
else:
assert mode == "test-twostage"
self.eval()
with torch.no_grad():
qy_ment_logits, qy_ment_pred, qy_ment_gold, qy_ment_loss, qy_bottom_hiddens = self.forward_ment_step(one_query, encoder_mode="ment") # evaluate on query data
span_indices, span_mask, span_tag, max_span_num = self.decode_ments(qy_ment_pred)
one_query['span_indices'] = span_indices[:, :max_span_num, :].to(qy_ment_logits.device)
one_query['span_mask'] = span_mask[:, :max_span_num].to(qy_ment_logits.device)
one_query['span_tag'] = span_tag[:, :max_span_num].to(qy_ment_logits.device)
one_query['span_weights'] = None
two_stage_max_snum = max(two_stage_max_snum, max_span_num)
two_stage_query_masks.append(span_mask)
two_stage_query_ments.append(span_indices)
qy_type_logits, qy_type_pred, qy_type_gold, qy_type_loss = self.forward_type_step(one_query, encoder_mode="type", query_bottom_hiddens=qy_bottom_hiddens) # evaluate on query data
qy_loss = qy_ment_loss + self.type_lam * qy_type_loss
episode_losses.append(qy_loss.item())
episode_ment_losses.append(qy_ment_loss.item())
episode_type_losses.append(qy_type_loss.item())
query_ment_preds.append(qy_ment_pred)
query_ment_golds.append(qy_ment_gold)
query_ment_logits.append(qy_ment_logits)
query_type_preds.append(qy_type_pred)
query_type_golds.append(qy_type_gold)
query_type_logits.append(qy_type_logits)
self.load_weights(names, weights)
current_query_num += sent_query_num
current_support_num += sent_support_num
self.zero_grad()
if mode == "test-twostage":
two_stage_query_masks = torch.cat(two_stage_query_masks, dim=0)[:, :two_stage_max_snum]
two_stage_query_ments = torch.cat(two_stage_query_ments, dim=0)[:, :two_stage_max_snum, :]
return {'loss': np.mean(episode_losses), 'ment_loss': np.mean(episode_ment_losses), 'type_loss': np.mean(episode_type_losses), 'names': names, 'grads': meta_grad,
'ment_preds': query_ment_preds, 'ment_golds': query_ment_golds, 'ment_logits': query_ment_logits,
'type_preds': query_type_preds, 'type_golds': query_type_golds, 'type_logits': query_type_logits,
'pred_spans': two_stage_query_ments, 'pred_masks': two_stage_query_masks
}
else:
return {'loss': np.mean(episode_losses), 'ment_loss': np.mean(episode_ment_losses), 'type_loss': np.mean(episode_type_losses), 'names': names, 'grads': meta_grad,
'ment_preds': query_ment_preds, 'ment_golds': query_ment_golds, 'ment_logits': query_ment_logits,
'type_preds': query_type_preds, 'type_golds': query_type_golds, 'type_logits': query_type_logits,
}
def get_named_params(self, no_grads=[]):
names = []
params = []
for n, p in self.named_parameters():
if any([pn in n for pn in no_grads]):
continue
if p.requires_grad:
names.append(n)
params.append(p)
return names, params
def load_weights(self, names, params):
model_params = self.state_dict()
for n, p in zip(names, params):
model_params[n].data.copy_(p.data)
return
def load_gradients(self, names, grads):
model_params = self.state_dict(keep_vars=True)
for n, g in zip(names, grads):
if model_params[n].grad is None:
continue
model_params[n].grad.data.add_(g.data) # accumulate
return

Просмотреть файл

@ -0,0 +1,41 @@
import torch
from torch import nn
import torch.nn.functional as F
class MaxLoss(nn.Module):
def __init__(self, gamma=1.0):
super(MaxLoss, self).__init__()
self.gamma = gamma
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
return
def forward(self, logits, targets):
assert logits.dim() == 3 #batch_size, seq_len, label_num
batch_size, seq_len, class_num = logits.size()
token_loss = self.loss_fct(logits.view(-1, class_num), targets.view(-1)).view(batch_size, seq_len)
act_pos = targets.ne(-1)
loss = token_loss[act_pos].mean()
if self.gamma > 0:
max_loss = torch.max(token_loss, dim=1)[0]
loss += self.gamma * max_loss.mean()
return loss
class FocalLoss(nn.Module):
def __init__(self, gamma=1.0, reduction=False):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.reduction = reduction
return
def forward(self, logits, targets):
assert logits.dim() == 2
assert torch.min(targets).item() > -1
logp = F.log_softmax(logits, dim=1)
target_logp = logp.gather(1, targets.view(-1, 1)).view(-1)
target_p = torch.exp(target_logp)
weight = (1 - target_p) ** self.gamma
loss = - weight * target_logp
if self.reduction:
loss = loss.mean()
return loss

Просмотреть файл

@ -0,0 +1,171 @@
import torch
from torch import nn
import torch.nn.functional as F
from .span_fewshotmodel import FewShotSpanModel
from copy import deepcopy
import numpy as np
from .crf import LinearCRF
from .loss_model import MaxLoss
class MentSeqtagger(FewShotSpanModel):
def __init__(self, span_encoder, num_tag, label2idx, schema, use_crf, max_loss):
super(MentSeqtagger, self).__init__(span_encoder)
self.num_tag = num_tag
self.use_crf = use_crf
self.label2idx = label2idx
self.idx2label = {idx: label for label, idx in self.label2idx.items()}
self.schema = schema
self.cls = nn.Linear(span_encoder.word_dim, self.num_tag)
if self.use_crf:
self.crf_layer = LinearCRF(self.num_tag, schema=schema, add_constraint=True, label2idx=label2idx)
self.ment_loss_fct = MaxLoss(gamma=max_loss)
self.init_weights()
return
def init_weights(self):
self.cls.weight.data.normal_(mean=0.0, std=0.02)
if self.cls.bias is not None:
self.cls.bias.data.zero_()
if self.use_crf:
self.crf_layer.init_params()
return
def forward_step(self, batch, crf_mode=True):
word_emb = self.word_encoder(batch['word'], batch['word_mask'],
batch['word_to_piece_ind'],
batch['word_to_piece_end'])
logits = self.cls(word_emb)
gold = batch['ment_labels']
tot_loss = self.ment_loss_fct(logits, gold)
if self.use_crf and crf_mode:
crf_sp_logits = torch.zeros((logits.size(0), logits.size(1), 3), device=logits.device)
crf_sp_logits = torch.cat([logits, crf_sp_logits], dim=2)
_, pred = self.crf_layer.decode(crf_sp_logits, batch['seq_len'])
else:
pred = torch.argmax(logits, dim=-1)
pred = pred.masked_fill(gold.eq(-1), -1)
return logits, pred, gold, tot_loss
def inner_update(self, train_data, inner_steps, lr_inner):
parameters_to_optimize = list(self.named_parameters())
decay_params = []
nodecay_params = []
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
for n, p in parameters_to_optimize:
if p.requires_grad:
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
decay_params.append(p)
else:
nodecay_params.append(p)
parameters_groups = [
{'params': decay_params,
'lr': lr_inner, 'weight_decay': 1e-3},
{'params': nodecay_params,
'lr': lr_inner, 'weight_decay': 0},
]
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
self.train()
for _ in range(inner_steps):
inner_opt.zero_grad()
_, _, _, loss = self.forward_step(train_data, crf_mode=False)
loss.backward()
inner_opt.step()
return
def forward_sup(self, batch, mode):
support, query = batch["support"], batch["query"]
query_logits = []
query_preds = []
query_golds = []
all_loss = 0
task_num = 0
current_support_num = 0
current_query_num = 0
if mode == "train":
crf_mode = False
else:
assert mode == "test"
crf_mode = True
self.eval()
print('eval mode')
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'ment_labels']
for i, sent_support_num in enumerate(support['sentence_num']):
sent_query_num = query['sentence_num'][i]
one_support = {
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys
}
one_query = {
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys
}
_, _, _, sp_loss = self.forward_step(one_support, False)
all_loss += sp_loss
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query, crf_mode)
query_preds.append(qy_pred)
query_golds.append(qy_gold)
query_logits.append(qy_logits)
all_loss += qy_loss
task_num += 2
current_query_num += sent_query_num
current_support_num += sent_support_num
return {'loss': all_loss / task_num, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
def forward_meta(self, batch, inner_steps, lr_inner, mode):
names, params = self.get_named_params()
weights = deepcopy(params)
meta_grad = []
episode_losses = []
query_preds = []
query_golds = []
query_logits = []
support, query = batch["support"], batch["query"]
current_support_num = 0
current_query_num = 0
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'ment_labels']
for i, sent_support_num in enumerate(support['sentence_num']):
sent_query_num = query['sentence_num'][i]
one_support = {
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys
}
one_query = {
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys
}
self.zero_grad()
self.inner_update(one_support, inner_steps, lr_inner)
if mode == "train":
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query, crf_mode=False)
grad = torch.autograd.grad(qy_loss, params)
meta_grad.append(grad)
else:
self.eval()
with torch.no_grad():
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
episode_losses.append(qy_loss.item())
query_preds.append(qy_pred)
query_golds.append(qy_gold)
query_logits.append(qy_logits)
self.load_weights(names, weights)
current_query_num += sent_query_num
current_support_num += sent_support_num
self.zero_grad()
return {'loss': np.mean(episode_losses), 'names': names, 'grads': meta_grad, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
def get_named_params(self):
names = [n for n, p in self.named_parameters() if p.requires_grad]
params = [p for n, p in self.named_parameters() if p.requires_grad]
return names, params
def load_weights(self, names, params):
model_params = self.state_dict()
for n, p in zip(names, params):
model_params[n].data.copy_(p.data)
return
def load_gradients(self, names, grads):
model_params = self.state_dict(keep_vars=True)
for n, g in zip(names, grads):
if model_params[n].grad is None:
continue
model_params[n].grad.data.add_(g.data) # accumulate
return

Просмотреть файл

@ -0,0 +1,64 @@
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
class FewShotTokenModel(nn.Module):
def __init__(self, my_word_encoder):
nn.Module.__init__(self)
self.word_encoder = nn.DataParallel(my_word_encoder)
def forward(self, batch):
raise NotImplementedError
def span_accuracy(self, pred, gold):
return np.mean(pred == gold)
def metrics_by_pos(self, preds, golds):
hit_cnt = 0
tot_cnt = 0
for i in range(len(preds)):
tot_cnt += len(preds[i])
for j in range(len(preds[i])):
if preds[i][j] == golds[i][j]:
hit_cnt += 1
return hit_cnt, tot_cnt
def seq_eval(self, ep_preds, query):
query_seq_lens = query["seq_len"].detach().cpu().tolist()
subsent_label2tag_ids = []
subsent_pred_ids = []
for k, batch_preds in enumerate(ep_preds):
batch_preds = batch_preds.detach().cpu().tolist()
for pred in batch_preds:
subsent_pred_ids.append(pred)
subsent_label2tag_ids.append(k)
sent_gold_labels = []
sent_pred_labels = []
subsent_id = 0
query['word_labels'] = query['word_labels'].cpu().tolist()
for snum in query['subsentence_num']:
whole_sent_gids = []
whole_sent_pids = []
for k in range(subsent_id, subsent_id + snum):
whole_sent_gids += query['word_labels'][k][:query_seq_lens[k]]
whole_sent_pids += subsent_pred_ids[k][:query_seq_lens[k]]
label2tag = query['label2tag'][subsent_label2tag_ids[subsent_id]]
sent_gold_labels.append([label2tag[lid] for lid in whole_sent_gids])
sent_pred_labels.append([label2tag[lid] for lid in whole_sent_pids])
subsent_id += snum
hit_cnt, tot_cnt = self.metrics_by_pos(sent_pred_labels, sent_gold_labels)
logs = {
"index": query["index"],
"seq_len": query_seq_lens,
"pred": sent_pred_labels,
"gold": sent_gold_labels,
"sentence_num": query["sentence_num"],
"subsentence_num": query['subsentence_num']
}
metric_logs = {
"gold_cnt": tot_cnt,
"hit_cnt": hit_cnt
}
return metric_logs, logs

Просмотреть файл

@ -0,0 +1,173 @@
import torch
from torch import nn
import torch.nn.functional as F
from .pos_fewshotmodel import FewShotTokenModel
from copy import deepcopy
import numpy as np
from .loss_model import MaxLoss
class SeqProtoCls(FewShotTokenModel):
def __init__(self, span_encoder, max_loss, dot=False, normalize="none", temperature=None, use_focal=False):
super(SeqProtoCls, self).__init__(span_encoder)
self.dot = dot
self.normalize = normalize
self.temperature = temperature
self.use_focal = use_focal
self.loss_fct = MaxLoss(gamma=max_loss)
self.proto = None
print("use dot : {} use normalizatioin: {} use temperature: {}".format(self.dot, self.normalize,
self.temperature if self.temperature else "none"))
return
def __dist__(self, x, y, dim):
if self.normalize == 'l2':
x = F.normalize(x, p=2, dim=-1)
y = F.normalize(y, p=2, dim=-1)
if self.dot:
sim = (x * y).sum(dim)
else:
sim = -(torch.pow(x - y, 2)).sum(dim)
if self.temperature:
sim = sim / self.temperature
return sim
def __batch_dist__(self, S_emb, Q_emb, Q_mask):
if Q_mask is None:
Q_emb = Q_emb.view(-1, Q_emb.size(-1))
else:
Q_emb = Q_emb[Q_mask.eq(1), :].view(-1, Q_emb.size(-1))
dist = self.__dist__(S_emb.unsqueeze(0), Q_emb.unsqueeze(1), 2)
return dist
def __get_proto__(self, S_emb, S_tag, S_mask, max_tag):
proto = []
embedding = S_emb[S_mask.eq(1), :].view(-1, S_emb.size(-1))
S_tag = S_tag[S_mask.eq(1)]
for label in range(max_tag + 1):
if S_tag.eq(label).sum().item() == 0:
proto.append(torch.zeros(embedding.size(-1), device=embedding.device))
else:
proto.append(torch.mean(embedding[S_tag.eq(label), :], 0))
proto = torch.stack(proto, dim=0)
return proto
def __get_proto_dist__(self, Q_emb, Q_mask):
dist = self.__batch_dist__(self.proto, Q_emb, Q_mask)
return dist
def forward_step(self, query):
query_word_emb = self.word_encoder(query['word'], query['word_mask'], word_to_piece_inds=query['word_to_piece_ind'],
word_to_piece_ends=query['word_to_piece_end'])
logits = self.__get_proto_dist__(query_word_emb, None)
logits = logits.view(query_word_emb.size(0), query_word_emb.size(1), -1)
gold = query['word_labels']
tot_loss = self.loss_fct(logits, gold)
pred = torch.argmax(logits, dim=-1)
pred = pred.masked_fill(query['word_labels'] < 0, -1)
return logits, pred, gold, tot_loss
def init_proto(self, support):
self.eval()
with torch.no_grad():
support_word_emb = self.word_encoder(support['word'], support['word_mask'], word_to_piece_inds=support['word_to_piece_ind'],
word_to_piece_ends=support['word_to_piece_end'])
proto = self.__get_proto__(support_word_emb, support['word_labels'], support['word_labels'] > -1, len(support['label2idx']) - 1)
self.proto = nn.Parameter(proto.data, requires_grad=True)
return
def inner_update(self, support_data, inner_steps, lr_inner):
self.init_proto(support_data)
parameters_to_optimize = list(self.named_parameters())
decay_params = []
nodecay_params = []
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
for n, p in parameters_to_optimize:
if p.requires_grad:
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
decay_params.append(p)
else:
nodecay_params.append(p)
parameters_groups = [
{'params': decay_params,
'lr': lr_inner, 'weight_decay': 1e-3},
{'params': nodecay_params,
'lr': lr_inner, 'weight_decay': 0},
]
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
self.train()
for _ in range(inner_steps):
inner_opt.zero_grad()
_, _, _, loss = self.forward_step(support_data)
loss.backward()
inner_opt.step()
return
def forward_meta(self, batch, inner_steps, lr_inner, mode):
names, params = self.get_named_params(no_grads=["proto"])
weights = deepcopy(params)
meta_grad = []
episode_losses = []
query_logits = []
query_preds = []
query_golds = []
current_support_num = 0
current_query_num = 0
support, query = batch["support"], batch["query"]
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'seq_len', 'word_labels']
for i, sent_support_num in enumerate(support['sentence_num']):
sent_query_num = query['sentence_num'][i]
label2tag = query['label2tag'][i]
one_support = {
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys if k in support
}
one_query = {
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys if k in query
}
one_support['label2idx'] = one_query['label2idx'] = {label:idx for idx, label in label2tag.items()}
self.zero_grad()
self.inner_update(one_support, inner_steps, lr_inner) # inner update parameters on support data
if mode == "train":
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query) # evaluate on query data
grad = torch.autograd.grad(qy_loss, params) # meta-update
meta_grad.append(grad)
elif mode == "test":
self.eval()
with torch.no_grad():
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
else:
raise ValueError
episode_losses.append(qy_loss.item())
query_preds.append(qy_pred)
query_golds.append(qy_gold)
query_logits.append(qy_logits)
self.load_weights(names, weights)
current_query_num += sent_query_num
current_support_num += sent_support_num
self.zero_grad()
return {'loss': np.mean(episode_losses), 'names': names, 'grads': meta_grad, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
def get_named_params(self, no_grads=[]):
names = [n for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
params = [p for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
return names, params
def load_weights(self, names, params):
model_params = self.state_dict()
for n, p in zip(names, params):
assert n in model_params
model_params[n].data.copy_(p.data)
return
def load_gradients(self, names, grads):
model_params = self.state_dict(keep_vars=True)
for n, g in zip(names, grads):
if model_params[n].grad is None:
continue
model_params[n].grad.data.add_(g.data) # accumulate
return

Просмотреть файл

@ -0,0 +1,326 @@
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from util.span_sample import convert_bio2spans
class FewShotSpanModel(nn.Module):
def __init__(self, my_word_encoder):
nn.Module.__init__(self)
self.word_encoder = nn.DataParallel(my_word_encoder)
def forward(self, batch):
raise NotImplementedError
def span_accuracy(self, pred, gold):
return np.mean(pred == gold)
def merge_ment(self, span_list, seq_lens, subsentence_nums):
new_span_list = []
subsent_id = 0
for snum in subsentence_nums:
sent_st_id = 0
whole_sent_span_list = []
for k in range(subsent_id, subsent_id + snum):
tmp = [[x[0] + sent_st_id, x[1] + sent_st_id] for x in span_list[k]]
tmp = sorted(tmp, key=lambda x: x[1], reverse=False)
if len(whole_sent_span_list) > 0 and len(tmp) > 0:
if tmp[0][0] == whole_sent_span_list[-1][1] + 1:
whole_sent_span_list[-1][1] = tmp[0][1]
tmp = tmp[1:]
whole_sent_span_list.extend(tmp)
sent_st_id += seq_lens[k]
subsent_id += snum
new_span_list.append(whole_sent_span_list)
assert len(new_span_list) == len(subsentence_nums)
return new_span_list
def get_mention_prob(self, ment_probs, query):
span_indices = query["span_indices"].detach().cpu()
span_masks = query['span_mask'].detach().cpu()
ment_probs = ment_probs.detach().cpu().tolist()
cand_indices_list = []
cand_prob_list = []
cur_span_num = 0
for k in range(len(span_indices)):
mask = span_masks[k, :]
effect_indices = span_indices[k, mask.eq(1)].tolist()
effect_probs = ment_probs[cur_span_num: cur_span_num + len(effect_indices)]
cand_indices_list.append(effect_indices)
cand_prob_list.append(effect_probs)
cur_span_num += len(effect_indices)
return cand_indices_list, cand_prob_list
def filter_by_threshold(self, cand_indices_list, cand_prob_list, threshold):
final_indices_list = []
final_prob_list = []
for indices, probs in zip(cand_indices_list, cand_prob_list):
final_indices_list.append([])
final_prob_list.append([])
for x, y in zip(indices, probs):
if y > threshold:
final_indices_list[-1].append(x)
final_prob_list[-1].append(y)
return final_indices_list, final_prob_list
def seqment_eval(self, ep_preds, query, idx2label, schema):
query_seq_lens = query["seq_len"].detach().cpu().tolist()
ment_indices_list = []
sid = 0
for batch_preds in ep_preds:
for pred in batch_preds.detach().cpu().tolist():
seqs = [idx2label[idx] for idx in pred[:query_seq_lens[sid]]]
ents = convert_bio2spans(seqs, schema)
ment_indices_list.append([[x[1], x[2]] for x in ents])
sid += 1
pred_ments_list = self.merge_ment(ment_indices_list, query_seq_lens, query['subsentence_num'])
gold_ments_list = []
subsent_id = 0
query['ment_labels'] = query['ment_labels'].cpu().tolist()
for snum in query['subsentence_num']:
whole_sent_labels = []
for k in range(subsent_id, subsent_id + snum):
whole_sent_labels += query['ment_labels'][k][:query_seq_lens[k]]
ents = convert_bio2spans([idx2label[idx] for idx in whole_sent_labels], schema)
gold_ments_list.append([[x[1], x[2]] for x in ents])
subsent_id += snum
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_ments_list, gold_ments_list)
logs = {
"index": query["index"],
"seq_len": query_seq_lens,
"pred": pred_ments_list,
"gold": gold_ments_list,
"sentence_num": query["sentence_num"],
"subsentence_num": query['subsentence_num']
}
metric_logs = {
"ment_pred_cnt": pred_cnt,
"ment_gold_cnt": gold_cnt,
"ment_hit_cnt": hit_cnt,
}
return metric_logs, logs
def ment_eval(self, ep_probs, query, threshold=0.5):
if len(ep_probs) == 0:
print("no mention")
return {}, {}
probs = torch.cat(ep_probs, dim=0)
cand_indices_list, cand_prob_list = self.get_mention_prob(probs, query)
ment_indices_list, ment_prob_list = self.filter_by_threshold(cand_indices_list, cand_prob_list, threshold)
gold_ments_list = []
for sent_spans in query['spans']:
gold_ments_list.append([])
for tagid, sp_st, sp_ed in sent_spans:
gold_ments_list[-1].append([sp_st, sp_ed])
assert len(ment_indices_list) == len(query['seq_len'])
assert len(gold_ments_list) == len(query['seq_len'])
assert len(ment_indices_list) >= len(query['index'])
query_seq_lens = query["seq_len"].detach().cpu().tolist()
pred_ments_list = self.merge_ment(ment_indices_list, query_seq_lens, query['subsentence_num'])
gold_ments_list = self.merge_ment(gold_ments_list, query_seq_lens, query['subsentence_num'])
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_ments_list, gold_ments_list)
logs = {
"index": query["index"],
"seq_len": query_seq_lens,
"pred": pred_ments_list,
"gold": gold_ments_list,
"before_ind": ment_indices_list,
"before_prob": ment_prob_list,
"sentence_num": query["sentence_num"],
"subsentence_num": query['subsentence_num']
}
metric_logs = {
"ment_pred_cnt": pred_cnt,
"ment_gold_cnt": gold_cnt,
"ment_hit_cnt": hit_cnt,
}
return metric_logs, logs
def metrics_by_ment(self, pred_spans_list, gold_spans_list):
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
for pred_spans, gold_spans in zip(pred_spans_list, gold_spans_list):
pred_spans = set(map(lambda x: (x[1], x[2]), pred_spans))
gold_spans = set(map(lambda x: (x[1], x[2]), gold_spans))
pred_cnt += len(pred_spans)
gold_cnt += len(gold_spans)
hit_cnt += len(pred_spans.intersection(gold_spans))
return pred_cnt, gold_cnt, hit_cnt
def get_emission(self, ep_logits, query):
assert len(query['label2tag']) == len(query['sentence_num'])
span_indices = query["span_indices"]
cur_sent_num = 0
cand_indices_list = []
cand_prob_list = []
label2tag_list = []
for i, query_sent_num in enumerate(query['sentence_num']):
probs = F.softmax(ep_logits[i], dim=-1).detach().cpu()
cur_span_num = 0
for j in range(query_sent_num):
mask = query['span_mask'][cur_sent_num, :]
effect_indices = span_indices[cur_sent_num, mask.eq(1)].detach().cpu().tolist()
cand_indices_list.append(effect_indices)
cand_prob_list.append(probs[cur_span_num:cur_span_num + len(effect_indices)].numpy())
label2tag_list.append(query['label2tag'][i])
cur_sent_num += 1
cur_span_num += len(effect_indices)
return cand_indices_list, cand_prob_list, label2tag_list
def to_triple_score(self, indices_list, prob_list, label2tag):
tri_list = []
for sp_id, (i, j) in enumerate(indices_list):
k = np.argmax(prob_list[sp_id, :])
if label2tag[k] == 'O':
assert k == 0
continue
else:
tri_list.append([label2tag[k], i, j, prob_list[sp_id, k]])
return tri_list
def greedy_search(self, span_indices_list, prob_list, label2tag_list, seq_len_list, overlap=False, threshold=-1):
output_spans_list = []
for sid in range(len(span_indices_list)):
output_spans_list.append([])
tri_list = self.to_triple_score(span_indices_list[sid], prob_list[sid], label2tag_list[sid])
sorted_tri_list = sorted(tri_list, key=lambda x: x[-1], reverse=True)
used_words = np.zeros(seq_len_list[sid])
for tag, sp_st, sp_ed, score in sorted_tri_list:
if score < threshold:
continue
if sum(used_words[sp_st: sp_ed + 1]) > 0 and (not overlap):
continue
used_words[sp_st: sp_ed + 1] = 1
assert tag != "O"
output_spans_list[sid].append([tag, sp_st, sp_ed])
return output_spans_list
def merge_entity(self, span_list, seq_lens, subsentence_nums):
new_span_list = []
subsent_id = 0
for snum in subsentence_nums:
sent_st_id = 0
whole_sent_span_list = []
for k in range(subsent_id, subsent_id + snum):
tmp = [[x[0], x[1] + sent_st_id, x[2] + sent_st_id] for x in span_list[k]]
tmp = sorted(tmp, key=lambda x: x[2], reverse=False)
if len(whole_sent_span_list) > 0 and len(tmp) > 0:
if whole_sent_span_list[-1][0] == tmp[0][0] and tmp[0][1] == whole_sent_span_list[-1][2] + 1:
whole_sent_span_list[-1][2] = tmp[0][2]
tmp = tmp[1:]
whole_sent_span_list.extend(tmp)
sent_st_id += seq_lens[k]
subsent_id += snum
new_span_list.append(whole_sent_span_list)
assert len(new_span_list) == len(subsentence_nums)
return new_span_list
def seq_eval(self, ep_preds, query, schema):
query_seq_lens = query["seq_len"].detach().cpu().tolist()
subsent_label2tag_ids = []
subsent_pred_ids = []
for k, batch_preds in enumerate(ep_preds):
batch_preds = batch_preds.detach().cpu().tolist()
for pred in batch_preds:
subsent_pred_ids.append(pred)
subsent_label2tag_ids.append(k)
sent_gold_labels = []
sent_pred_labels = []
pred_spans_list = []
gold_spans_list = []
subsent_id = 0
query['word_labels'] = query['word_labels'].cpu().tolist()
for snum in query['subsentence_num']:
whole_sent_gids = []
whole_sent_pids = []
for k in range(subsent_id, subsent_id + snum):
whole_sent_gids += query['word_labels'][k][:query_seq_lens[k]]
whole_sent_pids += subsent_pred_ids[k][:query_seq_lens[k]]
label2tag = query['label2tag'][subsent_label2tag_ids[subsent_id]]
sent_gold_labels.append([label2tag[lid] for lid in whole_sent_gids])
sent_pred_labels.append([label2tag[lid] for lid in whole_sent_pids])
gold_spans_list.append(convert_bio2spans(sent_gold_labels[-1], schema))
pred_spans_list.append(convert_bio2spans(sent_pred_labels[-1], schema))
subsent_id += snum
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_spans_list, gold_spans_list)
ment_pred_cnt, ment_gold_cnt, ment_hit_cnt = self.metrics_by_ment(pred_spans_list, gold_spans_list)
logs = {
"index": query["index"],
"seq_len": query_seq_lens,
"pred": pred_spans_list,
"gold": gold_spans_list,
"sentence_num": query["sentence_num"],
"subsentence_num": query['subsentence_num']
}
metric_logs = {
"ment_pred_cnt": ment_pred_cnt,
"ment_gold_cnt": ment_gold_cnt,
"ment_hit_cnt": ment_hit_cnt,
"ent_pred_cnt": pred_cnt,
"ent_gold_cnt": gold_cnt,
"ent_hit_cnt": hit_cnt
}
return metric_logs, logs
def greedy_eval(self, ep_logits, query, overlap=False, threshold=-1):
if len(ep_logits) == 0:
print("no entity")
return {}, {}
query_seq_lens = query["seq_len"].detach().cpu().tolist()
cand_indices_list, cand_prob_list, label2tag_list = self.get_emission(ep_logits, query)
pred_spans_list = self.greedy_search(cand_indices_list, cand_prob_list, label2tag_list, query_seq_lens, overlap, threshold)
gold_spans_list = []
subsent_idx = 0
for sent_spans, label2tag in zip(query['spans'], label2tag_list):
gold_spans_list.append([])
for tagid, sp_st, sp_ed in sent_spans:
gold_spans_list[-1].append([label2tag[tagid], sp_st, sp_ed])
subsent_idx += 1
assert len(pred_spans_list) == len(query['seq_len'])
assert len(gold_spans_list) == len(query['seq_len'])
assert len(pred_spans_list) >= len(query['index'])
pred_spans_list = self.merge_entity(pred_spans_list, query_seq_lens, query['subsentence_num'])
gold_spans_list = self.merge_entity(gold_spans_list, query_seq_lens, query['subsentence_num'])
pred_cnt, gold_cnt, hit_cnt = self.metrics_by_entity(pred_spans_list, gold_spans_list)
ment_pred_cnt, ment_gold_cnt, ment_hit_cnt = self.metrics_by_ment(pred_spans_list, gold_spans_list)
logs = {
"index": query["index"],
"seq_len": query_seq_lens,
"pred": pred_spans_list,
"gold": gold_spans_list,
"before_ind": cand_indices_list,
"before_prob": cand_prob_list,
"label_tag": label2tag_list,
"sentence_num": query["sentence_num"],
"subsentence_num": query['subsentence_num']
}
metric_logs = {
"ment_pred_cnt": ment_pred_cnt,
"ment_gold_cnt": ment_gold_cnt,
"ment_hit_cnt": ment_hit_cnt,
"ent_pred_cnt": pred_cnt,
"ent_gold_cnt": gold_cnt,
"ent_hit_cnt": hit_cnt
}
return metric_logs, logs
def metrics_by_entity(self, pred_spans_list, gold_spans_list):
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
for pred_spans, gold_spans in zip(pred_spans_list, gold_spans_list):
pred_spans = set(map(lambda x: tuple(x), pred_spans))
gold_spans = set(map(lambda x: tuple(x), gold_spans))
pred_cnt += len(pred_spans)
gold_cnt += len(gold_spans)
hit_cnt += len(pred_spans.intersection(gold_spans))
return pred_cnt, gold_cnt, hit_cnt
def metrics_by_ment(self, pred_spans_list, gold_spans_list):
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
for pred_spans, gold_spans in zip(pred_spans_list, gold_spans_list):
pred_spans = set(map(lambda x: (x[1], x[2]), pred_spans))
gold_spans = set(map(lambda x: (x[1], x[2]), gold_spans))
pred_cnt += len(pred_spans)
gold_cnt += len(gold_spans)
hit_cnt += len(pred_spans.intersection(gold_spans))
return pred_cnt, gold_cnt, hit_cnt

Просмотреть файл

@ -0,0 +1,204 @@
import torch
from torch import nn
import torch.nn.functional as F
from .span_fewshotmodel import FewShotSpanModel
from copy import deepcopy
from .loss_model import FocalLoss
import numpy as np
class SpanProtoCls(FewShotSpanModel):
def __init__(self, span_encoder, use_oproto, dot=False, normalize="none",
temperature=None, use_focal=False):
super(SpanProtoCls, self).__init__(span_encoder)
self.dot = dot
self.normalize = normalize
self.temperature = temperature
self.use_focal = use_focal
self.use_oproto = use_oproto
self.proto = None
if use_focal:
self.base_loss_fct = FocalLoss(gamma=1.0)
print("use focal loss")
else:
self.base_loss_fct = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
print("use cross entropy loss")
print("use dot : {} use normalizatioin: {} use temperature: {}".format(self.dot, self.normalize,
self.temperature if self.temperature else "none"))
self.cached_o_proto = torch.zeros(span_encoder.span_dim, requires_grad=False)
return
def loss_fct(self, logits, targets, inst_weights=None):
if inst_weights is None:
loss = self.base_loss_fct(logits, targets)
loss = loss.mean()
else:
targets = torch.clamp(targets, min=0)
one_hot_targets = torch.zeros(logits.size(), device=logits.device).scatter_(1, targets.unsqueeze(1), 1)
soft_labels = inst_weights.unsqueeze(1) * one_hot_targets + (1 - one_hot_targets) * (1 - inst_weights).unsqueeze(1) / (logits.size(1) - 1)
logp = F.log_softmax(logits, dim=-1)
loss = - (logp * soft_labels).sum(1)
loss = loss.mean()
return loss
def __dist__(self, x, y, dim):
if self.normalize == 'l2':
x = F.normalize(x, p=2, dim=-1)
y = F.normalize(y, p=2, dim=-1)
if self.dot:
sim = (x * y).sum(dim)
else:
sim = -(torch.pow(x - y, 2)).sum(dim)
if self.temperature:
sim = sim / self.temperature
return sim
def __batch_dist__(self, S_emb, Q_emb, Q_mask):
Q_emb = Q_emb[Q_mask.eq(1), :].view(-1, Q_emb.size(-1))
dist = self.__dist__(S_emb.unsqueeze(0), Q_emb.unsqueeze(1), 2)
return dist
def __get_proto__(self, S_emb, S_tag, S_mask):
proto = []
embedding = S_emb[S_mask.eq(1), :].view(-1, S_emb.size(-1))
S_tag = S_tag[S_mask.eq(1)]
if self.use_oproto:
st_idx = 0
else:
st_idx = 1
proto = [self.cached_o_proto]
for label in range(st_idx, torch.max(S_tag) + 1):
proto.append(torch.mean(embedding[S_tag.eq(label), :], 0))
proto = torch.stack(proto, dim=0)
return proto
def __get_proto_dist__(self, Q_emb, Q_mask):
dist = self.__batch_dist__(self.proto, Q_emb, Q_mask)
if not self.use_oproto:
dist[:, 0] = -1000000
return dist
def forward_step(self, query):
if query['span_mask'].sum().item() == 0: # there is no query mentions
print("no query mentions")
empty_tensor = torch.tensor([], device=query['word'].device)
zero_tensor = torch.tensor([0], device=query['word'].device)
return empty_tensor, empty_tensor, empty_tensor, zero_tensor
query_span_emb = self.word_encoder(query['word'], query['word_mask'], word_to_piece_inds=query['word_to_piece_ind'],
word_to_piece_ends=query['word_to_piece_end'], span_indices=query['span_indices'])
logits = self.__get_proto_dist__(query_span_emb, query['span_mask'])
golds = query["span_tag"][query["span_mask"].eq(1)].view(-1)
query_span_weights = query["span_weights"][query["span_mask"].eq(1)].view(-1)
if self.use_oproto:
loss = self.loss_fct(logits, golds, inst_weights=query_span_weights)
else:
loss = self.loss_fct(logits[:, 1:], golds - 1, inst_weights=query_span_weights)
_, preds = torch.max(logits, dim=-1)
return logits, preds, golds, loss
def init_proto(self, support_data):
self.eval()
with torch.no_grad():
support_span_emb = self.word_encoder(support_data['word'], support_data['word_mask'], word_to_piece_inds=support_data['word_to_piece_ind'],
word_to_piece_ends=support_data['word_to_piece_end'], span_indices=support_data['span_indices'])
self.cached_o_proto = self.cached_o_proto.to(support_span_emb.device)
proto = self.__get_proto__(support_span_emb, support_data['span_tag'], support_data['span_mask'])
self.proto = nn.Parameter(proto.data, requires_grad=True)
return
def inner_update(self, support_data, inner_steps, lr_inner):
self.init_proto(support_data)
parameters_to_optimize = list(self.named_parameters())
decay_params = []
nodecay_params = []
decay_names = []
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
for n, p in parameters_to_optimize:
if p.requires_grad:
if ("bert." in n) and (not any(nd in n for nd in no_decay)):
decay_params.append(p)
decay_names.append(n)
else:
nodecay_params.append(p)
parameters_groups = [
{'params': decay_params,
'lr': lr_inner, 'weight_decay': 1e-3},
{'params': nodecay_params,
'lr': lr_inner, 'weight_decay': 0},
]
inner_opt = torch.optim.AdamW(parameters_groups, lr=lr_inner)
self.train()
for _ in range(inner_steps):
inner_opt.zero_grad()
_, _, _, loss = self.forward_step(support_data)
if loss.requires_grad:
loss.backward()
inner_opt.step()
return
def forward_meta(self, batch, inner_steps, lr_inner, mode):
names, params = self.get_named_params(no_grads=["proto"])
weights = deepcopy(params)
meta_grad = []
episode_losses = []
query_logits = []
query_preds = []
query_golds = []
current_support_num = 0
current_query_num = 0
support, query = batch["support"], batch["query"]
data_keys = ['word', 'word_mask', 'word_to_piece_ind', 'word_to_piece_end', 'span_indices', 'span_mask', 'span_tag', 'span_weights']
for i, sent_support_num in enumerate(support['sentence_num']):
sent_query_num = query['sentence_num'][i]
one_support = {
k: support[k][current_support_num:current_support_num + sent_support_num] for k in data_keys if k in support
}
one_query = {
k: query[k][current_query_num:current_query_num + sent_query_num] for k in data_keys if k in query
}
self.zero_grad()
self.inner_update(one_support, inner_steps, lr_inner)
if mode == "train":
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
if one_query['span_mask'].sum().item() == 0:
pass
else:
grad = torch.autograd.grad(qy_loss, params)
meta_grad.append(grad)
else:
self.eval()
with torch.no_grad():
qy_logits, qy_pred, qy_gold, qy_loss = self.forward_step(one_query)
episode_losses.append(qy_loss.item())
query_preds.append(qy_pred)
query_golds.append(qy_gold)
query_logits.append(qy_logits)
self.load_weights(names, weights)
current_query_num += sent_query_num
current_support_num += sent_support_num
self.zero_grad()
return {'loss': np.mean(episode_losses), 'names': names, 'grads': meta_grad, 'preds': query_preds, 'golds': query_golds, 'logits': query_logits}
def get_named_params(self, no_grads=[]):
names = [n for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
params = [p for n, p in self.named_parameters() if p.requires_grad and (n not in no_grads)]
return names, params
def load_weights(self, names, params):
model_params = self.state_dict()
for n, p in zip(names, params):
assert n in model_params
model_params[n].data.copy_(p.data)
return
def load_gradients(self, names, grads):
model_params = self.state_dict(keep_vars=True)
for n, g in zip(names, grads):
if model_params[n].grad is None:
continue
model_params[n].grad.data.add_(g.data) # accumulate

Просмотреть файл

@ -0,0 +1,135 @@
seed=${1:-12}
d=${2:-1}
K=${3:-1}
dataset=${4:-"ner"}
mft_steps=${5:-30}
tft_steps=${6:-30}
schema=${7:-"BIO"}
mloss=${8:-2}
aids=8-9-10-11
batch_size=8
base_dir=data
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
then
root=${base_dir}/${dataset}/${K}shot
train_fn=train_domain${d}.txt
val_fn=valid_domain${d}.txt
test_fn=test_domain${d}.txt
ep_dir=$root
ep_train_fn=train_domain${d}_id.jsonl
ep_val_fn=valid_domain${d}_id.jsonl
ep_test_fn=test_domain${d}_id.jsonl
bio=True
train_iter=800
dev_iter=-1
val_steps=100
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
then
root=${base_dir}/few-nerd/${dataset}
ep_dir=${base_dir}/few-nerd/episode-${dataset}
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=False
train_iter=1000
dev_iter=500
val_steps=500
else
root=${base_dir}/fewevent
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_dir=${root}
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=True
train_iter=1000
dev_iter=-1
val_steps=500
fi
if [ $schema == 'IO' ]
then
crf=False
else
crf=True
fi
tft_steps=${mft_steps}
name=joint_m${mft_steps}_max${mloss}_${schema}_t${tft_steps}_bz${batch_size}
inner_steps=1
ckpt_dir=outputs/${dataset}/${d}-${K}shot/joint/${name}/seed${seed}
output_dir=${ckpt_dir}/
log_name=${dataset}-${d}-${K}-${name}-${seed}-test.log
python train_joint.py --mode test-twostage \
--seed ${seed} \
--root ${root} \
--train ${train_fn} \
--val ${val_fn} \
--test ${test_fn} \
--ep_dir ${ep_dir} \
--ep_train ${ep_train_fn} \
--ep_val ${ep_val_fn} \
--ep_test ${ep_test_fn} \
--output_dir ${output_dir} \
--N ${d} \
--K ${K} \
--Q 1 \
--bio ${bio} \
--max_loss ${mloss} \
--use_crf ${crf} \
--schema ${schema} \
--adapter_layer_ids ${aids} \
--encoder_name_or_path bert-base-uncased \
--last_n_layer -1 \
--max_length 128 \
--word_encode_choice first \
--span_encode_choice avg \
--learning_rate 0.001 \
--bert_learning_rate 5e-5 \
--bert_weight_decay 0.01 \
--train_iter ${train_iter} \
--dev_iter ${dev_iter} \
--val_steps ${val_steps} \
--warmup_step 0 \
--log_steps 50 \
--type_lam 1 \
--use_adapter True \
--eval_batch_size 1 \
--use_width False \
--use_case False \
--dropout 0.5 \
--max_grad_norm 5 \
--dot False \
--normalize l2 \
--temperature 0.1 \
--use_focal False \
--use_att False \
--att_hidden_dim 100 \
--adapter_size 64 \
--use_oproto False \
--hou_eval_ep ${hou_eval_ep} \
--overlap False \
--type_threshold 0 \
--use_maml True \
--train_inner_lr 2e-5 \
--train_inner_steps ${inner_steps} \
--warmup_prop_inner 0 \
--eval_inner_lr 2e-5 \
--eval_ment_inner_steps ${mft_steps} \
--eval_type_inner_steps ${tft_steps} \
--load_ckpt ${ckpt_dir}/model.pth.tar

Просмотреть файл

@ -0,0 +1,118 @@
seed=${1:-12}
d=${2:-1}
K=${3:-1}
dataset=${4:-"ner"}
ft_steps=${5:-20}
ment_name=${6:-"sqlment_maml30_max2_BIO_bz16"}
base_dir=data
batch_size=16
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
then
root=${base_dir}/${dataset}/${K}shot
train_fn=train_domain${d}.txt
val_fn=valid_domain${d}.txt
test_fn=test_domain${d}.txt
ep_dir=$root
ep_train_fn=train_domain${d}_id.jsonl
ep_val_fn=valid_domain${d}_id.jsonl
ep_test_fn=test_domain${d}_id.jsonl
bio=True
train_iter=800
dev_iter=-1
val_steps=100
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
then
root=${base_dir}/few-nerd/${dataset}
ep_dir=${base_dir}/few-nerd/episode-${dataset}
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=False
train_iter=1000
dev_iter=500
val_steps=500
else
root=${base_dir}/fewevent
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_dir=${root}
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=True
train_iter=1000
dev_iter=-1
val_steps=500
fi
name=type_mamlcls${ft_steps}_bz${batch_size}
inner_steps=1
output_dir=outputs/${dataset}/${d}-${K}shot/${ment_name}/seed${seed}/${name}
ckpt_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
mkdir -p ${output_dir}
val_ment_fn=outputs/${dataset}/${d}-${K}shot/${ment_name}/seed${seed}/dev_ment.json
test_ment_fn=outputs/${dataset}/${d}-${K}shot/${ment_name}/seed${seed}/test_ment.json
python train_type.py --mode test \
--seed ${seed} \
--root ${root} \
--train ${train_fn} \
--val ${val_fn} \
--test ${test_fn} \
--ep_dir ${ep_dir} \
--ep_train ${ep_train_fn} \
--ep_val ${ep_val_fn} \
--ep_test ${ep_test_fn} \
--val_ment_fn ${val_ment_fn} \
--test_ment_fn ${test_ment_fn} \
--output_dir ${output_dir} \
--load_ckpt ${ckpt_dir}/model.pth.tar \
--N ${d} \
--K ${K} \
--Q 1 \
--bio ${bio} \
--encoder_name_or_path bert-base-uncased \
--last_n_layer -4 \
--max_length 128 \
--word_encode_choice first \
--span_encode_choice avg \
--learning_rate 0.001 \
--bert_learning_rate 5e-5 \
--bert_weight_decay 0.01 \
--train_iter ${train_iter} \
--dev_iter ${dev_iter} \
--val_steps ${val_steps} \
--warmup_step 0 \
--log_steps 50 \
--train_batch_size ${batch_size} \
--eval_batch_size 1 \
--use_width False \
--use_case False \
--dropout 0.5 \
--max_grad_norm 5 \
--dot False \
--normalize l2 \
--temperature 0.1 \
--use_focal False \
--use_att False \
--att_hidden_dim 100 \
--use_oproto False \
--hou_eval_ep ${hou_eval_ep} \
--overlap False \
--type_threshold 0 \
--use_maml True \
--train_inner_lr 2e-5 \
--train_inner_steps ${inner_steps} \
--warmup_prop_inner 0 \
--eval_inner_lr 2e-5 \
--eval_inner_steps ${ft_steps}

Просмотреть файл

@ -0,0 +1,130 @@
seed=${1:-12}
d=${2:-5}
K=${3:-1}
dataset=${4:-"inter"}
mft_steps=${5:-3}
tft_steps=${6:-3}
schema=${7:-"BIO"}
mloss=${8:-2}
base_dir=data
batch_size=8
aids=8-9-10-11
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
then
root=${base_dir}/${dataset}/${K}shot
train_fn=train_domain${d}.txt
val_fn=valid_domain${d}.txt
test_fn=test_domain${d}.txt
ep_dir=$root
ep_train_fn=train_domain${d}_id.jsonl
ep_val_fn=valid_domain${d}_id.jsonl
ep_test_fn=test_domain${d}_id.jsonl
bio=True
train_iter=800
dev_iter=-1
val_steps=100
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
then
root=${base_dir}/few-nerd/${dataset}
ep_dir=${base_dir}/few-nerd/episode-${dataset}
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=False
train_iter=1000
dev_iter=500
val_steps=500
else
root=${base_dir}/fewevent
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_dir=${root}
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=True
train_iter=1000
dev_iter=-1
val_steps=500
fi
if [ $schema == 'IO' ]
then
crf=False
else
crf=True
fi
tft_steps=${mft_steps}
name=joint_m${mft_steps}_max${mloss}_${schema}_t${tft_steps}_bz${batch_size}
inner_steps=1
output_dir=outputs/${dataset}/${d}-${K}shot/joint/${name}/seed${seed}
mkdir -p ${output_dir}
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
python train_joint.py --mode train \
--seed ${seed} \
--root ${root} \
--train ${train_fn} \
--val ${val_fn} \
--test ${test_fn} \
--ep_dir ${ep_dir} \
--ep_train ${ep_train_fn} \
--ep_val ${ep_val_fn} \
--ep_test ${ep_test_fn} \
--output_dir ${output_dir} \
--N ${d} \
--K ${K} \
--Q 1 \
--bio ${bio} \
--max_loss ${mloss} \
--use_crf ${crf} \
--schema ${schema} \
--adapter_layer_ids ${aids} \
--encoder_name_or_path bert-base-uncased \
--last_n_layer -1 \
--max_length 128 \
--word_encode_choice first \
--span_encode_choice avg \
--learning_rate 0.001 \
--bert_learning_rate 5e-5 \
--bert_weight_decay 0.01 \
--train_iter ${train_iter} \
--dev_iter ${dev_iter} \
--val_steps ${val_steps} \
--warmup_step 0 \
--log_steps 50 \
--train_batch_size ${batch_size} \
--type_lam 1 \
--use_adapter True \
--eval_batch_size 1 \
--use_width False \
--use_case False \
--dropout 0.5 \
--max_grad_norm 5 \
--dot False \
--normalize l2 \
--temperature 0.1 \
--use_focal False \
--use_att False \
--att_hidden_dim 100 \
--adapter_size 64 \
--use_oproto False \
--hou_eval_ep ${hou_eval_ep} \
--overlap False \
--type_threshold 0 \
--use_maml True \
--train_inner_lr 2e-5 \
--train_inner_steps ${inner_steps} \
--warmup_prop_inner 0 \
--eval_inner_lr 2e-5 \
--eval_ment_inner_steps ${mft_steps} \
--eval_type_inner_steps ${tft_steps}

Просмотреть файл

@ -0,0 +1,115 @@
seed=${1:-12}
d=${2:-1}
K=${3:-1}
dataset=${4:-"ner"}
schema=${5:-"BIO"}
mloss=${6:-2}
ft_steps=${7:-30}
base_dir=data
wp_inner=0
lr=5e-5
batch_size=16
echo $batch_size
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
then
root=${base_dir}/${dataset}/${K}shot
train_fn=train_domain${d}.txt
val_fn=valid_domain${d}.txt
test_fn=test_domain${d}.txt
ep_dir=$root
ep_train_fn=train_domain${d}_id.jsonl
ep_val_fn=valid_domain${d}_id.jsonl
ep_test_fn=test_domain${d}_id.jsonl
bio=True
train_iter=800
dev_iter=-1
val_steps=100
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
then
root=${base_dir}/few-nerd/${dataset}
ep_dir=${base_dir}/few-nerd/episode-${dataset}
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=False
train_iter=1000
dev_iter=500
val_steps=500
else
root=${base_dir}/fewevent
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_dir=${root}
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=True
train_iter=1000
dev_iter=-1
val_steps=500
fi
if [ $schema == 'IO' ]
then
crf=False
else
crf=True
fi
name=sqlment_maml${ft_steps}_max${mloss}_${schema}_bz${batch_size}
inner_steps=2
output_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
mkdir -p ${output_dir}
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
python train_sql.py --mode train \
--seed ${seed} \
--root ${root} \
--train ${train_fn} \
--val ${val_fn} \
--test ${test_fn} \
--ep_dir ${ep_dir} \
--ep_train ${ep_train_fn} \
--ep_val ${ep_val_fn} \
--ep_test ${ep_test_fn} \
--output_dir ${output_dir} \
--encoder_name_or_path bert-base-uncased \
--last_n_layer -1 \
--max_length 128 \
--word_encode_choice first \
--warmup_step 0 \
--learning_rate 0.001 \
--bert_learning_rate ${lr} \
--bert_weight_decay 0.01 \
--log_steps 50 \
--train_iter ${train_iter} \
--dev_iter ${dev_iter} \
--val_steps ${val_steps} \
--train_batch_size ${batch_size} \
--eval_batch_size 1 \
--max_loss ${mloss} \
--use_crf ${crf} \
--schema ${schema} \
--dropout 0.5 \
--max_grad_norm 5 \
--eval_all_after_train True \
--bio ${bio} \
--use_maml True \
--train_inner_lr 2e-5 \
--train_inner_steps ${inner_steps} \
--warmup_prop_inner ${wp_inner} \
--eval_inner_lr 2e-5 \
--eval_inner_steps ${ft_steps} | tee ${log_name}
mv ${log_name} ${output_dir}/train.log

Просмотреть файл

@ -0,0 +1,70 @@
seed=${1:-12}
d=${2:-2}
K=${3:-1}
mloss=${4:-2}
ft_steps=${5:-20}
base_dir=data
batch_size=16
dataset=postag
root=${base_dir}/postag/${d}-${K}shot
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_dir=$root
ep_train_fn=train_id.jsonl
ep_val_fn=dev_id.jsonl
ep_test_fn=test_id.jsonl
train_iter=1000
dev_iter=-1
val_steps=100
name=pos_mamlcls${ft_steps}_max${mloss}_bz${batch_size}
inner_steps=1
output_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
mkdir -p ${output_dir}
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
python train_pos.py --mode train \
--seed ${seed} \
--root ${root} \
--train ${train_fn} \
--val ${val_fn} \
--test ${test_fn} \
--ep_dir ${ep_dir} \
--ep_train ${ep_train_fn} \
--ep_val ${ep_val_fn} \
--ep_test ${ep_test_fn} \
--output_dir ${output_dir} \
--N ${d} \
--K ${K} \
--Q 1 \
--encoder_name_or_path bert-base-uncased \
--last_n_layer -1 \
--max_length 128 \
--word_encode_choice first \
--learning_rate 0.001 \
--bert_learning_rate 5e-5 \
--bert_weight_decay 0.01 \
--train_iter ${train_iter} \
--dev_iter ${dev_iter} \
--val_steps ${val_steps} \
--warmup_step 0 \
--log_steps 50 \
--train_batch_size ${batch_size} \
--eval_batch_size 1 \
--dropout 0.5 \
--max_grad_norm 5 \
--dot False \
--normalize l2 \
--temperature 0.1 \
--max_loss ${mloss} \
--use_maml True \
--train_inner_lr 2e-5 \
--train_inner_steps ${inner_steps} \
--warmup_prop_inner 0.1 \
--eval_inner_lr 2e-5 \
--eval_inner_steps ${ft_steps}

Просмотреть файл

@ -0,0 +1,114 @@
seed=${1:-12}
d=${2:-1}
K=${3:-1}
dataset=${4:-"ner"}
ft_steps=${5:-20}
base_dir=data
batch_size=16
if [ $dataset == 'snips' ] || [ $dataset == "ner" ]
then
root=${base_dir}/${dataset}/${K}shot
train_fn=train_domain${d}.txt
val_fn=valid_domain${d}.txt
test_fn=test_domain${d}.txt
ep_dir=$root
ep_train_fn=train_domain${d}_id.jsonl
ep_val_fn=valid_domain${d}_id.jsonl
ep_test_fn=test_domain${d}_id.jsonl
bio=True
train_iter=800
dev_iter=-1
val_steps=100
elif [ $dataset == 'inter' ] || [ $dataset == "intra" ]
then
root=${base_dir}/few-nerd/${dataset}
ep_dir=${base_dir}/few-nerd/episode-${dataset}
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=False
train_iter=1000
dev_iter=500
val_steps=500
else
root=${base_dir}/fewevent
train_fn=train.txt
val_fn=dev.txt
test_fn=test.txt
ep_dir=${root}
ep_train_fn=train_${d}_${K}_id.jsonl
ep_val_fn=dev_${d}_${K}_id.jsonl
ep_test_fn=test_${d}_${K}_id.jsonl
bio=True
train_iter=1000
dev_iter=-1
val_steps=500
fi
name=type_mamlcls${ft_steps}_bz${batch_size}
inner_steps=1
output_dir=outputs/${dataset}/${d}-${K}shot/${name}/seed${seed}
mkdir -p ${output_dir}
log_name=${dataset}-${d}-${K}-${name}-${seed}-train.log
python train_type.py --mode train \
--seed ${seed} \
--root ${root} \
--train ${train_fn} \
--val ${val_fn} \
--test ${test_fn} \
--ep_dir ${ep_dir} \
--ep_train ${ep_train_fn} \
--ep_val ${ep_val_fn} \
--ep_test ${ep_test_fn} \
--output_dir ${output_dir} \
--N ${d} \
--K ${K} \
--Q 1 \
--bio ${bio} \
--encoder_name_or_path bert-base-uncased \
--last_n_layer -4 \
--max_length 128 \
--word_encode_choice first \
--span_encode_choice avg \
--learning_rate 0.001 \
--bert_learning_rate 5e-5 \
--bert_weight_decay 0.01 \
--train_iter ${train_iter} \
--dev_iter ${dev_iter} \
--val_steps ${val_steps} \
--warmup_step 0 \
--log_steps 50 \
--train_batch_size ${batch_size} \
--eval_batch_size 1 \
--use_width False \
--use_case False \
--dropout 0.5 \
--max_grad_norm 5 \
--dot False \
--normalize l2 \
--temperature 0.1 \
--use_focal False \
--use_att False \
--att_hidden_dim 100 \
--use_oproto False \
--hou_eval_ep ${hou_eval_ep} \
--overlap False \
--type_threshold 0 \
--use_maml True \
--train_inner_lr 2e-5 \
--train_inner_steps ${inner_steps} \
--warmup_prop_inner 0 \
--eval_inner_lr 2e-5 \
--eval_inner_steps ${ft_steps} | tee ${log_name}
mv ${log_name} ${output_dir}/train.log

Просмотреть файл

@ -0,0 +1,267 @@
import sys
import torch
import numpy as np
import json
import argparse
import os
import random
from util.joint_loader import get_joint_loader
from trainer.joint_trainer import JointTrainer
from util.span_encoder import BERTSpanEncoder
from util.log_utils import eval_ent_log, cal_episode_prf, set_seed, save_json
from model.joint_model import SelectedJointModel
def add_args():
def str2bool(arg):
if arg.lower() == "true":
return True
return False
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='test', type=str,
help='train / test')
parser.add_argument('--load_ckpt', default=None,
help='load ckpt')
parser.add_argument('--output_dir', default=None,
help='output dir')
parser.add_argument('--log_dir', default=None,
help='log dir')
parser.add_argument('--root', default=None, type=str,
help='data root dir')
parser.add_argument('--train', default='train.txt',
help='train file')
parser.add_argument('--val', default='dev.txt',
help='val file')
parser.add_argument('--test', default='test.txt',
help='test file')
parser.add_argument('--train_ment_fn', default=None,
help='train file')
parser.add_argument('--val_ment_fn', default=None,
help='val file')
parser.add_argument('--test_ment_fn', default=None,
help='test file')
parser.add_argument('--ep_dir', default=None, type=str)
parser.add_argument('--ep_train', default=None, type=str)
parser.add_argument('--ep_val', default=None, type=str)
parser.add_argument('--ep_test', default=None, type=str)
parser.add_argument('--N', default=None, type=int,
help='N way')
parser.add_argument('--K', default=None, type=int,
help='K shot')
parser.add_argument('--Q', default=None, type=int,
help='Num of query per class')
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
parser.add_argument('--word_encode_choice', default='first', type=str)
parser.add_argument('--span_encode_choice', default=None, type=str)
parser.add_argument('--max_length', default=128, type=int,
help='max length')
parser.add_argument('--max_span_len', default=8, type=int)
parser.add_argument('--max_neg_ratio', default=5, type=int)
parser.add_argument('--last_n_layer', default=-4, type=int)
parser.add_argument('--dot', type=str2bool, default=False,
help='use dot instead of L2 distance for knn')
parser.add_argument("--normalize", default='none', type=str, choices=['none', 'l2'])
parser.add_argument("--temperature", default=None, type=float)
parser.add_argument("--use_width", default=False, type=str2bool)
parser.add_argument("--width_dim", default=20, type=int)
parser.add_argument("--use_case", default=False, type=str2bool)
parser.add_argument("--case_dim", default=20, type=int)
parser.add_argument('--dropout', default=0.5, type=float,
help='dropout rate')
parser.add_argument('--log_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--val_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--train_batch_size', default=4, type=int,
help='batch size')
parser.add_argument('--eval_batch_size', default=1, type=int,
help='batch size')
parser.add_argument('--train_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--dev_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--test_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument("--max_grad_norm", default=None, type=float)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="lr rate")
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
parser.add_argument("--warmup_step", default=0, type=int)
parser.add_argument('--seed', default=42, type=int,
help='seed')
parser.add_argument('--fp16', action='store_true', default=False,
help='use nvidia apex fp16')
parser.add_argument('--use_focal', default=False, type=str2bool)
parser.add_argument('--iou_thred', default=None, type=float)
parser.add_argument('--use_att', default=False, type=str2bool)
parser.add_argument('--att_hidden_dim', default=-1, type=int)
parser.add_argument('--label_fn', default=None, type=str)
parser.add_argument('--hou_eval_ep', default=False, type=str2bool)
parser.add_argument('--use_maml', default=False, type=str2bool)
parser.add_argument('--warmup_prop_inner', default=0, type=float)
parser.add_argument('--train_inner_lr', default=0, type=float)
parser.add_argument('--train_inner_steps', default=0, type=int)
parser.add_argument('--eval_inner_lr', default=0, type=float)
parser.add_argument('--eval_type_inner_steps', default=0, type=int)
parser.add_argument('--eval_ment_inner_steps', default=0, type=int)
parser.add_argument('--overlap', default=False, type=str2bool)
parser.add_argument('--type_lam', default=1, type=float)
parser.add_argument('--use_adapter', default=False, type=str2bool)
parser.add_argument('--adapter_size', default=64, type=int)
parser.add_argument('--type_threshold', default=-1, type=float)
parser.add_argument('--use_oproto', default=False, type=str2bool)
parser.add_argument('--bio', default=False, type=str2bool)
parser.add_argument('--schema', default='IO', type=str, choices=['IO', 'BIO', 'BIOES'])
parser.add_argument('--use_crf', type=str2bool, default=False)
parser.add_argument('--max_loss', default=0, type=float)
parser.add_argument('--adapter_layer_ids', default='9-10-11', type=str)
opt = parser.parse_args()
return opt
def main(opt):
print("Joint Model pipeline {} way {} shot".format(opt.N, opt.K))
set_seed(opt.seed)
if opt.mode == "train":
output_dir = opt.output_dir
print("Output dir is ", output_dir)
log_dir = os.path.join(
output_dir,
"logs",
)
opt.log_dir = log_dir
if not os.path.exists(opt.output_dir):
os.makedirs(opt.output_dir)
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
else:
if not os.path.exists(opt.output_dir):
os.makedirs(opt.output_dir)
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
print('loading model and tokenizer...')
print("use adapter: ", opt.use_adapter)
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
word_encode_choice=opt.word_encode_choice,
span_encode_choice=opt.span_encode_choice,
use_width=opt.use_width, width_dim=opt.width_dim, use_case=opt.use_case,
case_dim=opt.case_dim,
drop_p=opt.dropout, use_att=opt.use_att,
att_hidden_dim=opt.att_hidden_dim, use_adapter=opt.use_adapter, adapter_size=opt.adapter_size,
adapter_layer_ids=[int(x) for x in opt.adapter_layer_ids.split('-')])
print('loading data')
if opt.mode == "train":
train_loader = get_joint_loader(os.path.join(opt.root, opt.train), "train", word_encoder, N=opt.N,
K=opt.K,
Q=opt.Q, batch_size=opt.train_batch_size, max_length=opt.max_length,
shuffle=True,
bio=opt.bio,
schema=opt.schema,
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None,
query_file=opt.train_ment_fn,
iou_thred=opt.iou_thred,
use_oproto=opt.use_oproto,
label_fn=opt.label_fn)
print(os.path.join(opt.root, opt.val))
dev_loader = get_joint_loader(os.path.join(opt.root, opt.val), "test", word_encoder, N=opt.N, K=opt.K,
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
shuffle=False,
bio=opt.bio,
schema=opt.schema,
debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None,
query_file=opt.val_ment_fn,
iou_thred=opt.iou_thred,
use_oproto=opt.use_oproto,
label_fn=opt.label_fn)
test_loader = get_joint_loader(os.path.join(opt.root, opt.test), "test", word_encoder, N=opt.N, K=opt.K,
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
shuffle=False,
bio=opt.bio,
schema=opt.schema,
debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None,
query_file=opt.test_ment_fn,
iou_thred=opt.iou_thred, hidden_query_label=False,
use_oproto=opt.use_oproto,
label_fn=opt.label_fn)
print("max_length: {}".format(opt.max_length))
print('mode: {}'.format(opt.mode))
print("{}-way-{}-shot Proto Few-Shot NER".format(opt.N, opt.K))
model = SelectedJointModel(word_encoder, num_tag=len(test_loader.dataset.ment_label2tag),
ment_label2idx=test_loader.dataset.ment_tag2label,
schema=opt.schema, use_crf=opt.use_crf, max_loss=opt.max_loss,
use_oproto=opt.use_oproto, dot=opt.dot, normalize=opt.normalize,
temperature=opt.temperature, use_focal=opt.use_focal, type_lam=opt.type_lam)
num_params = sum(param.numel() for param in model.parameters())
print("total parameter numbers: ", num_params)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
trainer = JointTrainer()
assert opt.eval_batch_size == 1
if opt.mode == 'train':
trainer.train(model, opt, device, train_loader, dev_loader, dev_pred_fn=os.path.join(opt.output_dir, "dev_metrics.json"), dev_log_fn=os.path.join(opt.output_dir, "dev_log.txt"), load_ckpt=opt.load_ckpt)
test_loss, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_ment_logs, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=os.path.join(opt.output_dir, "model.pth.tar"),
ment_update_iter=opt.eval_ment_inner_steps,
type_update_iter=opt.eval_type_inner_steps,
learning_rate=opt.eval_inner_lr,
overlap=opt.overlap,
threshold=opt.type_threshold,
eval_mode="test-twostage")
if opt.hou_eval_ep:
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
print('[TEST] loss: {0:2.6f} | [Entity] precision: {1:3.4f}, recall: {2:3.4f}, f1: {3:3.4f}'\
.format(test_loss, test_p, test_r, test_f1) + '\r')
else:
_, dev_ment_p, dev_ment_r, dev_ment_f1, dev_p, dev_r, dev_f1, dev_ment_logs, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=opt.load_ckpt,
eval_iter=opt.dev_iter,
ment_update_iter=opt.eval_ment_inner_steps,
type_update_iter=opt.eval_type_inner_steps,
learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold, eval_mode=opt.mode)
if opt.hou_eval_ep:
dev_p, dev_r, dev_f1 = cal_episode_prf(dev_logs)
print("Mention dev precision {:.5f} recall {:.5f} f1 {:.5f}".format(dev_ment_p, dev_ment_r, dev_ment_f1))
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
with open(os.path.join(opt.output_dir, "dev_metrics.json"), mode="w", encoding="utf-8") as fp:
json.dump({"ment_p": dev_ment_p, "ment_r": dev_ment_r, "ment_f1": dev_ment_f1, "precision": dev_p, "recall": dev_r, "f1": dev_f1}, fp)
test_loss, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_ment_logs, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=opt.load_ckpt,
ment_update_iter=opt.eval_ment_inner_steps,
type_update_iter=opt.eval_type_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold, eval_mode=opt.mode)
if opt.hou_eval_ep:
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
with open(os.path.join(opt.output_dir, "test_metrics.json"), mode="w", encoding="utf-8") as fp:
json.dump({"ment_p": test_ment_p, "ment_r": test_ment_r, "ment_f1": test_ment_f1, "precision": test_p, "recall": test_r, "f1": test_f1}, fp)
eval_ent_log(test_loader.dataset.samples, test_logs)
return
if __name__ == "__main__":
opt = add_args()
main(opt)

Просмотреть файл

@ -0,0 +1,182 @@
import torch
import numpy as np
import json
import argparse
import os
import random
from util.pos_loader import get_seq_loader
from trainer.pos_trainer import POSTrainer
from util.span_encoder import BERTSpanEncoder
from model.pos_model import SeqProtoCls
from util.log_utils import write_pos_pred_json, save_json, set_seed
def add_args():
def str2bool(arg):
if arg.lower() == "true":
return True
return False
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='test', type=str,
help='train / test / typetest')
parser.add_argument('--train_eval_mode', default='test', type=str, choices=['test', 'typetest'])
parser.add_argument('--load_ckpt', default=None,
help='load ckpt')
parser.add_argument('--output_dir', default=None,
help='output dir')
parser.add_argument('--log_dir', default=None,
help='log dir')
parser.add_argument('--root', default=None, type=str,
help='data root dir')
parser.add_argument('--train', default='train.txt',
help='train file')
parser.add_argument('--val', default='dev.txt',
help='val file')
parser.add_argument('--test', default='test.txt',
help='test file')
parser.add_argument('--ep_dir', default=None, type=str)
parser.add_argument('--ep_train', default=None, type=str)
parser.add_argument('--ep_val', default=None, type=str)
parser.add_argument('--ep_test', default=None, type=str)
parser.add_argument('--N', default=None, type=int,
help='N way')
parser.add_argument('--K', default=None, type=int,
help='K shot')
parser.add_argument('--Q', default=None, type=int,
help='Num of query per class')
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
parser.add_argument('--word_encode_choice', default='first', type=str)
parser.add_argument('--max_length', default=128, type=int,
help='max length')
parser.add_argument('--last_n_layer', default=-4, type=int)
parser.add_argument('--max_loss', default=0, type=float)
parser.add_argument('--dot', type=str2bool, default=False,
help='use dot instead of L2 distance for knn')
parser.add_argument("--normalize", default='none', type=str, choices=['none', 'l2'])
parser.add_argument("--temperature", default=None, type=float)
parser.add_argument('--dropout', default=0.5, type=float,
help='dropout rate')
parser.add_argument('--log_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--val_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--train_batch_size', default=4, type=int,
help='batch size')
parser.add_argument('--eval_batch_size', default=1, type=int,
help='batch size')
parser.add_argument('--train_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--dev_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--test_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument("--max_grad_norm", default=None, type=float)
parser.add_argument("--learning_rate", default=1e-3, type=float, help="lr rate")
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
parser.add_argument("--warmup_step", default=0, type=int)
parser.add_argument('--seed', default=42, type=int,
help='seed')
parser.add_argument('--fp16', action='store_true', default=False,
help='use nvidia apex fp16')
parser.add_argument('--use_maml', default=False, type=str2bool)
parser.add_argument('--warmup_prop_inner', default=0, type=float)
parser.add_argument('--train_inner_lr', default=0, type=float)
parser.add_argument('--train_inner_steps', default=0, type=int)
parser.add_argument('--eval_inner_lr', default=0, type=float)
parser.add_argument('--eval_inner_steps', default=0, type=int)
opt = parser.parse_args()
return opt
def main(opt):
set_seed(opt.seed)
if not os.path.exists(opt.output_dir):
os.makedirs(opt.output_dir)
if opt.mode == "train":
output_dir = opt.output_dir
print("Output dir is ", output_dir)
log_dir = os.path.join(
output_dir,
"logs",
)
opt.log_dir = log_dir
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
else:
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
print('loading model and tokenizer...')
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
word_encode_choice=opt.word_encode_choice, drop_p=opt.dropout)
print('loading data')
if opt.mode == "train":
train_loader = get_seq_loader(os.path.join(opt.root, opt.train), "train", word_encoder, batch_size=opt.train_batch_size, max_length=opt.max_length,
shuffle=True,
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
else:
train_loader = get_seq_loader(os.path.join(opt.root, opt.train), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
shuffle=False,
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
dev_loader = get_seq_loader(os.path.join(opt.root, opt.val), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
shuffle=False,
debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None)
test_loader = get_seq_loader(os.path.join(opt.root, opt.test), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
shuffle=False,
debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None)
print("max_length: {}".format(opt.max_length))
print('mode: {}'.format(opt.mode))
print("{}-way-{}-shot Token MAML-Proto Few-Shot NER".format(opt.N, opt.K))
print("this mode can only used for maml training !!!!!!!!!!!!!!!!!")
model = SeqProtoCls(word_encoder, opt.max_loss, dot=opt.dot, normalize=opt.normalize,
temperature=opt.temperature)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
trainer = POSTrainer()
if opt.mode == 'train':
print("==================start training==================")
trainer.train(model, opt, device, train_loader, dev_loader, dev_pred_fn=os.path.join(opt.output_dir, "dev_pred.json"))
test_loss, test_acc, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=os.path.join(opt.output_dir, "model.pth.tar"),
update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_mode=opt.train_eval_mode)
print('[TEST] loss: {0:2.6f} | [POS] acc: {1:3.4f}'\
.format(test_loss, test_acc) + '\r')
else:
test_loss, test_acc, test_logs = trainer.eval(model, device, test_loader, load_ckpt=opt.load_ckpt,
eval_iter=opt.test_iter, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_mode=opt.mode)
print('[TEST] loss: {0:2.6f} | [POS] acc: {1:3.4f}'\
.format(test_loss, test_acc) + '\r')
name = "test_metrics.json"
if opt.mode != "test":
name = f"{opt.mode}_test_metrics.json"
with open(os.path.join(opt.output_dir, name), mode="w", encoding="utf-8") as fp:
res_mp = {"test_acc": test_acc}
json.dump(res_mp, fp)
write_pos_pred_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_pred.json"))
return
if __name__ == "__main__":
opt = add_args()
main(opt)

Просмотреть файл

@ -0,0 +1,180 @@
import torch
import numpy as np
import argparse
import os
from util.seq_loader import get_seq_loader as get_ment_loader
from trainer.ment_trainer import MentTrainer
from util.log_utils import eval_ment_log, write_ep_ment_log_json, save_json, set_seed
from util.span_encoder import BERTSpanEncoder
from model.ment_model import MentSeqtagger
def add_args():
def str2bool(arg):
if arg.lower() == "true":
return True
return False
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='test', type=str,
help='train / test')
parser.add_argument("--use_episode", default=False, type=str2bool)
parser.add_argument("--bio", default=False, type=str2bool)
parser.add_argument('--load_ckpt', default=None,
help='load ckpt')
parser.add_argument('--output_dir', default=None,
help='output dir')
parser.add_argument('--log_dir', default=None,
help='log dir')
parser.add_argument('--root', default=None, type=str,
help='data root dir')
parser.add_argument('--train', default='train.txt',
help='train file')
parser.add_argument('--val', default='dev.txt',
help='val file')
parser.add_argument('--test', default='test.txt',
help='test file')
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
parser.add_argument('--word_encode_choice', default='first', type=str)
parser.add_argument('--max_length', default=128, type=int,
help='max length')
parser.add_argument('--last_n_layer', default=-4, type=int)
parser.add_argument('--schema', default='IO', type=str, choices=['IO', 'BIO', 'BIOES'])
parser.add_argument('--use_crf', type=str2bool, default=False)
parser.add_argument('--max_loss', default=0, type=float)
parser.add_argument('--dropout', default=0.5, type=float,
help='dropout rate')
parser.add_argument('--log_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--val_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--train_batch_size', default=16, type=int,
help='batch size')
parser.add_argument('--eval_batch_size', default=16, type=int,
help='batch size')
parser.add_argument('--train_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--dev_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--test_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument("--max_grad_norm", default=None, type=float)
parser.add_argument("--learning_rate", default=1e-3, type=float, help="lr rate")
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
parser.add_argument("--warmup_step", default=0, type=int)
parser.add_argument('--seed', default=42, type=int,
help='seed')
parser.add_argument('--fp16', action='store_true', default=False,
help='use nvidia apex fp16')
parser.add_argument('--ep_dir', default=None, type=str)
parser.add_argument('--ep_train', default=None, type=str)
parser.add_argument('--ep_val', default=None, type=str)
parser.add_argument('--ep_test', default=None, type=str)
parser.add_argument('--eval_all_after_train', default=False, type=str2bool)
parser.add_argument('--use_maml', default=False, type=str2bool)
parser.add_argument('--warmup_prop_inner', default=0, type=float)
parser.add_argument('--train_inner_lr', default=0, type=float)
parser.add_argument('--train_inner_steps', default=0, type=int)
parser.add_argument('--eval_inner_lr', default=0, type=float)
parser.add_argument('--eval_inner_steps', default=0, type=int)
opt = parser.parse_args()
return opt
def main(opt):
set_seed(opt.seed)
if not os.path.exists(opt.output_dir):
os.makedirs(opt.output_dir)
if opt.mode == "train":
output_dir = opt.output_dir
print("Output dir is ", output_dir)
log_dir = os.path.join(
output_dir,
"logs",
)
opt.log_dir = log_dir
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
else:
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
print('loading model and tokenizer...')
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
word_encode_choice=opt.word_encode_choice, drop_p=opt.dropout)
print('loading data')
if opt.mode == "train":
train_loader = get_ment_loader(os.path.join(opt.root, opt.train), "train", word_encoder, batch_size=opt.train_batch_size, max_length=opt.max_length,
schema=opt.schema,
shuffle=True,
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
else:
train_loader = get_ment_loader(os.path.join(opt.root, opt.train), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
schema=opt.schema,
shuffle=False,
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
dev_loader = get_ment_loader(os.path.join(opt.root, opt.val), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
schema=opt.schema,
shuffle=False,
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None)
test_loader = get_ment_loader(os.path.join(opt.root, opt.test), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
schema=opt.schema,
shuffle=False,
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None)
print("max_length: {}".format(opt.max_length))
print('mode: {}'.format(opt.mode))
print("mention detection sequence labeling with max loss {}".format(opt.max_loss))
model = MentSeqtagger(word_encoder, len(test_loader.dataset.ment_label2tag), test_loader.dataset.ment_tag2label, opt.schema, opt.use_crf, opt.max_loss)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)
trainer = MentTrainer()
assert opt.eval_batch_size == 1
if opt.mode == 'train':
print("==================start training==================")
trainer.train(model, opt, device, train_loader, dev_loader, eval_log_fn=os.path.join(opt.output_dir, "dev_ment_log.txt"))
if opt.eval_all_after_train:
train_loader = get_ment_loader(os.path.join(opt.root, opt.train), "test", word_encoder, batch_size=opt.eval_batch_size, max_length=opt.max_length,
schema=opt.schema,
shuffle=False,
bio=opt.bio, debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None)
load_ckpt = os.path.join(opt.output_dir, "model.pth.tar")
_, dev_p, dev_r, dev_f1, _, _, _, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.dev_iter)
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
eval_ment_log(dev_loader.dataset.samples, dev_logs)
write_ep_ment_log_json(dev_loader.dataset.samples, dev_logs, os.path.join(opt.output_dir, "dev_ment.json"))
_, test_p, test_r, test_f1, _, _, _, test_logs = trainer.eval(model, device, test_loader, load_ckpt=load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.test_iter)
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
eval_ment_log(test_loader.dataset.samples, test_logs)
write_ep_ment_log_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_ment.json"))
else:
_, dev_p, dev_r, dev_f1, _, _, _, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=opt.load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.dev_iter)
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
eval_ment_log(dev_loader.dataset.samples, dev_logs)
write_ep_ment_log_json(dev_loader.dataset.samples, dev_logs, os.path.join(opt.output_dir, "dev_ment.json"))
_, test_p, test_r, test_f1, _, _, _, test_logs = trainer.eval(model, device, test_loader, load_ckpt=opt.load_ckpt, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, eval_iter=opt.test_iter)
model.timer.avg()
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
eval_ment_log(test_loader.dataset.samples, test_logs)
write_ep_ment_log_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_ment.json"))
res_mp = {"dev_p": dev_p, "dev_r": dev_r, "dev_f1": dev_f1, "test_p": test_p, "test_r": test_r, "test_f1": test_f1}
save_json(res_mp, os.path.join(opt.output_dir, "metrics.json"))
return
if __name__ == "__main__":
opt = add_args()
main(opt)

Просмотреть файл

@ -0,0 +1,243 @@
import sys
import torch
import numpy as np
import json
import argparse
import os
from trainer.span_trainer import SpanTrainer
from util.span_encoder import BERTSpanEncoder
from util.span_loader import get_query_loader
from util.log_utils import eval_ent_log, cal_episode_prf, write_ent_pred_json, set_seed, save_json
from model.type_model import SpanProtoCls
def add_args():
def str2bool(arg):
if arg.lower() == "true":
return True
return False
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='test', type=str,
help='train / test')
parser.add_argument('--load_ckpt', default=None,
help='load ckpt')
parser.add_argument('--output_dir', default=None,
help='output dir')
parser.add_argument('--log_dir', default=None,
help='log dir')
parser.add_argument('--root', default=None, type=str,
help='data root dir')
parser.add_argument('--train', default='train.txt',
help='train file')
parser.add_argument('--val', default='dev.txt',
help='val file')
parser.add_argument('--test', default='test.txt',
help='test file')
parser.add_argument('--train_ment_fn', default=None,
help='train file')
parser.add_argument('--val_ment_fn', default=None,
help='val file')
parser.add_argument('--test_ment_fn', default=None,
help='test file')
parser.add_argument('--ep_dir', default=None, type=str)
parser.add_argument('--ep_train', default=None, type=str)
parser.add_argument('--ep_val', default=None, type=str)
parser.add_argument('--ep_test', default=None, type=str)
parser.add_argument('--N', default=None, type=int,
help='N way')
parser.add_argument('--K', default=None, type=int,
help='K shot')
parser.add_argument('--Q', default=None, type=int,
help='Num of query per class')
parser.add_argument('--encoder_name_or_path', default='bert-base-uncased', type=str)
parser.add_argument('--word_encode_choice', default='first', type=str)
parser.add_argument('--span_encode_choice', default=None, type=str)
parser.add_argument('--max_length', default=128, type=int,
help='max length')
parser.add_argument('--max_span_len', default=8, type=int)
parser.add_argument('--max_neg_ratio', default=5, type=int)
parser.add_argument('--last_n_layer', default=-4, type=int)
parser.add_argument('--dot', type=str2bool, default=False,
help='use dot instead of L2 distance for knn')
parser.add_argument("--normalize", default='none', type=str, choices=['none', 'l2'])
parser.add_argument("--temperature", default=None, type=float)
parser.add_argument("--use_width", default=False, type=str2bool)
parser.add_argument("--width_dim", default=20, type=int)
parser.add_argument("--use_case", default=False, type=str2bool)
parser.add_argument("--case_dim", default=20, type=int)
parser.add_argument('--dropout', default=0.5, type=float,
help='dropout rate')
parser.add_argument('--log_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--val_steps', default=2000, type=int,
help='val after training how many iters')
parser.add_argument('--train_batch_size', default=4, type=int,
help='batch size')
parser.add_argument('--eval_batch_size', default=1, type=int,
help='batch size')
parser.add_argument('--train_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--dev_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument('--test_iter', default=-1, type=int,
help='num of iters in training')
parser.add_argument("--max_grad_norm", default=None, type=float)
parser.add_argument("--learning_rate", default=1e-3, type=float, help="lr rate")
parser.add_argument("--weight_decay", default=1e-3, type=float, help="weight decay")
parser.add_argument("--bert_learning_rate", default=5e-5, type=float, help="lr rate")
parser.add_argument("--bert_weight_decay", default=1e-5, type=float, help="weight decay")
parser.add_argument("--warmup_step", default=0, type=int)
parser.add_argument('--seed', default=42, type=int,
help='seed')
parser.add_argument('--fp16', action='store_true', default=False,
help='use nvidia apex fp16')
parser.add_argument('--use_focal', default=False, type=str2bool)
parser.add_argument('--iou_thred', default=None, type=float)
parser.add_argument('--use_att', default=False, type=str2bool)
parser.add_argument('--att_hidden_dim', default=-1, type=int)
parser.add_argument('--label_fn', default=None, type=str)
parser.add_argument('--hou_eval_ep', default=False, type=str2bool)
parser.add_argument('--use_maml', default=False, type=str2bool)
parser.add_argument('--warmup_prop_inner', default=0, type=float)
parser.add_argument('--train_inner_lr', default=0, type=float)
parser.add_argument('--train_inner_steps', default=0, type=int)
parser.add_argument('--eval_inner_lr', default=0, type=float)
parser.add_argument('--eval_inner_steps', default=0, type=int)
parser.add_argument('--overlap', default=False, type=str2bool)
parser.add_argument('--type_threshold', default=-1, type=float)
parser.add_argument('--use_oproto', default=False, type=str2bool)
parser.add_argument('--bio', default=False, type=str2bool)
opt = parser.parse_args()
return opt
def main(opt):
print("Span based proto pipeline {} way {} shot".format(opt.N, opt.K))
set_seed(opt.seed)
if opt.mode == "train":
output_dir = opt.output_dir
print("Output dir is ", output_dir)
log_dir = os.path.join(
output_dir,
"logs",
)
opt.log_dir = log_dir
if not os.path.exists(opt.output_dir):
os.makedirs(opt.output_dir)
save_json(opt.__dict__, os.path.join(opt.output_dir, "train_setting.txt"))
else:
if not os.path.exists(opt.output_dir):
os.makedirs(opt.output_dir)
save_json(opt.__dict__, os.path.join(opt.output_dir, "test_setting.txt"))
print('loading model and tokenizer...')
word_encoder = BERTSpanEncoder(opt.encoder_name_or_path, opt.max_length, last_n_layer=opt.last_n_layer,
word_encode_choice=opt.word_encode_choice,
span_encode_choice=opt.span_encode_choice,
use_width=opt.use_width, width_dim=opt.width_dim, use_case=opt.use_case,
case_dim=opt.case_dim,
drop_p=opt.dropout, use_att=opt.use_att,
att_hidden_dim=opt.att_hidden_dim)
print('loading data')
if opt.mode == "train":
train_loader = get_query_loader(os.path.join(opt.root, opt.train), "train", word_encoder, N=opt.N,
K=opt.K,
Q=opt.Q, batch_size=opt.train_batch_size, max_length=opt.max_length,
shuffle=True,
bio=opt.bio,
debug_file=os.path.join(opt.ep_dir, opt.ep_train) if opt.ep_train else None,
query_file=opt.train_ment_fn,
iou_thred=opt.iou_thred,
use_oproto=opt.use_oproto,
label_fn=opt.label_fn)
dev_loader = get_query_loader(os.path.join(opt.root, opt.val), "test", word_encoder, N=opt.N, K=opt.K,
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
shuffle=False,
bio=opt.bio,
debug_file=os.path.join(opt.ep_dir, opt.ep_val) if opt.ep_val else None,
query_file=opt.val_ment_fn,
iou_thred=opt.iou_thred,
use_oproto=opt.use_oproto,
label_fn=opt.label_fn)
test_loader = get_query_loader(os.path.join(opt.root, opt.test), "test", word_encoder, N=opt.N, K=opt.K,
Q=opt.Q, batch_size=opt.eval_batch_size, max_length=opt.max_length,
shuffle=False,
bio=opt.bio,
debug_file=os.path.join(opt.ep_dir, opt.ep_test) if opt.ep_test else None,
query_file=opt.test_ment_fn,
iou_thred=opt.iou_thred, hidden_query_label=False,
use_oproto=opt.use_oproto,
label_fn=opt.label_fn)
print("max_length: {}".format(opt.max_length))
print('mode: {}'.format(opt.mode))
print("{}-way-{}-shot Proto Few-Shot NER".format(opt.N, opt.K))
if opt.train_inner_steps > 0:
print("this mode can only used for maml training !!!!!!!!!!!!!!!!!")
model = SpanProtoCls(word_encoder, use_oproto=opt.use_oproto, dot=opt.dot, normalize=opt.normalize,
temperature=opt.temperature, use_focal=opt.use_focal)
else:
raise NotImplementedError
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
trainer = SpanTrainer()
assert opt.eval_batch_size == 1
if opt.mode == 'train':
trainer.train(model, opt, device, train_loader, dev_loader, dev_pred_fn=os.path.join(opt.output_dir, "dev_pred.json"), dev_log_fn=os.path.join(opt.output_dir, "dev_log.txt"))
test_loss, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_logs = trainer.eval(model, device, test_loader, eval_iter=opt.test_iter, load_ckpt=os.path.join(opt.output_dir, "model.pth.tar"),
update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold)
if opt.hou_eval_ep:
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
print('[TEST] loss: {0:2.6f} | [Entity] precision: {1:3.4f}, recall: {2:3.4f}, f1: {3:3.4f}'\
.format(test_loss, test_p, test_r, test_f1) + '\r')
else:
_, dev_ment_p, dev_ment_r, dev_ment_f1, dev_p, dev_r, dev_f1, dev_logs = trainer.eval(model, device, dev_loader, load_ckpt=opt.load_ckpt,
eval_iter=opt.dev_iter, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold)
if opt.hou_eval_ep:
dev_p, dev_r, dev_f1 = cal_episode_prf(dev_logs)
print("Mention dev precision {:.5f} recall {:.5f} f1 {:.5f}".format(dev_ment_p, dev_ment_r, dev_ment_f1))
print("Dev precison {:.5f} recall {:.5f} f1 {:.5f}".format(dev_p, dev_r, dev_f1))
with open(os.path.join(opt.output_dir, "dev_metrics.json"), mode="w", encoding="utf-8") as fp:
json.dump({"precision": dev_p, "recall": dev_r, "f1": dev_f1}, fp)
_, test_ment_p, test_ment_r, test_ment_f1, test_p, test_r, test_f1, test_logs = trainer.eval(model, device, test_loader, load_ckpt=opt.load_ckpt,
eval_iter=opt.test_iter, update_iter=opt.eval_inner_steps, learning_rate=opt.eval_inner_lr, overlap=opt.overlap, threshold=opt.type_threshold)
if opt.hou_eval_ep:
test_p, test_r, test_f1 = cal_episode_prf(test_logs)
print("Mention test precision {:.5f} recall {:.5f} f1 {:.5f}".format(test_ment_p, test_ment_r, test_ment_f1))
print("Test precison {:.5f} recall {:.5f} f1 {:.5f}".format(test_p, test_r, test_f1))
with open(os.path.join(opt.output_dir, "test_metrics.json"), mode="w", encoding="utf-8") as fp:
json.dump({"precision": test_p, "recall": test_r, "f1": test_f1}, fp)
eval_ent_log(test_loader.dataset.samples, test_logs)
write_ent_pred_json(test_loader.dataset.samples, test_logs, os.path.join(opt.output_dir, "test_pred.json"))
return
if __name__ == "__main__":
opt = add_args()
main(opt)

Просмотреть файл

@ -0,0 +1,217 @@
import json
import os, sys
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from collections import defaultdict
from util.log_utils import write_ep_ment_log_json, write_ment_log, eval_ment_log
from util.log_utils import write_ent_log, write_ent_pred_json, eval_ent_log, cal_prf
from tqdm import tqdm
class JointTrainer:
def __init__(self):
return
def __load_model__(self, ckpt):
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
print("Successfully loaded checkpoint '%s'" % ckpt)
return checkpoint
else:
raise Exception("No checkpoint found at '%s'" % ckpt)
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
if schedule == "linear":
if progress < warmup:
lr *= progress / warmup
else:
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
return lr
def check_input_ment(self, query):
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
for i in range(len(query['spans'])):
pred_ments = query['span_indices'][i][query['span_mask'][i].eq(1)].detach().cpu().tolist()
gold_ments = [x[1:] for x in query['spans'][i]]
pred_cnt += len(pred_ments)
gold_cnt += len(gold_ments)
for x in gold_ments:
if x in pred_ments:
hit_cnt += 1
return pred_cnt, gold_cnt, hit_cnt
def eval(self, model, device, dataloader, load_ckpt=None, ment_update_iter=0, type_update_iter=0, learning_rate=3e-5, eval_iter=-1, eval_mode="test-twostage", overlap=False, threshold=-1):
if load_ckpt:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
model.eval()
eval_loss = 0
tot_seq_cnt = 0
tot_ment_logs = {}
tot_type_logs = {}
eval_batchs = iter(dataloader)
tot_ment_metric_logs = defaultdict(int)
tot_type_metric_logs = defaultdict(int)
eval_loss = 0
if eval_iter <= 0:
eval_iter = len(dataloader.dataset.sampler)
tot_seq_cnt = 0
print("[eval] update {} steps | total {} episode".format(ment_update_iter, eval_iter))
for batch_id in tqdm(range(eval_iter)):
batch = next(eval_batchs)
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
res = model.forward_joint_meta(batch, ment_update_iter, learning_rate, eval_mode)
eval_loss += res['loss']
if eval_mode == 'test-twostage':
batch["query"]["span_indices"] = res['pred_spans']
batch["query"]["span_mask"] = res['pred_masks']
ment_metric_logs, ment_logs = model.seqment_eval(res["ment_preds"], batch["query"], model.ment_idx2label, model.schema)
type_metric_logs, type_logs = model.greedy_eval(res['type_logits'], batch['query'], overlap=overlap, threshold=threshold)
tot_seq_cnt += batch['query']['word'].size(0)
for k, v in ment_logs.items():
if k not in tot_ment_logs:
tot_ment_logs[k] = []
tot_ment_logs[k] += v
for k, v in type_logs.items():
if k not in tot_type_logs:
tot_type_logs[k] = []
tot_type_logs[k] += v
for k, v in ment_metric_logs.items():
tot_ment_metric_logs[k] += v
for k, v in type_metric_logs.items():
tot_type_metric_logs[k] += v
ment_p, ment_r, ment_f1 = cal_prf(tot_ment_metric_logs["ment_hit_cnt"], tot_ment_metric_logs["ment_pred_cnt"], tot_ment_metric_logs["ment_gold_cnt"])
print("seq num:", tot_seq_cnt, "hit cnt:", tot_ment_metric_logs["ment_hit_cnt"], "pred cnt:", tot_ment_metric_logs["ment_pred_cnt"], "gold cnt:", tot_ment_metric_logs["ment_gold_cnt"])
print("avg hit:", tot_ment_metric_logs["ment_hit_cnt"] / tot_seq_cnt, "avg pred:", tot_ment_metric_logs["ment_pred_cnt"] / tot_seq_cnt, "avg gold:", tot_ment_metric_logs["ment_gold_cnt"] / tot_seq_cnt)
ent_p, ent_r, ent_f1 = cal_prf(tot_type_metric_logs["ent_hit_cnt"], tot_type_metric_logs["ent_pred_cnt"], tot_type_metric_logs["ent_gold_cnt"])
model.train()
return eval_loss / eval_iter, ment_p, ment_r, ment_f1, ent_p, ent_r, ent_f1, tot_ment_logs, tot_type_logs
def train(self, model, training_args, device, trainloader, devloader, load_ckpt=None, ignore_log=True, dev_pred_fn=None, dev_log_fn=None):
if load_ckpt is not None:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
print("load ckpt from {}".format(load_ckpt))
# Init optimizer
print('Use bert optim!')
parameters_to_optimize = list(filter(lambda x: x[1].requires_grad, model.named_parameters()))
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
parameters_groups = [
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
]
optimizer = torch.optim.AdamW(parameters_groups)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
num_training_steps=training_args.train_iter)
model.train()
model.zero_grad()
best_f1 = 0.0
train_loss = 0.0
train_ment_loss = 0.0
train_type_loss = 0.0
train_type_acc = 0
train_ment_acc = 0
iter_sample = 0
it = 0
train_batchs = iter(trainloader)
for _ in range(training_args.train_iter):
it += 1
torch.cuda.empty_cache()
model.train()
batch = next(train_batchs)
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
if training_args.use_maml:
progress = 1.0 * (it - 1) / training_args.train_iter
lr_inner = self.get_learning_rate(
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
)
res = model.forward_joint_meta(batch, training_args.train_inner_steps, lr_inner, "train")
for g in res['grads']:
model.load_gradients(res['names'], g)
train_loss += res['loss']
train_ment_loss += res['ment_loss']
train_type_loss += res['type_loss']
else:
raise ValueError
if training_args.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
type_pred = torch.cat(res['type_preds'], dim=0).detach().cpu().numpy()
type_gold = torch.cat(res['type_golds'], dim=0).detach().cpu().numpy()
ment_pred = torch.cat(res['ment_preds'], dim=0).detach().cpu().numpy()
ment_gold = torch.cat(res['ment_golds'], dim=0).detach().cpu().numpy()
type_acc = model.span_accuracy(type_pred, type_gold)
ment_acc = model.span_accuracy(ment_pred, ment_gold)
train_type_acc += type_acc
train_ment_acc += ment_acc
iter_sample += 1
if not ignore_log:
raise ValueError
if it % 100 == 0 or it % training_args.log_steps == 0:
if not ignore_log:
raise ValueError
else:
print('step: {0:4} | loss: {1:2.6f} | ment loss: {2:2.6f} | type loss: {3:2.6f} | ment acc {4:.5f} | type acc {5:.5f}'
.format(it, train_loss / iter_sample, train_ment_loss / iter_sample, train_type_loss / iter_sample, train_ment_acc / iter_sample, train_type_acc / iter_sample) + '\r')
train_loss = 0
train_ment_loss = 0
train_type_loss = 0
train_ment_acc = 0
train_type_acc = 0
iter_sample = 0
if it % training_args.val_steps == 0:
eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1, eval_ment_logs, eval_logs = self.eval(model, device, devloader, ment_update_iter=training_args.eval_ment_inner_steps, type_update_iter=training_args.eval_type_inner_steps, learning_rate=training_args.eval_inner_lr, eval_iter=training_args.dev_iter, eval_mode="test-twostage")
print('[EVAL] step: {0:4} | loss: {1:2.6f} | [MENTION] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f} [ENTITY] precision: {5:3.4f}, recall: {6:3.4f}, f1: {7:3.4f}'\
.format(it, eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1) + '\r')
if eval_f1 > best_f1:
print('Best checkpoint')
torch.save({'state_dict': model.state_dict()},
os.path.join(training_args.output_dir, "model.pth.tar"))
best_f1 = eval_f1
if dev_pred_fn is not None:
with open(dev_pred_fn, mode="w", encoding="utf-8") as fp:
json.dump({"ment_p": eval_ment_p, "ment_r": eval_ment_r, "ment_f1": eval_ment_f1, "precision": eval_p, "recall": eval_r, "f1": eval_f1}, fp)
eval_ent_log(devloader.dataset.samples, eval_logs)
if training_args.train_iter == it:
break
print("\n####################\n")
print("Finish training ")
return

Просмотреть файл

@ -0,0 +1,192 @@
import json
import os, sys
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from collections import defaultdict
from util.log_utils import write_ep_ment_log_json, write_ment_log, eval_ment_log, cal_prf
from tqdm import tqdm
class MentTrainer:
def __init__(self):
return
def __load_model__(self, ckpt):
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
print("Successfully loaded checkpoint '%s'" % ckpt)
return checkpoint
else:
raise Exception("No checkpoint found at '%s'" % ckpt)
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
if schedule == "linear":
if progress < warmup:
lr *= progress / warmup
else:
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
return lr
def eval(self, model, device, dataloader, load_ckpt=None, update_iter=0, learning_rate=3e-5, eval_iter=-1):
if load_ckpt:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
model.eval()
tot_metric_logs = defaultdict(int)
eval_loss = 0
tot_seq_cnt = 0
tot_logs = {}
eval_batchs = iter(dataloader)
tot_metric_logs = defaultdict(int)
eval_loss = 0
if eval_iter <= 0:
eval_iter = len(dataloader.dataset.sampler)
tot_seq_cnt = 0
tot_logs = {}
print("[eval] update {} steps | total {} episode".format(update_iter, eval_iter))
for batch_id in range(eval_iter):
batch = next(eval_batchs)
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
if update_iter > 0:
res = model.forward_meta(batch, update_iter, learning_rate, "test")
eval_loss += res['loss']
else:
res = model.forward_sup(batch, "test")
eval_loss += res['loss'].item()
metric_logs, logs = model.seqment_eval(res["preds"], batch["query"], model.idx2label, model.schema)
tot_seq_cnt += batch['query']['word'].size(0)
for k, v in logs.items():
if k not in tot_logs:
tot_logs[k] = []
tot_logs[k] += v
for k, v in metric_logs.items():
tot_metric_logs[k] += v
ment_p, ment_r, ment_f1 = cal_prf(tot_metric_logs["ment_hit_cnt"], tot_metric_logs["ment_pred_cnt"], tot_metric_logs["ment_gold_cnt"])
print("seq num:", tot_seq_cnt, "hit cnt:", tot_metric_logs["ment_hit_cnt"], "pred cnt:", tot_metric_logs["ment_pred_cnt"], "gold cnt:", tot_metric_logs["ment_gold_cnt"])
print("avg hit:", tot_metric_logs["ment_hit_cnt"] / tot_seq_cnt, "avg pred:", tot_metric_logs["ment_pred_cnt"] / tot_seq_cnt, "avg gold:", tot_metric_logs["ment_gold_cnt"] / tot_seq_cnt)
ent_p, ent_r, ent_f1 = 0, 0, 0
model.train()
return eval_loss / eval_iter, ment_p, ment_r, ment_f1, ent_p, ent_r, ent_f1, tot_logs
def train(self, model, training_args, device, trainloader, devloader, eval_log_fn=None, load_ckpt=None, ignore_log=True):
if load_ckpt is not None:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
print("load ckpt from {}".format(load_ckpt))
# Init optimizer
print('Use bert optim!')
parameters_to_optimize = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
parameters_groups = [
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
]
optimizer = torch.optim.AdamW(parameters_groups)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
num_training_steps=training_args.train_iter)
model.train()
model.zero_grad()
best_f1 = 0.0
train_loss = 0.0
train_acc = 0
iter_sample = 0
tot_metric_logs = defaultdict(int)
it = 0
train_batchs = iter(trainloader)
for _ in range(training_args.train_iter):
it += 1
torch.cuda.empty_cache()
model.train()
batch = next(train_batchs)
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
if training_args.use_maml:
progress = 1.0 * (it - 1) / training_args.train_iter
lr_inner = self.get_learning_rate(
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
)
res = model.forward_meta(batch, training_args.train_inner_steps, lr_inner, "train")
for g in res['grads']:
model.load_gradients(res['names'], g)
train_loss += res['loss']
else:
res = model.forward_sup(batch, "train")
loss = res['loss']
loss.backward()
train_loss += loss.item()
if training_args.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
pred = torch.cat(res['preds'], dim=0).detach().cpu().numpy()
gold = torch.cat(res['golds'], dim=0).detach().cpu().numpy()
acc = model.span_accuracy(pred, gold)
train_acc += acc
iter_sample += 1
if not ignore_log:
metric_logs, logs = model.seqment_eval(res["preds"], batch["query"], model.idx2label, model.schema)
for k, v in metric_logs.items():
tot_metric_logs[k] += v
if it % 100 == 0 or it % training_args.log_steps == 0:
if not ignore_log:
precision, recall, f1 = cal_prf(tot_metric_logs["ment_hit_cnt"], tot_metric_logs["ment_pred_cnt"],
tot_metric_logs["ment_gold_cnt"])
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f} [ENTITY] precision: {3:3.4f}, recall: {4:3.4f}, f1: {5:3.4f}'\
.format(it, train_loss / iter_sample, train_acc / iter_sample, precision, recall, f1) + '\r')
else:
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f}'
.format(it, train_loss / iter_sample, train_acc / iter_sample) + '\r')
train_loss = 0
train_acc = 0
iter_sample = 0
tot_metric_logs = defaultdict(int)
if it % training_args.val_steps == 0:
eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1, eval_logs = self.eval(model, device, devloader, update_iter=training_args.eval_inner_steps, learning_rate=training_args.eval_inner_lr, eval_iter=training_args.dev_iter)
print('[EVAL] step: {0:4} | loss: {1:2.6f} | [MENTION] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f}'\
.format(it, eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1) + '\r')
if eval_ment_f1 > best_f1:
print('Best checkpoint')
torch.save({'state_dict': model.state_dict()},
os.path.join(training_args.output_dir, "model.pth.tar"))
best_f1 = eval_ment_f1
if eval_log_fn is not None:
write_ment_log(devloader.dataset.samples, eval_logs, eval_log_fn)
if training_args.train_iter == it:
break
print("\n####################\n")
print("Finish training ")
return

Просмотреть файл

@ -0,0 +1,185 @@
import json
import os, sys
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from collections import defaultdict
from util.log_utils import write_pos_pred_json
class POSTrainer:
def __init__(self):
return
def __load_model__(self, ckpt):
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
print("Successfully loaded checkpoint '%s'" % ckpt)
return checkpoint
else:
raise Exception("No checkpoint found at '%s'" % ckpt)
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
if schedule == "linear":
if progress < warmup:
lr *= progress / warmup
else:
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
return lr
def eval(self, model, device, dataloader, load_ckpt=None, update_iter=0, learning_rate=3e-5, eval_iter=-1, eval_mode="test"):
if load_ckpt:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
model.eval()
eval_batchs = iter(dataloader)
tot_metric_logs = defaultdict(int)
eval_loss = 0
if eval_iter <= 0:
eval_iter = len(dataloader.dataset.sampler)
tot_seq_cnt = 0
tot_logs = {}
for batch_id in tqdm(range(eval_iter)):
batch = next(eval_batchs)
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
if update_iter > 0:
res = model.forward_meta(batch, update_iter, learning_rate, eval_mode)
eval_loss += res['loss']
else:
res = model.forward_proto(batch, eval_mode)
eval_loss += res['loss'].item()
metric_logs, logs = model.seq_eval(res['preds'], batch['query'])
tot_seq_cnt += batch["query"]["seq_len"].size(0)
logs["support_index"] = batch["support"]["index"]
logs["support_sentence_num"] = batch["support"]["sentence_num"]
logs["support_subsentence_num"] = batch["support"]["subsentence_num"]
for k, v in logs.items():
if k not in tot_logs:
tot_logs[k] = []
tot_logs[k] += v
for k, v in metric_logs.items():
tot_metric_logs[k] += v
token_acc = tot_metric_logs["hit_cnt"] / tot_metric_logs["gold_cnt"]
print("seq num:", tot_seq_cnt, "hit cnt:", tot_metric_logs["hit_cnt"], "gold cnt:", tot_metric_logs["gold_cnt"], "acc:", token_acc)
model.train()
return eval_loss / eval_iter, token_acc, tot_logs
def train(self, model, training_args, device, trainloader, devloader, load_ckpt=None, dev_pred_fn=None, dev_log_fn=None, ignore_log=True):
if load_ckpt is not None:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
print("load ckpt from {}".format(load_ckpt))
# Init optimizer
print('Use bert optim!')
parameters_to_optimize = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
parameters_groups = [
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
]
optimizer = torch.optim.AdamW(parameters_groups)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
num_training_steps=training_args.train_iter)
model.train()
model.zero_grad()
best_f1 = -1
train_loss = 0.0
train_acc = 0
iter_sample = 0
tot_metric_logs = defaultdict(int)
it = 0
train_batchs = iter(trainloader)
for _ in range(training_args.train_iter):
it += 1
torch.cuda.empty_cache()
model.train()
batch = next(train_batchs)
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
if training_args.use_maml:
progress = 1.0 * (it - 1) / training_args.train_iter
lr_inner = self.get_learning_rate(
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
)
res = model.forward_meta(batch, training_args.train_inner_steps, lr_inner, "train")
for g in res['grads']:
model.load_gradients(res['names'], g) # loss backward
train_loss += res['loss']
else:
res = model.forward_proto(batch, "train")
loss = res['loss']
loss.backward()
train_loss += loss.item()
if training_args.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
pred = torch.cat(res['preds'], dim=0).detach().cpu().numpy()
gold = torch.cat(res['golds'], dim=0).detach().cpu().numpy()
acc = model.span_accuracy(pred, gold)
train_acc += acc
iter_sample += 1
if not ignore_log:
metric_logs, logs = model.seq_eval(res["preds"], batch["query"], model.schema)
for k, v in metric_logs.items():
tot_metric_logs[k] += v
if it % 100 == 0 or it % training_args.log_steps == 0:
if not ignore_log:
acc = tot_metric_logs["hit_cnt"] / tot_metric_logs["gold_cnt"]
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f} [Token] acc: {3:3.4f}'\
.format(it, train_loss / iter_sample, train_acc / iter_sample, acc) + '\r')
else:
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f}'\
.format(it, train_loss / iter_sample, train_acc / iter_sample) + '\r')
train_loss = 0
train_acc = 0
iter_sample = 0
tot_metric_logs = defaultdict(int)
if it % training_args.val_steps == 0:
eval_loss, eval_acc, eval_logs = self.eval(model, device, devloader, eval_iter=training_args.dev_iter,
update_iter=training_args.eval_inner_steps, learning_rate=training_args.eval_inner_lr, eval_mode=training_args.train_eval_mode)
print('[EVAL] step: {0:4} | loss: {1:2.6f} | F1: {2:3.4f}'\
.format(it, eval_loss, eval_acc) + '\r')
if eval_acc > best_f1:
print('Best checkpoint')
torch.save({'state_dict': model.state_dict()},
os.path.join(training_args.output_dir, "model.pth.tar"))
best_f1 = eval_acc
if dev_pred_fn is not None:
write_pos_pred_json(devloader.dataset.samples, eval_logs, dev_pred_fn)
print("Best acc: {}".format(best_f1))
if it > 1003:
break
print("\n####################\n")
print("Finish training ")
return

Просмотреть файл

@ -0,0 +1,204 @@
import json
import os, sys
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from collections import defaultdict
from util.log_utils import write_ent_log, write_ent_pred_json, eval_ent_log, cal_prf
from tqdm import tqdm
class SpanTrainer:
def __init__(self):
return
def __load_model__(self, ckpt):
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
print("Successfully loaded checkpoint '%s'" % ckpt)
return checkpoint
else:
raise Exception("No checkpoint found at '%s'" % ckpt)
def check_input_ment(self, query):
pred_cnt, gold_cnt, hit_cnt = 0, 0, 0
for i in range(len(query['spans'])):
pred_ments = query['span_indices'][i][query['span_mask'][i].eq(1)].detach().cpu().tolist()
gold_ments = [x[1:] for x in query['spans'][i]]
pred_cnt += len(pred_ments)
gold_cnt += len(gold_ments)
for x in gold_ments:
if x in pred_ments:
hit_cnt += 1
return pred_cnt, gold_cnt, hit_cnt
def get_learning_rate(self, lr, progress, warmup, schedule="linear"):
if schedule == "linear":
if progress < warmup:
lr *= progress / warmup
else:
lr *= max((progress - 1.0) / (warmup - 1.0), 0.0)
return lr
def eval(self, model, device, dataloader, load_ckpt=None, update_iter=0, learning_rate=3e-5, eval_iter=-1, overlap=False, threshold=-1):
if load_ckpt:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
model.eval()
eval_batchs = iter(dataloader)
tot_metric_logs = defaultdict(int)
eval_loss = 0
if eval_iter <= 0:
eval_iter = len(dataloader.dataset.sampler)
tot_seq_cnt = 0
tot_logs = {}
for batch_id in range(eval_iter):
batch = next(eval_batchs)
input_pred_ment_cnt, input_gold_ment_cnt, input_hit_ment_cnt = self.check_input_ment(batch['query'])
tot_metric_logs['episode_query_ment_hit_cnt'] += input_hit_ment_cnt
tot_metric_logs['episode_query_ment_pred_cnt'] += input_pred_ment_cnt
tot_metric_logs['episode_query_ment_gold_cnt'] += input_gold_ment_cnt
batch['support'] = dataloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = dataloader.dataset.batch_to_device(batch['query'], device)
if update_iter > 0:
res = model.forward_meta(batch, update_iter, learning_rate, "test")
eval_loss += res['loss']
else:
res = model.forward_proto(batch)
eval_loss += res['loss'].item()
metric_logs, logs = model.greedy_eval(res['logits'], batch['query'], overlap=overlap, threshold=threshold)
tot_seq_cnt += batch["query"]["seq_len"].size(0)
logs["support_index"] = batch["support"]["index"]
logs["support_sentence_num"] = batch["support"]["sentence_num"]
logs["support_subsentence_num"] = batch["support"]["subsentence_num"]
for k, v in logs.items():
if k not in tot_logs:
tot_logs[k] = []
tot_logs[k] += v
for k, v in metric_logs.items():
tot_metric_logs[k] += v
ment_p, ment_r, ment_f1 = cal_prf(tot_metric_logs["ment_hit_cnt"], tot_metric_logs["ment_pred_cnt"], tot_metric_logs["ment_gold_cnt"])
print("seq num:", tot_seq_cnt, "hit cnt:", tot_metric_logs["ent_hit_cnt"], "pred cnt:", tot_metric_logs["ent_pred_cnt"], "gold cnt:", tot_metric_logs["ent_gold_cnt"])
print(tot_metric_logs["ent_hit_cnt"] / tot_seq_cnt, tot_metric_logs["ent_pred_cnt"] / tot_seq_cnt, tot_metric_logs["ent_gold_cnt"] / tot_seq_cnt)
ent_p, ent_r, ent_f1 = cal_prf(tot_metric_logs["ent_hit_cnt"], tot_metric_logs["ent_pred_cnt"], tot_metric_logs["ent_gold_cnt"])
input_ment_p, input_ment_r, input_ment_f1 = cal_prf(tot_metric_logs["episode_query_ment_hit_cnt"], tot_metric_logs["episode_query_ment_pred_cnt"], tot_metric_logs["episode_query_ment_gold_cnt"])
print("episode based input mention precision {:.5f} recall {:.5f} f1 {:.5f}".format(input_ment_p, input_ment_r, input_ment_f1))
model.train()
return eval_loss / eval_iter, ment_p, ment_r, ment_f1, ent_p, ent_r, ent_f1, tot_logs
def train(self, model, training_args, device, trainloader, devloader, load_ckpt=None, dev_pred_fn=None, dev_log_fn=None):
if load_ckpt is not None:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('[ERROR] Ignore {}'.format(name))
continue
own_state[name].copy_(param)
print("load ckpt from {}".format(load_ckpt))
# Init optimizer
print('Use bert optim!')
parameters_to_optimize = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
parameters_groups = [
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and (not any(nd in n for nd in no_decay))],
'lr': training_args.bert_learning_rate, 'weight_decay': training_args.bert_weight_decay},
{'params': [p for n, p in parameters_to_optimize if ("bert." in n) and any(nd in n for nd in no_decay)],
'lr': training_args.bert_learning_rate, 'weight_decay': 0},
{'params': [p for n, p in parameters_to_optimize if "bert." not in n],
'lr': training_args.learning_rate, 'weight_decay': training_args.weight_decay}
]
optimizer = torch.optim.AdamW(parameters_groups)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=training_args.warmup_step,
num_training_steps=training_args.train_iter)
model.train()
model.zero_grad()
best_f1 = -1
train_loss = 0.0
train_acc = 0
iter_sample = 0
tot_metric_logs = defaultdict(int)
it = 0
train_batchs = iter(trainloader)
for _ in range(training_args.train_iter):
it += 1
torch.cuda.empty_cache()
model.train()
batch = next(train_batchs)
batch['support'] = trainloader.dataset.batch_to_device(batch['support'], device)
batch['query'] = trainloader.dataset.batch_to_device(batch['query'], device)
if training_args.use_maml:
progress = 1.0 * (it - 1) / training_args.train_iter
lr_inner = self.get_learning_rate(
training_args.train_inner_lr, progress, training_args.warmup_prop_inner
)
res = model.forward_meta(batch, training_args.train_inner_steps, lr_inner, "train")
for g in res['grads']:
model.load_gradients(res['names'], g) # loss backward
train_loss += res['loss']
else:
res = model.forward_proto(batch)
loss = res['loss']
loss.backward()
train_loss += loss.item()
if training_args.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
pred = torch.cat(res['preds'], dim=0).detach().cpu().numpy()
gold = torch.cat(res['golds'], dim=0).detach().cpu().numpy()
acc = model.span_accuracy(pred, gold)
train_acc += acc
iter_sample += 1
metric_logs, logs = model.greedy_eval(res["logits"], batch["query"], overlap=training_args.overlap, threshold=training_args.type_threshold)
for k, v in metric_logs.items():
tot_metric_logs[k] += v
if it % 100 == 0 or it % training_args.log_steps == 0:
precision, recall, f1 = cal_prf(tot_metric_logs["ent_hit_cnt"], tot_metric_logs["ent_pred_cnt"],
tot_metric_logs["ent_gold_cnt"])
print('step: {0:4} | loss: {1:2.6f} | span acc {2:.5f} [ENTITY] precision: {3:3.4f}, recall: {4:3.4f}, f1: {5:3.4f}'\
.format(it, train_loss / iter_sample, train_acc / iter_sample, precision, recall, f1) + '\r')
train_loss = 0
train_acc = 0
iter_sample = 0
tot_metric_logs = defaultdict(int)
if it % training_args.val_steps == 0:
eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1, eval_logs = self.eval(model, device, devloader, eval_iter=training_args.dev_iter,
update_iter=training_args.eval_inner_steps, learning_rate=training_args.eval_inner_lr, overlap=training_args.overlap, threshold=training_args.type_threshold)
print('[EVAL] step: {0:4} | loss: {1:2.6f} | [MENTION] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f} [ENTITY] precision: {5:3.4f}, recall: {6:3.4f}, f1: {7:3.4f}'\
.format(it, eval_loss, eval_ment_p, eval_ment_r, eval_ment_f1, eval_p, eval_r, eval_f1) + '\r')
if eval_f1 > best_f1:
print('Best checkpoint')
torch.save({'state_dict': model.state_dict()},
os.path.join(training_args.output_dir, "model.pth.tar"))
best_f1 = eval_f1
if dev_pred_fn is not None:
write_ent_pred_json(devloader.dataset.samples, eval_logs, dev_pred_fn)
if dev_log_fn is not None:
write_ent_log(devloader.dataset.samples, eval_logs, dev_log_fn)
eval_ent_log(devloader.dataset.samples, eval_logs)
print("\n####################\n")
print("Finish training ")
return

Просмотреть файл

@ -0,0 +1,30 @@
import torch.nn as nn
class ProjAdapter(nn.Module):
def __init__(self, in_size, hidden_size):
super().__init__()
self.w0 = nn.Linear(in_size, hidden_size, bias=True)
self.w1 = nn.Linear(hidden_size, in_size, bias=True)
self.act = nn.GELU()
return
def forward(self, input_hidden_states):
hidden_states = self.act(self.w0(input_hidden_states))
hidden_states = self.w1(hidden_states)
return hidden_states + input_hidden_states
class AdapterStack(nn.Module):
def __init__(self, adapter_hidden_size, num_hidden_layers=12):
super().__init__()
self.slf_list = nn.ModuleList([ProjAdapter(768, adapter_hidden_size) for _ in range(num_hidden_layers)])
self.ffn_list = nn.ModuleList([ProjAdapter(768, adapter_hidden_size) for _ in range(num_hidden_layers)])
return
def forward(self, inputs, layer_id, name):
if name == "slf":
outs = self.slf_list[layer_id](inputs)
else:
assert name == "ffn"
outs = self.ffn_list[layer_id](inputs)
return outs

Просмотреть файл

@ -0,0 +1,783 @@
'''
The code is adapted from https://github.com/huggingface/transformers/blob/v4.2.1/src/transformers/models/bert/modeling_bert.py
'''
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
import math
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers import BertPreTrainedModel
from transformers.modeling_utils import PreTrainedModel
#####################for 4.2.1##########################
from transformers.modeling_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
# from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers import BertConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
# TokenClassification docstring
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
_TOKEN_CLASS_EXPECTED_OUTPUT = (
"['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
)
_TOKEN_CLASS_EXPECTED_LOSS = 0.01
# QuestionAnswering docstring
_CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 7.41
_QA_TARGET_START_INDEX = 14
_QA_TARGET_END_INDEX = 15
# SequenceClassification docstring
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.01
BERT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`BertConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, adapters=None, adapter_id=None) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
if (adapters is not None):
hidden_states = adapters(hidden_states, layer_id=adapter_id, name="slf")
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
adapters=None, adapter_id=None
) -> Tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states, adapters=adapters, adapter_id=adapter_id)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, adapters=None, adapter_id=None) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
if (adapters is not None):
hidden_states = adapters(hidden_states, layer_id=adapter_id, name="ffn")
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
adapters=None,
adapter_id=None,
) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
adapters=adapters,
adapter_id=adapter_id
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
lambda x: self.feed_forward_chunk(x, adapters, adapter_id), self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output, adapters=None, adapter_id=None):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output, adapters=adapters, adapter_id=adapter_id)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
adapters=None,
adapter_layer_ids=None,
input_bottom_hiddens=None
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if input_bottom_hiddens is None:
start_layer_i = 0
else:
start_layer_i = adapter_layer_ids[0]
hidden_states = input_bottom_hiddens
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if i < start_layer_i:
continue
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if adapter_layer_ids is None:
cur_adapters = adapters
cur_i = i
elif i not in adapter_layer_ids:
cur_adapters = None
cur_i = i
else:
cur_adapters = adapters
cur_i = adapter_layer_ids.index(i)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
cur_adapters,
cur_i
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
cur_adapters,
cur_i
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
adapters=None,
adapter_layer_ids=None,
input_bottom_hiddens=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
elif input_bottom_hiddens is not None:
input_shape = input_bottom_hiddens.size()[:-1]
batch_size, seq_length = input_shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if input_bottom_hiddens is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
adapters=adapters,
adapter_layer_ids=adapter_layer_ids,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
else:
encoder_outputs = self.encoder(
None,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
adapters=adapters,
adapter_layer_ids=adapter_layer_ids,
input_bottom_hiddens=input_bottom_hiddens
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)

Просмотреть файл

@ -0,0 +1,161 @@
'''
The code is adapted from https://github.com/thunlp/Few-NERD/blob/main/util/fewshotsampler.py
'''
import random
import numpy as np
import json
class FewshotSampleBase:
'''
Abstract Class
DO NOT USE
Build your own Sample class and inherit from this class
'''
def __init__(self):
self.class_count = {}
def get_class_count(self):
'''
return a dictionary of {class_name:count} in format {any : int}
'''
return self.class_count
class FewshotSampler:
'''
sample one support set and one query set
'''
def __init__(self, N, K, Q, samples, classes=None, random_state=0):
'''
N: int, how many types in each set
K: int, how many instances for each type in support set
Q: int, how many instances for each type in query set
samples: List[Sample], Sample class must have `get_class_count` attribute
classes[Optional]: List[any], all unique classes in samples. If not given, the classes will be got from samples.get_class_count()
random_state[Optional]: int, the random seed
'''
self.K = K
self.N = N
self.Q = Q
self.samples = samples
self.__check__() # check if samples have correct types
if classes:
self.classes = classes
else:
self.classes = self.__get_all_classes__()
random.seed(random_state)
def __get_all_classes__(self):
classes = []
for sample in self.samples:
classes += list(sample.get_class_count().keys())
return list(set(classes))
def __check__(self):
for idx, sample in enumerate(self.samples):
if not hasattr(sample,'get_class_count'):
print('[ERROR] samples in self.samples expected to have `get_class_count` attribute, but self.samples[{idx}] does not')
raise ValueError
def __additem__(self, index, set_class):
class_count = self.samples[index].get_class_count()
for class_name in class_count:
if class_name in set_class:
set_class[class_name] += class_count[class_name]
else:
set_class[class_name] = class_count[class_name]
def __valid_sample__(self, sample, set_class, target_classes):
threshold = 2 * set_class['k']
class_count = sample.get_class_count()
if not class_count:
return False
isvalid = False
for class_name in class_count:
if class_name not in target_classes:
return False
if class_count[class_name] + set_class.get(class_name, 0) > threshold:
return False
if set_class.get(class_name, 0) < set_class['k']:
isvalid = True
return isvalid
def __finish__(self, set_class):
if len(set_class) < self.N+1:
return False
for k in set_class:
if set_class[k] < set_class['k']:
return False
return True
def __get_candidates__(self, target_classes):
return [idx for idx, sample in enumerate(self.samples) if sample.valid(target_classes)]
def __next__(self):
'''
randomly sample one support set and one query set
return:
target_classes: List[any]
support_idx: List[int], sample index in support set in samples list
support_idx: List[int], sample index in query set in samples list
'''
support_class = {'k':self.K}
support_idx = []
query_class = {'k':self.Q}
query_idx = []
target_classes = random.sample(self.classes, self.N)
candidates = self.__get_candidates__(target_classes)
while not candidates:
target_classes = random.sample(self.classes, self.N)
candidates = self.__get_candidates__(target_classes)
# greedy search for support set
while not self.__finish__(support_class):
index = random.choice(candidates)
if index not in support_idx:
if self.__valid_sample__(self.samples[index], support_class, target_classes):
self.__additem__(index, support_class)
support_idx.append(index)
# same for query set
while not self.__finish__(query_class):
index = random.choice(candidates)
if index not in query_idx and index not in support_idx:
if self.__valid_sample__(self.samples[index], query_class, target_classes):
self.__additem__(index, query_class)
query_idx.append(index)
return target_classes, support_idx, query_idx
def __iter__(self):
return self
class DebugSampler:
def __init__(self, filepath, querypath=None):
data = []
with open(filepath, mode="r", encoding="utf-8") as fp:
for line in fp:
data.append(json.loads(line))
self.data = data
if querypath is not None:
with open(querypath, mode="r", encoding="utf-8") as fp:
self.ment_data = [json.loads(line) for line in fp]
if len(self.ment_data) != len(self.data):
print("the mention data len is different with input episode number!!!!!!!!!!!!!!!")
else:
self.ment_data = None
return
def __next__(self, idx):
idx = idx % len(self.data)
target_classes = self.data[idx]["target_classes"]
support_idx = self.data[idx]["support_idx"]
query_idx = self.data[idx]["query_idx"]
if self.ment_data is None:
return target_classes, support_idx, query_idx
return target_classes, support_idx, query_idx, self.ment_data[idx]
def __len__(self):
return len(self.data)

Просмотреть файл

@ -0,0 +1,351 @@
import json
import torch
import torch.utils.data as data
import os
from .fewshotsampler import DebugSampler, FewshotSampleBase
import numpy as np
import random
from collections import defaultdict
from .span_sample import SpanSample
from .span_loader import QuerySpanBatcher
class JointBatcher(QuerySpanBatcher):
def __init__(self, iou_thred, use_oproto):
super().__init__(iou_thred, use_oproto)
return
def batchnize_episode(self, data, mode):
support_sets, query_sets = zip(*data)
if mode == "train":
is_train = True
else:
is_train = False
batch_support = self.batchnize_sent(support_sets, "support", is_train)
batch_query = self.batchnize_sent(query_sets, "query", is_train)
batch_query["label2tag"] = []
for i in range(len(query_sets)):
batch_query["label2tag"].append(query_sets[i]["label2tag"])
return {"support": batch_support, "query": batch_query}
def batchnize_sent(self, data, mode, is_train):
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
"seq_len": [],
"spans": [], "sentence_num": [], "query_ments": [], "query_probs": [], "subsentence_num": [],
'split_words': [],
"ment_labels": []}
for i in range(len(data)):
for k in batch.keys():
if k == 'sentence_num':
batch[k].append(data[i][k])
else:
batch[k] += data[i][k]
max_n_piece = max([sum(x) for x in batch['word_mask']])
max_n_words = max(batch["seq_len"])
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
ment_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
for k, slen in enumerate(batch['seq_len']):
assert len(batch['word_to_piece_ind'][k]) == slen
assert len(batch['word_to_piece_end'][k]) == slen
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
ment_labels[k, :slen] = batch['ment_labels'][k]
batch['word_to_piece_ind'] = word_to_piece_ind
batch['word_to_piece_end'] = word_to_piece_end
batch['ment_labels'] = ment_labels
if mode == "support":
batch = self.make_support_batch(batch)
else:
batch = self.make_query_batch(batch, is_train)
for k, v in batch.items():
if k not in ['spans', 'sentence_num', 'label2tag', 'index', "query_ments", "query_probs", "subsentence_num",
"split_words"]:
v = np.array(v)
if k == "span_weights":
batch[k] = torch.tensor(v).float()
else:
batch[k] = torch.tensor(v).long()
batch['word'] = batch['word'][:, :max_n_piece]
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
return batch
def __call__(self, batch, mode):
return self.batchnize_episode(batch, mode)
class JointNERDataset(data.Dataset):
"""
Fewshot NER Dataset
"""
def __init__(self, filepath, encoder, N, K, Q, max_length, schema, \
bio=True, debug_file=None, query_file=None, hidden_query_label=False, labelname_fn=None):
if not os.path.exists(filepath):
print("[ERROR] Data file does not exist!")
assert (0)
self.class2sampleid = {}
self.N = N
self.K = K
self.Q = Q
self.encoder = encoder
self.schema = schema # this means the meta-train/test schema
if self.schema == 'BIO':
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2}
elif self.schema == 'IO':
self.ment_tag2label = {"O": 0, "I-X": 1}
elif self.schema == 'BIOES':
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2, "E-X": 3, "S-X": 4}
else:
raise ValueError
self.ment_label2tag = {lidx: tag for tag, lidx in self.ment_tag2label.items()}
self.label2tag = None
self.tag2label = None
self.sql_label2tag = None
self.sql_tag2label = None
self.samples, self.classes = self.__load_data_from_file__(filepath, bio)
if debug_file:
if query_file is None:
print("use golden mention for typing !!! input_fn: {}".format(filepath))
self.sampler = DebugSampler(debug_file, query_file)
self.max_length = max_length
return
def __insert_sample__(self, index, sample_classes):
for item in sample_classes:
if item in self.class2sampleid:
self.class2sampleid[item].append(index)
else:
self.class2sampleid[item] = [index]
return
def __load_data_from_file__(self, filepath, bio):
samples = []
classes = []
with open(filepath, 'r', encoding='utf-8')as f:
lines = f.readlines()
samplelines = []
index = 0
for line in lines:
line = line.strip("\n")
if len(line):
samplelines.append(line)
else:
sample = SpanSample(index, samplelines, bio)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
samplelines = []
index += 1
if len(samplelines):
sample = SpanSample(index, samplelines, bio)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
classes = list(set(classes))
max_span_len = -1
long_ent_num = 0
tot_ent_num = 0
tot_tok_num = 0
for eid, sample in enumerate(samples):
max_span_len = max(max_span_len, sample.get_max_ent_len())
long_ent_num += sample.get_num_of_long_ent(10)
tot_ent_num += len(sample.spans)
tot_tok_num += len(sample.words)
# convert seq labels to target schema
new_tags = ['O' for _ in range(len(sample.words))]
for sp in sample.spans:
stype = sp[0]
sp_st = sp[1]
sp_ed = sp[2]
assert stype != "O"
if self.schema == 'IO':
for k in range(sp_st, sp_ed + 1):
new_tags[k] = "I-" + stype
elif self.schema == 'BIO':
new_tags[sp_st] = "B-" + stype
for k in range(sp_st + 1, sp_ed + 1):
new_tags[k] = "I-" + stype
elif self.schema == 'BIOES':
if sp_st == sp_ed:
new_tags[sp_st] = "S-" + stype
else:
new_tags[sp_st] = "B-" + stype
new_tags[sp_ed] = "E-" + stype
for k in range(sp_st + 1, sp_ed):
new_tags[k] = "I-" + stype
else:
raise ValueError
assert len(new_tags) == len(samples[eid].tags)
samples[eid].tags = new_tags
print("Sentence num {}, token num {}, entity num {} in file {}".format(len(samples), tot_tok_num, tot_ent_num,
filepath))
print("Total classes {}: {}".format(len(classes), str(classes)))
print("The max golden entity len in the dataset is ", max_span_len)
print("The max golden entity len in the dataset is greater than 10", long_ent_num)
print("The total coverage of spans: {:.5f}".format(1 - long_ent_num / tot_ent_num))
return samples, classes
def get_ment_word_tag(self, wtag):
if wtag == "O":
return wtag
return wtag[:2] + "X"
def __getraw__(self, sample, add_split):
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_ids, seq_lens = self.encoder.tokenize(sample.words)
sent_st_id = 0
split_seqs = []
ment_labels = []
word_labels = []
split_spans = []
split_querys = []
split_probs = []
split_words = []
cur_wid = 0
for k in range(len(seq_lens)):
split_words.append(sample.words[cur_wid: cur_wid + seq_lens[k]])
cur_wid += seq_lens[k]
for cur_len in seq_lens:
sent_ed_id = sent_st_id + cur_len
split_seqs.append(sample.tags[sent_st_id: sent_ed_id])
cur_ment_seqs = []
cur_word_seqs = []
split_spans.append([])
for tag, span_st, span_ed in sample.spans:
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
continue
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, span_ed - sent_st_id])
elif add_split:
if span_st >= sent_st_id:
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
else:
split_spans[-1].append([self.tag2label[tag], 0, span_ed - sent_st_id])
split_querys.append([])
split_probs.append([])
for [span_st, span_ed], span_prob in zip(sample.query_ments, sample.query_probs):
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
continue
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
split_querys[-1].append([span_st - sent_st_id, span_ed - sent_st_id])
split_probs[-1].append(span_prob)
elif add_split:
if span_st >= sent_st_id:
split_querys[-1].append([span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
split_probs[-1].append(span_prob)
else:
split_querys[-1].append([0, span_ed - sent_st_id])
split_probs[-1].append(span_prob)
for wtag in split_seqs[-1]:
cur_ment_seqs.append(self.ment_tag2label[self.get_ment_word_tag(wtag)])
cur_word_seqs.append(-1) # not used
ment_labels.append(cur_ment_seqs)
word_labels.append(cur_word_seqs)
sent_st_id += cur_len
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind, "word_to_piece_end": word_to_piece_end,
"seq_len": seq_lens, "word_shape_ids": word_shape_ids, "ment_labels": ment_labels, "word_labels": word_labels,
"subsentence_num": len(seq_lens), "spans": split_spans, "query_ments": split_querys, "query_probs": split_probs,
"split_words": split_words}
return item
def __additem__(self, index, d, item):
d['index'].append(index)
d['word'] += item['word']
d['word_mask'] += item['word_mask']
d['seq_len'] += item['seq_len']
d['word_to_piece_ind'] += item['word_to_piece_ind']
d['word_to_piece_end'] += item['word_to_piece_end']
d['word_shape_ids'] += item['word_shape_ids']
d['spans'] += item['spans']
d['query_ments'] += item['query_ments']
d['query_probs'] += item['query_probs']
d['ment_labels'] += item['ment_labels']
d['word_labels'] += item['word_labels']
d['subsentence_num'].append(item['subsentence_num'])
d['split_words'] += item['split_words']
return
def __populate__(self, idx_list, query_ment_mp=None, savelabeldic=False, add_split=False):
dataset = {'index': [], 'word': [], 'word_mask': [], 'ment_labels': [], 'word_labels': [], 'word_to_piece_ind': [],
"word_to_piece_end": [], "seq_len": [], "word_shape_ids": [], "subsentence_num": [], 'spans': [], "query_ments": [], "query_probs": [], 'split_words': []}
for idx in idx_list:
if query_ment_mp is not None:
self.samples[idx].query_ments = [x[0] for x in query_ment_mp[str(self.samples[idx].index)]]
self.samples[idx].query_probs = [x[1] for x in query_ment_mp[str(self.samples[idx].index)]]
else:
self.samples[idx].query_ments = [[x[1], x[2]] for x in self.samples[idx].spans]
self.samples[idx].query_probs = [1 for x in self.samples[idx].spans]
item = self.__getraw__(self.samples[idx], add_split)
self.__additem__(idx, dataset, item)
if savelabeldic:
dataset['label2tag'] = self.label2tag
dataset['sql_label2tag'] = self.sql_label2tag
dataset['sentence_num'] = len(dataset["seq_len"])
assert len(dataset['word']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
assert len(dataset['ment_labels']) == len(dataset['seq_len'])
assert len(dataset['word_labels']) == len(dataset['seq_len'])
return dataset
def __getitem__(self, index):
tmp = self.sampler.__next__(index)
if len(tmp) == 3:
target_classes, support_idx, query_idx = tmp
query_ment_mp = None
else:
target_classes, support_idx, query_idx, query_ment_mp = tmp
target_tags = ['O']
for cname in target_classes:
if self.schema == 'IO':
target_tags.append(f"I-{cname}")
elif self.schema == 'BIO':
target_tags.append(f"B-{cname}")
target_tags.append(f"I-{cname}")
elif self.schema == 'BIOES':
target_tags.append(f"B-{cname}")
target_tags.append(f"I-{cname}")
target_tags.append(f"E-{cname}")
target_tags.append(f"S-{cname}")
else:
raise ValueError
self.sql_tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
self.sql_label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
distinct_tags = ['O'] + target_classes
self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
support_set = self.__populate__(support_idx, add_split=True, savelabeldic=False)
query_set = self.__populate__(query_idx, query_ment_mp=query_ment_mp, add_split=True, savelabeldic=True)
return support_set, query_set
def __len__(self):
return 1000000
def batch_to_device(self, batch, device):
for k, v in batch.items():
if k in ['sentence_num', 'label2tag', 'spans', 'index', 'query_ments', 'query_probs', "subsentence_num", 'split_words']:
continue
batch[k] = v.to(device)
return batch
def get_joint_loader(filepath, mode, encoder, N, K, Q, batch_size, max_length, schema,
bio, shuffle, num_workers=8, debug_file=None, query_file=None,
iou_thred=None, hidden_query_label=False, label_fn=None, use_oproto=False):
batcher = JointBatcher(iou_thred=iou_thred, use_oproto=use_oproto)
dataset = JointNERDataset(filepath, encoder, N, K, Q, max_length, bio=bio, schema=schema,
debug_file=debug_file, query_file=query_file,
hidden_query_label=hidden_query_label, labelname_fn=label_fn)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=lambda x: batcher(x, mode))
return dataloader

Просмотреть файл

@ -0,0 +1,416 @@
import json
from collections import defaultdict
import random
import torch
import prettytable as pt
import numpy as np
def cal_prf(tot_hit_cnt, tot_pred_cnt, tot_gold_cnt):
precision = tot_hit_cnt / tot_pred_cnt if tot_hit_cnt > 0 else 0
recall = tot_hit_cnt / tot_gold_cnt if tot_hit_cnt > 0 else 0
f1 = precision * recall * 2 / (precision + recall) if (tot_hit_cnt > 0) else 0
return precision, recall, f1
def cal_episode_prf(logs):
pred_list = logs["pred"]
gold_list = logs["gold"]
indexes = logs["index"]
k = 0
query_episode_start_sent_idlist = []
for ep_subsent_num in logs["sentence_num"]:
query_episode_start_sent_idlist.append(k)
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["subsentence_num"][k]
k += 1
query_episode_start_sent_idlist.append(k)
subsent_id = 0
sent_id = 0
ep_prf = {}
for cur_preds, cur_golds, cur_index, snum in zip(pred_list, gold_list, indexes, logs["subsentence_num"]):
if sent_id in query_episode_start_sent_idlist:
ep_id = query_episode_start_sent_idlist.index(sent_id)
ep_prf[ep_id] = {"hit": 0, "pred": 0, "gold": 0}
pred_spans = set(map(lambda x: tuple(x), cur_preds))
gold_spans = set(map(lambda x: tuple(x), cur_golds))
ep_prf[ep_id]["pred"] += len(pred_spans)
ep_prf[ep_id]["gold"] += len(gold_spans)
ep_prf[ep_id]["hit"] += len(pred_spans.intersection(gold_spans))
subsent_id += snum
sent_id += 1
ep_p = 0
ep_r = 0
ep_f1 = 0
for ep in ep_prf.values():
p, r, f1 = cal_prf(ep["hit"], ep["pred"], ep["gold"])
ep_p += p
ep_r += r
ep_f1 += f1
ep_p /= len(ep_prf)
ep_r /= len(ep_prf)
ep_f1 /= len(ep_prf)
return ep_p, ep_r, ep_f1
def eval_ment_log(samples, logs):
indexes = logs["index"]
pred_list = logs["pred"]
gold_list = logs["gold"]
gold_cnt_mp = defaultdict(int)
hit_cnt_mp = defaultdict(int)
for cur_idx, cur_preds, cur_golds in zip(indexes, pred_list, gold_list):
tagged_spans = samples[cur_idx].spans
for r, x, y in tagged_spans:
gold_cnt_mp[r] += 1
if [x, y] in cur_preds:
hit_cnt_mp[r] += 1
tot_miss = 0
tb = pt.PrettyTable(["Type", "recall", "miss_span", "tot_span"])
for gtype in sorted(gold_cnt_mp.keys()):
rscore = hit_cnt_mp[gtype] / gold_cnt_mp[gtype]
miss_cnt = gold_cnt_mp[gtype] - hit_cnt_mp[gtype]
tb.add_row([
gtype, "{:.4f}".format(rscore), miss_cnt, gold_cnt_mp[gtype]
])
tot_miss += miss_cnt
tb.add_row(["Total", "{:.4f}".format(sum(hit_cnt_mp.values()) / sum(gold_cnt_mp.values())), tot_miss, sum(gold_cnt_mp.values())])
print(tb)
return
def write_ep_ment_log_json(samples, logs, output_fn):
k = 0
subsent_id = 0
query_episode_end_subsent_idlist = []
for ep_subsent_num in logs["sentence_num"]:
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["subsentence_num"][k]
subsent_id += logs["subsentence_num"][k]
k += 1
query_episode_end_subsent_idlist.append(subsent_id)
indexes = logs["index"]
if "before_prob" in logs:
split_ind_list = logs["before_ind"]
split_prob_list = logs["before_prob"]
else:
split_ind_list = None
split_prob_list = None
sent_pred_list = logs["pred"]
split_slen_list = logs["seq_len"]
subsent_num_list = logs["subsentence_num"]
log_lines = []
cur_query_res = {}
subsent_id = 0
sent_id = 0
for snum, cur_index in zip(subsent_num_list, indexes):
if subsent_id in query_episode_end_subsent_idlist:
log_lines.append(cur_query_res)
cur_query_res = {}
cur_probs = []
if split_prob_list is not None:
cur_sent_st = 0
for k in range(subsent_id, subsent_id + snum):
for x, y in zip(split_ind_list[k], split_prob_list[k]):
cur_probs.append(([x[0] + cur_sent_st, x[1] + cur_sent_st], y))
cur_sent_st += split_slen_list[k]
else:
for x in sent_pred_list[sent_id]:
cur_probs.append(([x[0], x[1]], 1))
cur_query_res[samples[cur_index].index] = cur_probs
subsent_id += snum
sent_id += 1
assert subsent_id in query_episode_end_subsent_idlist
log_lines.append(cur_query_res)
cur_query_res = {}
with open(output_fn, mode="w", encoding="utf-8") as fp:
output_lines = []
for line in log_lines:
output_lines.append(json.dumps(line) + "\n")
fp.writelines(output_lines)
return
def write_ment_log(samples, logs, output_fn):
indexes = logs["index"]
pred_list = logs["pred"]
gold_list = logs["gold"]
subsent_num_list = logs["subsentence_num"]
split_slen_list = logs["seq_len"]
if "before_prob" in logs:
split_ind_list = logs["before_ind"]
split_prob_list = logs["before_prob"]
else:
split_ind_list = None
split_prob_list = None
log_lines = []
assert len(pred_list) == len(indexes)
assert len(gold_list) == len(indexes)
subsent_id = 0
for cur_preds, cur_golds, cur_index, snum in zip(pred_list, gold_list, indexes, subsent_num_list):
cur_sample = samples[cur_index]
cur_probs = []
if split_prob_list is not None:
cur_sent_st = 0
for k in range(subsent_id, subsent_id + snum):
for x, y in zip(split_ind_list[k], split_prob_list[k]):
cur_probs.append(([x[0] + cur_sent_st, x[1] + cur_sent_st], y))
cur_sent_st += split_slen_list[k]
subsent_id += snum
log_lines.append("index:{}\n".format(cur_sample.index))
log_lines.append("pred:\n")
for x in cur_preds:
log_lines.append(" ".join(cur_sample.words[x[0]: x[1] + 1]) + " " + str(x) + "\n")
log_lines.append("gold:\n")
for x in cur_golds:
log_lines.append(" ".join(cur_sample.words[x[0]: x[1] + 1]) + " " + str(x) + "\n")
log_lines.append("log:\n")
log_lines.append(json.dumps(cur_probs) + "\n")
log_lines.append("\n")
with open(output_fn, mode="w", encoding="utf-8") as fp:
fp.writelines(log_lines)
return
def eval_ent_log(samples, logs):
indexes = logs["index"]
pred_list = logs["pred"]
gold_list = logs["gold"]
assert len(indexes) == len(pred_list)
gold_cnt_mp = defaultdict(int)
hit_cnt_mp = defaultdict(int)
pred_cnt_mp = defaultdict(int)
hit_mentcnt_mp = defaultdict(int)
fp_cnt_mp = defaultdict(int)
for cur_idx, cur_preds, cur_golds in zip(indexes, pred_list, gold_list):
used_token = np.array([0 for i in range(len(samples[cur_idx].words))])
cur_ments = [[x[1], x[2]] for x in cur_preds]
tagged_spans = samples[cur_idx].spans
for r, x, y in tagged_spans:
gold_cnt_mp[r] += 1
used_token[x: y + 1] = 1
if [r, x, y] in cur_preds:
hit_cnt_mp[r] += 1
if [x, y] in cur_ments:
hit_mentcnt_mp[r] += 1
for r, x, y in cur_preds:
pred_cnt_mp[r] += 1
if sum(used_token[x: y + 1]) == 0:
fp_cnt_mp[r] += 1
tb = pt.PrettyTable(["Type", "precision", "recall", "f1", "overall_ment_recall", "error_miss_ment", "false_span", "span_type_error", "correct"])
for gtype in sorted(gold_cnt_mp.keys()):
pscore = hit_cnt_mp[gtype] / pred_cnt_mp[gtype] if hit_cnt_mp[gtype] > 0 else 0
rscore = hit_cnt_mp[gtype] / gold_cnt_mp[gtype] if hit_cnt_mp[gtype] > 0 else 0
fscore = pscore * rscore * 2 / (pscore + rscore) if hit_cnt_mp[gtype] > 0 else 0
ment_rscore = hit_mentcnt_mp[gtype] / gold_cnt_mp[gtype]
miss_ment_cnt = gold_cnt_mp[gtype] - hit_mentcnt_mp[gtype]
fp_cnt = fp_cnt_mp[gtype]
type_error_cnt = hit_mentcnt_mp[gtype] - hit_cnt_mp[gtype]
correct_cnt = hit_cnt_mp[gtype]
tb.add_row([
gtype, "{:.4f}".format(pscore), "{:.4f}".format(rscore), "{:.4f}".format(fscore), "{:.4f}".format(ment_rscore), miss_ment_cnt, fp_cnt, type_error_cnt, correct_cnt
])
pscore = sum(hit_cnt_mp.values()) / sum(pred_cnt_mp.values()) if sum(hit_cnt_mp.values()) > 0 else 0
rscore = sum(hit_cnt_mp.values()) / sum(gold_cnt_mp.values()) if sum(hit_cnt_mp.values()) > 0 else 0
fscore = pscore * rscore * 2 / (pscore + rscore) if sum(hit_cnt_mp.values()) > 0 else 0
ment_rscore = sum(hit_mentcnt_mp.values()) / sum(gold_cnt_mp.values())
miss_ment_cnt = sum(gold_cnt_mp.values()) - sum(hit_mentcnt_mp.values())
type_error_cnt = sum(hit_mentcnt_mp.values()) - sum(hit_cnt_mp.values())
tb.add_row([
"Overall", "{:.4f}".format(pscore), "{:.4f}".format(rscore), "{:.4f}".format(fscore), "{:.4f}".format(ment_rscore), miss_ment_cnt, sum(fp_cnt_mp.values()),type_error_cnt, sum(hit_cnt_mp.values())
])
print(tb)
return
def write_ent_log(samples, logs, output_fn):
k = 0
support_episode_start_sent_idlist = []
for ep_subsent_num in logs["support_sentence_num"]:
support_episode_start_sent_idlist.append(k)
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["support_subsentence_num"][k]
k += 1
support_episode_start_sent_idlist.append(k)
k = 0
query_episode_start_sent_idlist = []
for ep_subsent_num in logs["sentence_num"]:
query_episode_start_sent_idlist.append(k)
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["subsentence_num"][k]
k += 1
query_episode_start_sent_idlist.append(k)
assert len(support_episode_start_sent_idlist) == len(query_episode_start_sent_idlist)
indexes = logs["index"]
pred_list = logs["pred"]
gold_list = logs["gold"]
split_ind_list = logs["before_ind"] if "before_ind" in logs else None
split_prob_list = logs["before_prob"] if "before_prob" in logs else None
split_slen_list = logs["seq_len"]
split_label_tag_list = logs["label_tag"] if "label_tag" in logs else None
log_lines = []
subsent_id = 0
sent_id = 0
for cur_preds, cur_golds, cur_index, snum in zip(pred_list, gold_list, indexes, logs["subsentence_num"]):
if sent_id in query_episode_start_sent_idlist:
ep_id = query_episode_start_sent_idlist.index(sent_id)
support_sent_id_1 = support_episode_start_sent_idlist[ep_id]
support_sent_id_2 = support_episode_start_sent_idlist[ep_id + 1]
log_lines.append("="*20+"\n")
support_indexes = logs["support_index"][support_sent_id_1:support_sent_id_2]
log_lines.append("support:{}\n".format(str(support_indexes)))
for x in support_indexes:
log_lines.append(str(samples[x].words) + "\n")
log_lines.append(str([sp[0] + ":" + " ".join(samples[x].words[sp[1]: sp[2] + 1]) for sp in samples[x].spans]) + "\n")
if split_label_tag_list is not None:
log_lines.append(json.dumps(split_label_tag_list[subsent_id]) + "\n")
log_lines.append("\n")
cur_sample = samples[cur_index]
cur_probs = []
if split_ind_list is not None:
cur_sent_st = 0
for k in range(subsent_id, subsent_id + snum):
for x, y_list in zip(split_ind_list[k], split_prob_list[k]):
cur_probs.append(str([x[0] + cur_sent_st, x[1] + cur_sent_st])
+ ": " + ",".join(["{:.5f}".format(y) for y in y_list])
+ ", "
+ " ".join(cur_sample.words[x[0] + cur_sent_st: x[1] + cur_sent_st + 1]) + "\n")
cur_sent_st += split_slen_list[k]
subsent_id += snum
log_lines.append("index:{}\n".format(cur_index))
log_lines.append(str(cur_sample.words) + "\n")
log_lines.append("pred:\n")
for x in cur_preds:
log_lines.append(" ".join(cur_sample.words[x[1]: x[2] + 1]) + " " + str(x) + "\n")
log_lines.append("gold:\n")
for x in cur_golds:
log_lines.append(" ".join(cur_sample.words[x[1]: x[2] + 1]) + " " + str(x) + "\n")
log_lines.append("log:\n")
log_lines.extend(cur_probs)
log_lines.append("\n")
sent_id += 1
with open(output_fn, mode="w", encoding="utf-8") as fp:
fp.writelines(log_lines)
return log_lines
def write_ent_pred_json(samples, logs, output_fn):
k = 0
support_episode_start_sent_idlist = []
for ep_subsent_num in logs["support_sentence_num"]:
support_episode_start_sent_idlist.append(k)
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["support_subsentence_num"][k]
k += 1
support_episode_start_sent_idlist.append(k)
k = 0
query_episode_start_sent_idlist = []
for ep_subsent_num in logs["sentence_num"]:
query_episode_start_sent_idlist.append(k)
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["subsentence_num"][k]
k += 1
query_episode_start_sent_idlist.append(k)
assert len(support_episode_start_sent_idlist) == len(query_episode_start_sent_idlist)
indexes = logs["index"]
pred_list = logs["pred"]
log_lines = []
cur_query_res = None
support_indexes = None
subsent_id = 0
sent_id = 0
for cur_preds, cur_index, snum in zip(pred_list, indexes, logs["subsentence_num"]):
if sent_id in query_episode_start_sent_idlist:
if cur_query_res is not None:
log_lines.append({"support": support_indexes, "query": cur_query_res})
ep_id = query_episode_start_sent_idlist.index(sent_id)
support_sent_id_1 = support_episode_start_sent_idlist[ep_id]
support_sent_id_2 = support_episode_start_sent_idlist[ep_id + 1]
support_indexes = logs["support_index"][support_sent_id_1:support_sent_id_2]
cur_query_res = []
cur_sample = samples[cur_index]
subsent_id += snum
cur_query_res.append({"index":cur_index, "pred":cur_preds, "gold": cur_sample.spans})
sent_id += 1
if len(cur_query_res) > 0:
log_lines.append({"support": support_indexes, "query": cur_query_res})
with open(output_fn, mode="w", encoding="utf-8") as fp:
output_lines = []
for line in log_lines:
output_lines.append(json.dumps(line) + "\n")
fp.writelines(output_lines)
return
def write_pos_pred_json(samples, logs, output_fn):
k = 0
support_episode_start_sent_idlist = []
for ep_subsent_num in logs["support_sentence_num"]:
support_episode_start_sent_idlist.append(k)
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["support_subsentence_num"][k]
k += 1
support_episode_start_sent_idlist.append(k)
k = 0
query_episode_start_sent_idlist = []
for ep_subsent_num in logs["sentence_num"]:
query_episode_start_sent_idlist.append(k)
cur_subsent_num = 0
while cur_subsent_num < ep_subsent_num:
cur_subsent_num += logs["subsentence_num"][k]
k += 1
query_episode_start_sent_idlist.append(k)
assert len(support_episode_start_sent_idlist) == len(query_episode_start_sent_idlist)
indexes = logs["index"]
pred_list = logs["pred"]
log_lines = []
cur_query_res = None
support_indexes = None
subsent_id = 0
sent_id = 0
for cur_preds, cur_index, snum in zip(pred_list, indexes, logs["subsentence_num"]):
if sent_id in query_episode_start_sent_idlist:
if cur_query_res is not None:
log_lines.append({"support": support_indexes, "query": cur_query_res})
ep_id = query_episode_start_sent_idlist.index(sent_id)
support_sent_id_1 = support_episode_start_sent_idlist[ep_id]
support_sent_id_2 = support_episode_start_sent_idlist[ep_id + 1]
support_indexes = logs["support_index"][support_sent_id_1:support_sent_id_2]
cur_query_res = []
cur_sample = samples[cur_index]
subsent_id += snum
cur_query_res.append({"index":cur_index, "pred":cur_preds, "gold": cur_sample.tags})
sent_id += 1
if len(cur_query_res) > 0:
log_lines.append({"support": support_indexes, "query": cur_query_res})
with open(output_fn, mode="w", encoding="utf-8") as fp:
output_lines = []
for line in log_lines:
output_lines.append(json.dumps(line) + "\n")
fp.writelines(output_lines)
return
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs)
return
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
return

Просмотреть файл

@ -0,0 +1,244 @@
import json
import torch
import torch.utils.data as data
import os
from .fewshotsampler import DebugSampler, FewshotSampleBase
import numpy as np
import random
from collections import defaultdict
class TokenSample(FewshotSampleBase):
def __init__(self, idx, filelines):
super(TokenSample, self).__init__()
self.index = idx
filelines = [line.split('\t') for line in filelines]
if len(filelines[0]) == 2:
self.words, self.tags = zip(*filelines)
else:
self.words, self.postags, self.tags = zip(*filelines)
return
def __count_entities__(self):
self.class_count = {}
for tag in self.tags:
if tag in self.class_count:
self.class_count[tag] += 1
else:
self.class_count[tag] = 1
return
def get_class_count(self):
if self.class_count:
return self.class_count
else:
self.__count_entities__()
return self.class_count
def get_tag_class(self):
return list(set(self.tags))
def valid(self, target_classes):
return (set(self.get_class_count().keys()).intersection(set(target_classes))) and \
not (set(self.get_class_count().keys()).difference(set(target_classes)))
def __str__(self):
return
class SeqBatcher:
def __init__(self):
return
def batchnize_sent(self, data):
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
"word_shape_ids": [], "word_labels": [],
"seq_len": [], "sentence_num": [], "subsentence_num": []}
for i in range(len(data)):
for k in batch.keys():
if k == 'sentence_num':
batch[k].append(data[i][k])
else:
batch[k] += data[i][k]
max_n_piece = max([sum(x) for x in batch['word_mask']])
max_n_words = max(batch["seq_len"])
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
word_shape_ids = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
word_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
for k, slen in enumerate(batch['seq_len']):
assert len(batch['word_to_piece_ind'][k]) == slen
assert len(batch['word_to_piece_end'][k]) == slen
assert len(batch['word_shape_ids'][k]) == slen
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
word_shape_ids[k, :slen] = batch['word_shape_ids'][k]
word_labels[k, :slen] = batch['word_labels'][k]
batch['word_to_piece_ind'] = word_to_piece_ind
batch['word_to_piece_end'] = word_to_piece_end
batch['word_shape_ids'] = word_shape_ids
batch['word_labels'] = word_labels
for k, v in batch.items():
if k not in ['sentence_num', 'index', 'subsentence_num', 'label2tag']:
v = np.array(v, dtype=np.int32)
batch[k] = torch.tensor(v, dtype=torch.long)
batch['word'] = batch['word'][:, :max_n_piece]
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
return batch
def __call__(self, batch, use_episode):
if use_episode:
support_sets, query_sets = zip(*batch)
batch_support = self.batchnize_sent(support_sets)
batch_query = self.batchnize_sent(query_sets)
batch_query["label2tag"] = []
for i in range(len(query_sets)):
batch_query["label2tag"].append(query_sets[i]["label2tag"])
return {"support": batch_support, "query": batch_query}
else:
batch_query = self.batchnize_sent(batch)
batch_query["label2tag"] = []
for i in range(len(batch)):
batch_query["label2tag"].append(batch[i]["label2tag"])
return batch_query
class SeqNERDataset(data.Dataset):
"""
Fewshot NER Dataset
"""
def __init__(self, filepath, encoder, max_length, debug_file=None, use_episode=True):
if not os.path.exists(filepath):
print("[ERROR] Data file does not exist!")
assert (0)
self.class2sampleid = {}
self.encoder = encoder
self.label2tag = None
self.tag2label = None
self.samples, self.classes = self.__load_data_from_file__(filepath)
if debug_file is not None:
self.sampler = DebugSampler(debug_file)
self.max_length = max_length
self.use_episode = use_episode
return
def __insert_sample__(self, index, sample_classes):
for item in sample_classes:
if item in self.class2sampleid:
self.class2sampleid[item].append(index)
else:
self.class2sampleid[item] = [index]
return
def __load_data_from_file__(self, filepath):
samples = []
classes = []
with open(filepath, 'r', encoding='utf-8')as f:
lines = f.readlines()
samplelines = []
index = 0
for line in lines:
line = line.strip("\n")
if len(line):
samplelines.append(line)
else:
sample = TokenSample(index, samplelines)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
samplelines = []
index += 1
if len(samplelines):
sample = TokenSample(index, samplelines)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
classes = list(set(classes))
return samples, classes
def __getraw__(self, sample):
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_ids, seq_lens = self.encoder.tokenize(sample.words)
sent_st_id = 0
split_seqs = []
word_labels = []
for cur_len in seq_lens:
sent_ed_id = sent_st_id + cur_len
split_seqs.append(sample.tags[sent_st_id: sent_ed_id])
cur_word_seqs = []
for wtag in split_seqs[-1]:
if wtag not in self.tag2label:
print(wtag, self.tag2label)
cur_word_seqs.append(self.tag2label[wtag])
word_labels.append(cur_word_seqs)
sent_st_id += cur_len
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind, "word_to_piece_end": word_to_piece_end,
"seq_len": seq_lens, "word_shape_ids": word_shape_ids, "word_labels": word_labels,
"subsentence_num": len(seq_lens)}
return item
def __additem__(self, index, d, item):
d['index'].append(index)
d['word'] += item['word']
d['word_mask'] += item['word_mask']
d['seq_len'] += item['seq_len']
d['word_to_piece_ind'] += item['word_to_piece_ind']
d['word_to_piece_end'] += item['word_to_piece_end']
d['word_shape_ids'] += item['word_shape_ids']
d['word_labels'] += item['word_labels']
d['subsentence_num'].append(item['subsentence_num'])
return
def __populate__(self, idx_list, savelabeldic=False):
dataset = {'index': [], 'word': [], 'word_mask': [], 'word_labels': [], 'word_to_piece_ind': [],
"word_to_piece_end": [], "seq_len": [], "word_shape_ids": [], "subsentence_num": []}
for idx in idx_list:
item = self.__getraw__(self.samples[idx])
self.__additem__(idx, dataset, item)
if savelabeldic:
dataset['label2tag'] = self.label2tag
dataset['sentence_num'] = len(dataset["seq_len"])
assert len(dataset['word']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
assert len(dataset['word_labels']) == len(dataset['seq_len'])
return dataset
def __getitem__(self, index):
target_tags, support_idx, query_idx = self.sampler.__next__(index)
if self.use_episode:
self.tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
self.label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
support_set = self.__populate__(support_idx, savelabeldic=False)
query_set = self.__populate__(query_idx, savelabeldic=True)
return support_set, query_set
else:
if self.tag2label is None:
self.tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
self.label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
return self.__populate__(support_idx + query_idx, savelabeldic=True)
def __len__(self):
return 1000000
def batch_to_device(self, batch, device):
for k, v in batch.items():
if k == 'sentence_num' or k == 'index' or k == 'subsentence_num' or k == 'label2tag':
continue
batch[k] = v.to(device)
return batch
def get_seq_loader(filepath, mode, encoder, batch_size, max_length, shuffle, debug_file=None, num_workers=8, use_episode=True):
batcher = SeqBatcher()
dataset = SeqNERDataset(filepath, encoder, max_length, debug_file=debug_file, use_episode=use_episode)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=lambda x: batcher(x, use_episode))
return dataloader

Просмотреть файл

@ -0,0 +1,270 @@
import json
import torch
import torch.utils.data as data
import os
from .fewshotsampler import DebugSampler, FewshotSampleBase
import numpy as np
import random
from collections import defaultdict
from .span_sample import SpanSample
class SeqBatcher:
def __init__(self):
return
def batchnize_sent(self, data):
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
"word_shape_ids": [], "word_labels": [],
"seq_len": [], "ment_labels": [], "sentence_num": [], "subsentence_num": []}
for i in range(len(data)):
for k in batch.keys():
if k == 'sentence_num':
batch[k].append(data[i][k])
else:
batch[k] += data[i][k]
max_n_piece = max([sum(x) for x in batch['word_mask']])
max_n_words = max(batch["seq_len"])
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
word_shape_ids = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
ment_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
word_labels = np.full(shape=(len(batch['seq_len']), max_n_words), fill_value=-1, dtype='int')
for k, slen in enumerate(batch['seq_len']):
assert len(batch['word_to_piece_ind'][k]) == slen
assert len(batch['word_to_piece_end'][k]) == slen
assert len(batch['word_shape_ids'][k]) == slen
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
word_shape_ids[k, :slen] = batch['word_shape_ids'][k]
ment_labels[k, :slen] = batch['ment_labels'][k]
word_labels[k, :slen] = batch['word_labels'][k]
batch['word_to_piece_ind'] = word_to_piece_ind
batch['word_to_piece_end'] = word_to_piece_end
batch['word_shape_ids'] = word_shape_ids
batch['ment_labels'] = ment_labels
batch['word_labels'] = word_labels
for k, v in batch.items():
if k not in ['sentence_num', 'index', 'subsentence_num', 'label2tag']:
v = np.array(v, dtype=np.int32)
batch[k] = torch.tensor(v, dtype=torch.long)
batch['word'] = batch['word'][:, :max_n_piece]
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
return batch
def __call__(self, batch):
support_sets, query_sets = zip(*batch)
batch_support = self.batchnize_sent(support_sets)
batch_query = self.batchnize_sent(query_sets)
batch_query["label2tag"] = []
for i in range(len(query_sets)):
batch_query["label2tag"].append(query_sets[i]["label2tag"])
return {"support": batch_support, "query": batch_query}
class SeqNERDataset(data.Dataset):
"""
Fewshot NER Dataset
"""
def __init__(self, filepath, encoder, max_length, schema, bio=True, debug_file=None):
if not os.path.exists(filepath):
print("[ERROR] Data file does not exist!")
assert (0)
self.class2sampleid = {}
self.encoder = encoder
self.schema = schema # this means the meta-train/test schema
if self.schema == 'BIO':
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2}
elif self.schema == 'IO':
self.ment_tag2label = {"O": 0, "I-X": 1}
elif self.schema == 'BIOES':
self.ment_tag2label = {"O": 0, "B-X": 1, "I-X": 2, "E-X": 3, "S-X": 4}
else:
raise ValueError
self.ment_label2tag = {lidx: tag for tag, lidx in self.ment_tag2label.items()}
self.label2tag = None
self.tag2label = None
self.samples, self.classes = self.__load_data_from_file__(filepath, bio)
if debug_file is not None:
self.sampler = DebugSampler(debug_file)
self.max_length = max_length
return
def __insert_sample__(self, index, sample_classes):
for item in sample_classes:
if item in self.class2sampleid:
self.class2sampleid[item].append(index)
else:
self.class2sampleid[item] = [index]
return
def __load_data_from_file__(self, filepath, bio):
samples = []
classes = []
with open(filepath, 'r', encoding='utf-8')as f:
lines = f.readlines()
samplelines = []
index = 0
for line in lines:
line = line.strip("\n")
if len(line):
samplelines.append(line)
else:
sample = SpanSample(index, samplelines, bio)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
samplelines = []
index += 1
if len(samplelines):
sample = SpanSample(index, samplelines, bio)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
classes = list(set(classes))
max_span_len = -1
long_ent_num = 0
tot_ent_num = 0
tot_tok_num = 0
for eid, sample in enumerate(samples):
max_span_len = max(max_span_len, sample.get_max_ent_len())
long_ent_num += sample.get_num_of_long_ent(10)
tot_ent_num += len(sample.spans)
tot_tok_num += len(sample.words)
# convert seq labels to target schema
new_tags = ['O' for _ in range(len(sample.words))]
for sp in sample.spans:
stype = sp[0]
sp_st = sp[1]
sp_ed = sp[2]
assert stype != "O"
if self.schema == 'IO':
for k in range(sp_st, sp_ed + 1):
new_tags[k] = "I-" + stype
elif self.schema == 'BIO':
new_tags[sp_st] = "B-" + stype
for k in range(sp_st + 1, sp_ed + 1):
new_tags[k] = "I-" + stype
elif self.schema == 'BIOES':
if sp_st == sp_ed:
new_tags[sp_st] = "S-" + stype
else:
new_tags[sp_st] = "B-" + stype
new_tags[sp_ed] = "E-" + stype
for k in range(sp_st + 1, sp_ed):
new_tags[k] = "I-" + stype
else:
raise ValueError
assert len(new_tags) == len(samples[eid].tags)
samples[eid].tags = new_tags
print("Sentence num {}, token num {}, entity num {} in file {}".format(len(samples), tot_tok_num, tot_ent_num,
filepath))
print("Total classes {}: {}".format(len(classes), str(classes)))
print("The max golden entity len in the dataset is ", max_span_len)
print("The max golden entity len in the dataset is greater than 10", long_ent_num)
print("The total coverage of spans: {:.5f}".format(1 - long_ent_num / tot_ent_num))
return samples, classes
def get_ment_word_tag(self, wtag):
if wtag == "O":
return wtag
return wtag[:2] + "X"
def __getraw__(self, sample):
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_ids, seq_lens = self.encoder.tokenize(sample.words)
sent_st_id = 0
split_seqs = []
ment_labels = []
word_labels = []
for cur_len in seq_lens:
sent_ed_id = sent_st_id + cur_len
split_seqs.append(sample.tags[sent_st_id: sent_ed_id])
cur_ment_seqs = []
cur_word_seqs = []
for wtag in split_seqs[-1]:
cur_ment_seqs.append(self.ment_tag2label[self.get_ment_word_tag(wtag)])
if wtag not in self.tag2label:
print(wtag, self.tag2label)
cur_word_seqs.append(self.tag2label[wtag])
ment_labels.append(cur_ment_seqs)
word_labels.append(cur_word_seqs)
sent_st_id += cur_len
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind, "word_to_piece_end": word_to_piece_end,
"seq_len": seq_lens, "word_shape_ids": word_shape_ids, "ment_labels": ment_labels, "word_labels": word_labels,
"subsentence_num": len(seq_lens)}
return item
def __additem__(self, index, d, item):
d['index'].append(index)
d['word'] += item['word']
d['word_mask'] += item['word_mask']
d['seq_len'] += item['seq_len']
d['word_to_piece_ind'] += item['word_to_piece_ind']
d['word_to_piece_end'] += item['word_to_piece_end']
d['word_shape_ids'] += item['word_shape_ids']
d['ment_labels'] += item['ment_labels']
d['word_labels'] += item['word_labels']
d['subsentence_num'].append(item['subsentence_num'])
return
def __populate__(self, idx_list, savelabeldic=False):
dataset = {'index': [], 'word': [], 'word_mask': [], 'ment_labels': [], 'word_labels': [], 'word_to_piece_ind': [],
"word_to_piece_end": [], "seq_len": [], "word_shape_ids": [], "subsentence_num": []}
for idx in idx_list:
item = self.__getraw__(self.samples[idx])
self.__additem__(idx, dataset, item)
if savelabeldic:
dataset['label2tag'] = self.label2tag
dataset['sentence_num'] = len(dataset["seq_len"])
assert len(dataset['word']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
assert len(dataset['ment_labels']) == len(dataset['seq_len'])
assert len(dataset['word_labels']) == len(dataset['seq_len'])
return dataset
def __getitem__(self, index):
target_classes, support_idx, query_idx = self.sampler.__next__(index)
target_tags = ['O']
for cname in target_classes:
if self.schema == 'IO':
target_tags.append(f"I-{cname}")
elif self.schema == 'BIO':
target_tags.append(f"B-{cname}")
target_tags.append(f"I-{cname}")
elif self.schema == 'BIOES':
target_tags.append(f"B-{cname}")
target_tags.append(f"I-{cname}")
target_tags.append(f"E-{cname}")
target_tags.append(f"S-{cname}")
else:
raise ValueError
self.tag2label = {tag: idx for idx, tag in enumerate(target_tags)}
self.label2tag = {idx: tag for idx, tag in enumerate(target_tags)}
support_set = self.__populate__(support_idx, savelabeldic=False)
query_set = self.__populate__(query_idx, savelabeldic=True)
return support_set, query_set
def __len__(self):
return 1000000
def batch_to_device(self, batch, device):
for k, v in batch.items():
if k == 'sentence_num' or k == 'index' or k == 'subsentence_num' or k == 'label2tag':
continue
batch[k] = v.to(device)
return batch
def get_seq_loader(filepath, mode, encoder, batch_size, max_length, schema, bio, shuffle, debug_file=None, num_workers=8):
batcher = SeqBatcher()
dataset = SeqNERDataset(filepath, encoder, max_length, schema, bio, debug_file=debug_file)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=batcher)
return dataloader

Просмотреть файл

@ -0,0 +1,306 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import os
import string
from transformers import AutoModel, AutoTokenizer
from .adapter_layer import AdapterStack
class BERTSpanEncoder(nn.Module):
def __init__(self, pretrain_path, max_length, last_n_layer=-4, word_encode_choice='first', span_encode_choice=None,
use_width=False, width_dim=20, use_case=False, case_dim=20, drop_p=0.1, use_att=False, att_hidden_dim=100,
use_adapter=False, adapter_size=64, adapter_layer_ids=None):
nn.Module.__init__(self)
self.use_adapter = use_adapter
self.adapter_layer_ids = adapter_layer_ids
self.ada_layer_num = 12
if self.adapter_layer_ids is not None:
self.ada_layer_num = len(self.adapter_layer_ids)
if self.use_adapter:
from .bert_adapter import BertModel
self.bert = BertModel.from_pretrained(pretrain_path)
self.ment_adapters = AdapterStack(adapter_size, num_hidden_layers=self.ada_layer_num)
self.type_adapters = AdapterStack(adapter_size, num_hidden_layers=self.ada_layer_num)
print("use task specific adapter !!!!!!!!!")
else:
self.bert = AutoModel.from_pretrained(pretrain_path)
for n, p in self.bert.named_parameters():
if "pooler" in n:
p.requires_grad = False
self.tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
self.max_length = max_length
self.last_n_layer = last_n_layer
self.word_encode_choice = word_encode_choice
self.span_encode_choice = span_encode_choice
self.drop = nn.Dropout(drop_p)
self.word_dim = self.bert.config.hidden_size
self.use_att = use_att
self.att_hidden_dim = att_hidden_dim
if use_att:
self.att_layer = nn.Sequential(
nn.Linear(self.word_dim, self.att_hidden_dim),
nn.Tanh(),
nn.Linear(self.att_hidden_dim, 1)
)
if span_encode_choice is None:
span_dim = self.word_dim * 2
print("span representation is [head; tail]")
else:
span_dim = self.word_dim
print("span representation is ", self.span_encode_choice)
self.use_width = use_width
if self.use_width:
self.width_dim = width_dim
self.width_mat = nn.Embedding(50, width_dim)
span_dim = span_dim + width_dim
print("use width embedding")
self.use_case = use_case
if self.use_case:
self.case_dim = case_dim
self.case_mat = nn.Embedding(10, case_dim)
span_dim = span_dim + case_dim
print("use case embedding")
self.span_dim = span_dim
print("word dim is {}, span dim is {}".format(self.word_dim, self.span_dim))
return
def combine(self, seq_hidden, start_inds, end_inds, pooling='avg'):
batch_size, max_seq_len, hidden_dim = seq_hidden.size()
max_span_num = start_inds.size(1)
if pooling == 'first':
embeddings = torch.gather(seq_hidden, 1,
start_inds.unsqueeze(-1).expand(batch_size, max_span_num, hidden_dim))
elif pooling == 'avg':
device = seq_hidden.device
span_len = end_inds - start_inds + 1
max_span_len = torch.max(span_len).item()
subtoken_offset = torch.arange(0, max_span_len).to(device).view(1, 1, max_span_len).expand(batch_size,
max_span_num,
max_span_len)
subtoken_pos = subtoken_offset + start_inds.unsqueeze(-1).expand(batch_size, max_span_num, max_span_len)
subtoken_mask = subtoken_pos.le(end_inds.view(batch_size, max_span_num, 1))
subtoken_pos = subtoken_pos.masked_fill(~subtoken_mask, 0).view(batch_size, max_span_num * max_span_len,
1).expand(batch_size,
max_span_num * max_span_len,
hidden_dim)
embeddings = torch.gather(seq_hidden, 1, subtoken_pos).view(batch_size, max_span_num, max_span_len,
hidden_dim)
embeddings = embeddings * subtoken_mask.unsqueeze(-1)
embeddings = torch.div(embeddings.sum(2), span_len.unsqueeze(-1).float())
elif pooling == 'attavg':
global_weights = self.att_layer(seq_hidden.view(-1, hidden_dim)).view(batch_size, max_seq_len, 1)
seq_hidden_w = torch.cat([seq_hidden, global_weights], dim=2)
device = seq_hidden.device
span_len = end_inds - start_inds + 1
max_span_len = torch.max(span_len).item()
subtoken_offset = torch.arange(0, max_span_len).to(device).view(1, 1, max_span_len).expand(batch_size,
max_span_num,
max_span_len)
subtoken_pos = subtoken_offset + start_inds.unsqueeze(-1).expand(batch_size, max_span_num, max_span_len)
subtoken_mask = subtoken_pos.le(end_inds.view(batch_size, max_span_num, 1))
subtoken_pos = subtoken_pos.masked_fill(~subtoken_mask, 0).view(batch_size, max_span_num * max_span_len,
1).expand(batch_size,
max_span_num * max_span_len,
hidden_dim + 1)
span_w_embeddings = torch.gather(seq_hidden_w, 1, subtoken_pos).view(batch_size, max_span_num, max_span_len,
hidden_dim + 1)
span_w_embeddings = span_w_embeddings * subtoken_mask.unsqueeze(-1)
word_embeddings = span_w_embeddings[:, :, :, :-1]
word_weights = span_w_embeddings[:, :, :, -1].masked_fill(~subtoken_mask, -1e8)
word_weights = F.softmax(word_weights, dim=2).unsqueeze(3)
embeddings = (word_weights * word_embeddings).sum(2)
elif pooling == 'max':
device = seq_hidden.device
span_len = end_inds - start_inds + 1
max_span_len = torch.max(span_len).item()
subtoken_offset = torch.arange(0, max_span_len).to(device).view(1, 1, max_span_len).expand(batch_size,
max_span_num,
max_span_len)
subtoken_pos = subtoken_offset + start_inds.unsqueeze(-1).expand(batch_size, max_span_num, max_span_len)
subtoken_mask = subtoken_pos.le(
end_inds.view(batch_size, max_span_num, 1)) # batch_size, max_span_num, max_span_len
subtoken_pos = subtoken_pos.masked_fill(~subtoken_mask, 0).view(batch_size, max_span_num * max_span_len,
1).expand(batch_size,
max_span_num * max_span_len,
hidden_dim)
embeddings = torch.gather(seq_hidden, 1, subtoken_pos).view(batch_size, max_span_num, max_span_len,
hidden_dim)
embeddings = embeddings.masked_fill(
(~subtoken_mask).unsqueeze(-1).expand(batch_size, max_span_num, max_span_len, hidden_dim), -1e8).max(
dim=2)[0]
else:
raise ValueError('encode choice not in first / avg / max')
return embeddings
def forward(self, words, word_masks, word_to_piece_inds=None, word_to_piece_ends=None, span_indices=None, word_shape_ids=None, mode=None, bottom_hiddens=None):
assert word_to_piece_inds.size(0) == words.size(0)
assert word_to_piece_inds.size(1) <= words.size(1)
assert word_to_piece_ends.size() == word_to_piece_inds.size()
output_bottom_hiddens = None
if (mode is not None) and self.use_adapter:
if mode == 'ment':
outputs = self.bert(words, attention_mask=word_masks, output_hidden_states=True, return_dict=True, adapters=self.ment_adapters, adapter_layer_ids=self.adapter_layer_ids, input_bottom_hiddens=bottom_hiddens)
else:
outputs = self.bert(words, attention_mask=word_masks, output_hidden_states=True, return_dict=True, adapters=self.type_adapters, adapter_layer_ids=self.adapter_layer_ids, input_bottom_hiddens=bottom_hiddens)
if mode == 'ment':
output_bottom_hiddens = outputs['hidden_states'][self.adapter_layer_ids[0]]
else:
outputs = self.bert(words, attention_mask=word_masks, output_hidden_states=True, return_dict=True)
# use the sum of the last 4 layers
last_four_hidden_states = torch.cat(
[hidden_state.unsqueeze(0) for hidden_state in outputs['hidden_states'][self.last_n_layer:]], 0)
del outputs
piece_embeddings = torch.sum(last_four_hidden_states, 0) # [num_sent, number_of_tokens, 768]
_, piece_len, hidden_dim = piece_embeddings.size()
word_embeddings = self.combine(piece_embeddings, word_to_piece_inds, word_to_piece_ends,
self.word_encode_choice)
if span_indices is None:
word_embeddings = self.drop(word_embeddings)
if mode == 'ment':
return word_embeddings, output_bottom_hiddens
return word_embeddings
if word_shape_ids is not None:
assert word_shape_ids.size() == word_to_piece_inds.size()
assert torch.sum(span_indices[:, :, 1].lt(span_indices[:, :, 0])).item() == 0
embeds = []
if self.span_encode_choice is None:
start_word_embeddings = self.combine(word_embeddings, span_indices[:, :, 0], None, 'first')
start_word_embeddings = self.drop(start_word_embeddings)
embeds.append(start_word_embeddings)
end_word_embeddings = self.combine(word_embeddings, span_indices[:, :, 1], None, 'first')
end_word_embeddings = self.drop(end_word_embeddings)
embeds.append(end_word_embeddings)
else:
pool_embeddings = self.combine(word_embeddings, span_indices[:, :, 0], span_indices[:, :, 1],
self.span_encode_choice)
pool_embeddings = self.drop(pool_embeddings)
embeds.append(pool_embeddings)
if self.use_width:
width_embeddings = self.width_mat(span_indices[:, :, 1] - span_indices[:, :, 0])
width_embeddings = self.drop(width_embeddings)
embeds.append(width_embeddings)
if self.use_case:
case_wemb = self.case_mat(word_shape_ids)
case_embeddings = self.combine(case_wemb, span_indices[:, :, 0], span_indices[:, :, 1], 'avg')
case_embeddings = self.drop(case_embeddings)
embeds.append(case_embeddings)
span_embeddings = torch.cat(embeds, dim=-1)
span_embeddings = span_embeddings.view(span_indices.size(0), span_indices.size(1), self.span_dim)
return span_embeddings
def get_word_case(self, token):
if token.isdigit():
tfeat = 0
elif token in string.punctuation:
tfeat = 1
elif token.isupper():
tfeat = 2
elif token[0].isupper():
tfeat = 3
elif token.islower():
tfeat = 4
else:
tfeat = 5
return tfeat
def tokenize_label(self, label_name):
token_ids = self.tokenizer.encode(label_name, add_special_tokens=True, max_length=self.max_length, truncation=True)
mask = np.zeros((self.max_length), dtype=np.int32)
mask[:len(token_ids)] = 1
# padding
while len(token_ids) < self.max_length:
token_ids.append(0)
return token_ids, mask
def tokenize(self, input_tokens, true_token_flags=None):
cur_tokens = ['[CLS]']
cur_word_to_piece_ind = []
cur_word_to_piece_end = []
cur_word_shape_ind = []
raw_tokens_list = []
word_mask_list = []
indexed_tokens_list = []
word_to_piece_ind_list = []
word_to_piece_end_list = []
word_shape_ind_list = []
seq_len = []
word_flag = True
for i in range(len(input_tokens)):
word = input_tokens[i]
if true_token_flags is None:
word_flag = True
else:
word_flag = true_token_flags[i]
word_tokens = self.tokenizer.tokenize(word)
if len(word_tokens) == 0:
word_tokens = ['[UNK]']
if len(cur_tokens) + len(word_tokens) + 2 > self.max_length:
raw_tokens_list.append(cur_tokens + ['[SEP]'])
word_to_piece_ind_list.append(cur_word_to_piece_ind)
word_to_piece_end_list.append(cur_word_to_piece_end)
word_shape_ind_list.append(cur_word_shape_ind)
seq_len.append(len(cur_word_to_piece_ind))
cur_tokens = ['[CLS]'] + word_tokens
if word_flag:
cur_word_to_piece_ind = [1]
cur_word_to_piece_end = [len(cur_tokens) - 1]
cur_word_shape_ind = [self.get_word_case(word)]
else:
cur_word_to_piece_ind = []
cur_word_to_piece_end = []
cur_word_shape_ind = []
else:
if word_flag:
cur_word_to_piece_ind.append(len(cur_tokens))
cur_tokens.extend(word_tokens)
cur_word_to_piece_end.append(len(cur_tokens) - 1)
cur_word_shape_ind.append(self.get_word_case(word))
else:
cur_tokens.extend(word_tokens)
if len(cur_tokens):
assert len(cur_tokens) < self.max_length
raw_tokens_list.append(cur_tokens + ['[SEP]'])
word_to_piece_ind_list.append(cur_word_to_piece_ind)
word_to_piece_end_list.append(cur_word_to_piece_end)
word_shape_ind_list.append(cur_word_shape_ind)
seq_len.append(len(cur_word_to_piece_ind))
assert seq_len == [len(x) for x in word_to_piece_ind_list]
if true_token_flags is None:
assert sum(seq_len) == len(input_tokens)
assert len(raw_tokens_list) == len(word_to_piece_ind_list)
assert len(raw_tokens_list) == len(word_to_piece_end_list)
for raw_tokens in raw_tokens_list:
indexed_tokens = self.tokenizer.convert_tokens_to_ids(raw_tokens)
# padding
while len(indexed_tokens) < self.max_length:
indexed_tokens.append(0)
indexed_tokens_list.append(indexed_tokens)
if len(indexed_tokens) != self.max_length:
print(input_tokens)
print(raw_tokens)
raise ValueError
assert len(indexed_tokens) == self.max_length
# mask
mask = np.zeros((self.max_length), dtype=np.int32)
mask[:len(raw_tokens)] = 1
word_mask_list.append(mask)
sent_num = len(indexed_tokens_list)
assert sent_num == len(word_mask_list)
assert sent_num == len(word_to_piece_ind_list)
assert sent_num == len(seq_len)
return indexed_tokens_list, word_mask_list, word_to_piece_ind_list, word_to_piece_end_list, word_shape_ind_list, seq_len

Просмотреть файл

@ -0,0 +1,477 @@
import json
import torch
import torch.utils.data as data
import os
from .fewshotsampler import FewshotSampler, DebugSampler
from .span_sample import SpanSample
import numpy as np
import random
from collections import defaultdict
import logging, pickle, gzip
from tqdm import tqdm
class QuerySpanSample(SpanSample):
def __init__(self, index, filelines, ment_prob_list, bio=True):
super(QuerySpanSample, self).__init__(index, filelines, bio)
self.query_ments = []
self.query_probs = []
for i in range(len(ment_prob_list)):
self.query_ments.append(ment_prob_list[i][0])
self.query_probs.append(ment_prob_list[i][1])
return
class QuerySpanBatcher:
def __init__(self, iou_thred, use_oproto):
self.iou_thred = iou_thred
self.use_oproto = use_oproto
return
def _get_span_weights(self, batch_triples, seq_len, span_indices, span_tags, thred=0.6, alpha=1):
span_weights = []
for k in range(len(span_indices)):
seq_tags = np.zeros(seq_len[k], dtype=np.int)
span_st_inds = np.zeros(seq_len[k], dtype=np.int)
span_ed_inds = np.zeros(seq_len[k], dtype=np.int)
span_weights.append([1] * len(span_indices[k]))
for [r, i, j] in batch_triples[k]:
seq_tags[i:j + 1] = r
span_st_inds[i:j + 1] = i
span_ed_inds[i:j + 1] = j
for sp_idx in range(len(span_indices[k])):
sp_st = span_indices[k][sp_idx][0]
sp_ed = span_indices[k][sp_idx][1]
sp_tag = span_tags[k][sp_idx]
if sp_tag != 0:
continue
cur_token_tags = list(seq_tags[sp_st: sp_ed + 1])
max_tag = max(cur_token_tags, key=cur_token_tags.count)
anchor_idx = cur_token_tags.index(max_tag) + sp_st
if max_tag == 0:
continue
cur_ids = set(range(sp_st, sp_ed + 1))
ref_ids = set(range(span_st_inds[anchor_idx], span_ed_inds[anchor_idx] + 1))
tag_percent = len(cur_ids & ref_ids) / len(cur_ids | ref_ids)
if tag_percent >= thred:
span_tags[k][sp_idx] = max_tag
span_weights[k][sp_idx] = tag_percent ** alpha
else:
span_weights[k][sp_idx] = (1 - tag_percent) ** alpha
return span_tags, span_weights
@staticmethod
def _make_golden_batch(batch_triples, seq_len):
span_indices = []
span_tags = []
for k in range(len(batch_triples)):
per_sp_indices = []
per_sp_tags = []
for (r, i, j) in batch_triples[k]:
if j < seq_len[k]:
per_sp_indices.append([i, j])
per_sp_tags.append(r)
else:
print("something error")
span_indices.append(per_sp_indices)
span_tags.append(per_sp_tags)
return span_indices, span_tags
@staticmethod
def _make_train_query_batch(query_ments, query_probs, golden_triples, seq_len):
span_indices = []
ment_probs = []
span_tags = []
for k in range(len(query_ments)):
per_sp_indices = []
per_sp_tags = []
per_ment_probs = []
per_tag_mp = {(i, j): tag for tag, i, j in golden_triples[k]}
for [i, j], prob in zip(query_ments[k], query_probs[k]):
if j < seq_len[k]:
per_sp_indices.append([i, j])
per_sp_tags.append(per_tag_mp.get((i, j), 0))
per_ment_probs.append(prob)
else:
print("something error")
#add golden mentions into query batch
for [r, i, j] in golden_triples[k]:
if [i, j] not in per_sp_indices:
per_sp_indices.append([i, j])
per_sp_tags.append(r)
per_ment_probs.append(0)
span_indices.append(per_sp_indices)
span_tags.append(per_sp_tags)
ment_probs.append(per_ment_probs)
return span_indices, span_tags, ment_probs
@staticmethod
def _make_test_query_batch(query_ments, query_probs, golden_triples, seq_len):
span_indices = []
ment_probs = []
span_tags = []
for k in range(len(query_ments)):
per_sp_indices = []
per_sp_tags = []
per_ment_probs = []
per_tag_mp = {(i, j): tag for tag, i, j in golden_triples[k]}
for [i, j], prob in zip(query_ments[k], query_probs[k]):
if j < seq_len[k]:
per_sp_indices.append([i, j])
per_sp_tags.append(per_tag_mp.get((i, j), 0))
per_ment_probs.append(prob)
else:
print("something error")
span_indices.append(per_sp_indices)
span_tags.append(per_sp_tags)
ment_probs.append(per_ment_probs)
return span_indices, span_tags, ment_probs
def make_support_batch(self, batch):
span_indices, span_tags = self._make_golden_batch(batch["spans"], batch["seq_len"])
if self.use_oproto:
used_token_flag = []
for k in range(len(batch["spans"])):
used_token_flag.append(np.zeros(batch["seq_len"][k]))
for [r, sp_st, sp_ed] in batch["spans"][k]:
used_token_flag[k][sp_st: sp_ed + 1] = 1
for k in range(len(batch["seq_len"])):
for j in range(batch["seq_len"][k]):
if used_token_flag[k][j] == 1:
continue
if [j, j] not in span_indices[k]:
span_indices[k].append([j, j])
span_tags[k].append(0)
else:
print("something error")
span_nums = [len(x) for x in span_indices]
max_n_spans = max(span_nums)
span_masks = np.zeros(shape=(len(span_indices), max_n_spans), dtype='int')
for k in range(len(span_indices)):
span_masks[k, :span_nums[k]] = 1
while len(span_tags[k]) < max_n_spans:
span_indices[k].append([0, 0])
span_tags[k].append(-1)
batch["span_indices"] = span_indices
batch["span_mask"] = span_masks
batch["span_tag"] = span_tags
# all example with equal weights
batch["span_weights"] = np.ones(shape=span_masks.shape, dtype='float')
return batch
def make_query_batch(self, batch, is_train):
if is_train:
span_indices, span_tags, ment_probs = self._make_train_query_batch(batch["query_ments"],
batch["query_probs"], batch["spans"],
batch["seq_len"])
if self.iou_thred is None:
span_weights = None
else:
span_tags, span_weights = self._get_span_weights(batch["spans"], batch["seq_len"], span_indices,
span_tags, thred=self.iou_thred)
else:
span_indices, span_tags, ment_probs = self._make_test_query_batch(batch["query_ments"],
batch["query_probs"], batch["spans"],
batch["seq_len"])
span_weights = None
span_nums = [len(x) for x in span_indices]
max_n_spans = max(span_nums)
span_masks = np.zeros(shape=(len(span_indices), max_n_spans), dtype='int')
new_span_weights = np.ones(shape=span_masks.shape, dtype='float')
for k in range(len(span_indices)):
span_masks[k, :span_nums[k]] = 1
if span_weights is not None:
new_span_weights[k, :span_nums[k]] = span_weights[k]
while len(span_tags[k]) < max_n_spans:
span_indices[k].append([0, 0])
span_tags[k].append(-1)
ment_probs[k].append(-100)
batch["span_indices"] = span_indices
batch["span_mask"] = span_masks
batch["span_tag"] = span_tags
batch["span_probs"] = ment_probs
batch["span_weights"] = new_span_weights
return batch
def batchnize_episode(self, data, mode):
support_sets, query_sets = zip(*data)
if mode == "train":
is_train = True
else:
is_train = False
batch_support = self.batchnize_sent(support_sets, "support", is_train)
batch_query = self.batchnize_sent(query_sets, "query", is_train)
batch_query["label2tag"] = []
for i in range(len(query_sets)):
batch_query["label2tag"].append(query_sets[i]["label2tag"])
return {"support": batch_support, "query": batch_query}
def batchnize_sent(self, data, mode, is_train):
batch = {"index": [], "word": [], "word_mask": [], "word_to_piece_ind": [], "word_to_piece_end": [],
"seq_len": [],
"spans": [], "sentence_num": [], "query_ments": [], "query_probs": [], "subsentence_num": [],
'split_words': []}
for i in range(len(data)):
for k in batch.keys():
if k == 'sentence_num':
batch[k].append(data[i][k])
else:
batch[k] += data[i][k]
max_n_piece = max([sum(x) for x in batch['word_mask']])
max_n_words = max(batch["seq_len"])
word_to_piece_ind = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
word_to_piece_end = np.zeros(shape=(len(batch["seq_len"]), max_n_words))
for k, slen in enumerate(batch['seq_len']):
assert len(batch['word_to_piece_ind'][k]) == slen
assert len(batch['word_to_piece_end'][k]) == slen
word_to_piece_ind[k, :slen] = batch['word_to_piece_ind'][k]
word_to_piece_end[k, :slen] = batch['word_to_piece_end'][k]
batch['word_to_piece_ind'] = word_to_piece_ind
batch['word_to_piece_end'] = word_to_piece_end
if mode == "support":
batch = self.make_support_batch(batch)
else:
batch = self.make_query_batch(batch, is_train)
for k, v in batch.items():
if k not in ['spans', 'sentence_num', 'label2tag', 'index', "query_ments", "query_probs", "subsentence_num",
"split_words"]:
v = np.array(v)
if k == "span_weights":
batch[k] = torch.tensor(v).float()
else:
batch[k] = torch.tensor(v).long()
batch['word'] = batch['word'][:, :max_n_piece]
batch['word_mask'] = batch['word_mask'][:, :max_n_piece]
return batch
def __call__(self, batch, mode):
return self.batchnize_episode(batch, mode)
class QuerySpanNERDataset(data.Dataset):
"""
Fewshot NER Dataset
"""
def __init__(self, filepath, encoder, N, K, Q, max_length, \
bio=True, debug_file=None, query_file=None, hidden_query_label=False, labelname_fn=None):
if not os.path.exists(filepath):
print("[ERROR] Data file does not exist!")
assert (0)
self.class2sampleid = {}
self.N = N
self.K = K
self.Q = Q
self.encoder = encoder
self.cur_t = 0
self.samples, self.classes = self.__load_data_from_file__(filepath, bio)
if labelname_fn:
self.class2name = self.__load_names__(labelname_fn)
else:
self.class2name = None
self.sampler = FewshotSampler(N, K, Q, self.samples, classes=self.classes)
if debug_file:
if query_file is None:
print("use golden mention for typing !!! input_fn: {}".format(filepath))
self.sampler = DebugSampler(debug_file, query_file)
self.max_length = max_length
self.tag2label = None
self.label2tag = None
self.hidden_query_label = hidden_query_label
if self.hidden_query_label:
print("Attention! This dataset hidden query set labels !")
def __insert_sample__(self, index, sample_classes):
for item in sample_classes:
if item in self.class2sampleid:
self.class2sampleid[item].append(index)
else:
self.class2sampleid[item] = [index]
return
def __load_names__(self, label_mp_fn):
class2str = {}
with open(label_mp_fn, mode="r", encoding="utf-8") as fp:
for line in fp:
label, lstr = line.strip("\n").split(":")
class2str[label.strip()] = lstr.strip()
return class2str
def __load_data_from_file__(self, filepath, bio):
print("load input text from {}".format(filepath))
samples = []
classes = []
with open(filepath, 'r', encoding='utf-8')as f:
lines = f.readlines()
samplelines = []
index = 0
for line in lines:
line = line.strip("\n")
if len(line):
samplelines.append(line)
else:
cur_ments = []
sample = QuerySpanSample(index, samplelines, cur_ments, bio)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
samplelines = []
index += 1
if len(samplelines):
cur_ments = []
sample = QuerySpanSample(index, samplelines, cur_ments, bio)
samples.append(sample)
sample_classes = sample.get_tag_class()
self.__insert_sample__(index, sample_classes)
classes += sample_classes
classes = list(set(classes))
max_span_len = -1
long_ent_num = 0
tot_ent_num = 0
tot_tok_num = 0
for sample in samples:
max_span_len = max(max_span_len, sample.get_max_ent_len())
long_ent_num += sample.get_num_of_long_ent(10)
tot_ent_num += len(sample.spans)
tot_tok_num += len(sample.words)
print("Sentence num {}, token num {}, entity num {} in file {}".format(len(samples), tot_tok_num, tot_ent_num,
filepath))
print("Total classes {}: {}".format(len(classes), str(classes)))
print("The max golden entity len in the dataset is ", max_span_len)
print("The max golden entity len in the dataset is greater than 10", long_ent_num)
print("The total coverage of spans: {:.5f}".format(1 - long_ent_num / tot_ent_num))
return samples, classes
def __getraw__(self, sample, add_split):
word, mask, word_to_piece_ind, word_to_piece_end, word_shape_inds, seq_lens = self.encoder.tokenize(sample.words, true_token_flags=None)
sent_st_id = 0
split_spans = []
split_querys = []
split_probs = []
split_words = []
cur_wid = 0
for k in range(len(seq_lens)):
split_words.append(sample.words[cur_wid: cur_wid + seq_lens[k]])
cur_wid += seq_lens[k]
for cur_len in seq_lens:
sent_ed_id = sent_st_id + cur_len
split_spans.append([])
for tag, span_st, span_ed in sample.spans:
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
continue
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, span_ed - sent_st_id])
elif add_split:
if span_st >= sent_st_id:
split_spans[-1].append([self.tag2label[tag], span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
else:
split_spans[-1].append([self.tag2label[tag], 0, span_ed - sent_st_id])
split_querys.append([])
split_probs.append([])
for [span_st, span_ed], span_prob in zip(sample.query_ments, sample.query_probs):
if (span_st >= sent_ed_id) or (span_ed < sent_st_id): # span totally not in subsent
continue
if (span_st >= sent_st_id) and (span_ed < sent_ed_id): # span totally in subsent
split_querys[-1].append([span_st - sent_st_id, span_ed - sent_st_id])
split_probs[-1].append(span_prob)
elif add_split:
if span_st >= sent_st_id:
split_querys[-1].append([span_st - sent_st_id, sent_ed_id - 1 - sent_st_id])
split_probs[-1].append(span_prob)
else:
split_querys[-1].append([0, span_ed - sent_st_id])
split_probs[-1].append(span_prob)
sent_st_id += cur_len
item = {"word": word, "word_mask": mask, "word_to_piece_ind": word_to_piece_ind,
"word_to_piece_end": word_to_piece_end, "spans": split_spans, "query_ments": split_querys, "query_probs": split_probs, "seq_len": seq_lens,
"subsentence_num": len(seq_lens), "split_words": split_words}
return item
def __additem__(self, index, d, item):
d['index'].append(index)
d['word'] += item['word']
d['word_mask'] += item['word_mask']
d['seq_len'] += item['seq_len']
d['word_to_piece_ind'] += item['word_to_piece_ind']
d['word_to_piece_end'] += item['word_to_piece_end']
d['spans'] += item['spans']
d['query_ments'] += item['query_ments']
d['query_probs'] += item['query_probs']
d['subsentence_num'].append(item['subsentence_num'])
d['split_words'] += item['split_words']
return
def __populate__(self, idx_list, query_ment_mp=None, savelabeldic=False, add_split=False):
dataset = {'index': [], 'word': [], 'word_mask': [], 'spans': [], 'word_to_piece_ind': [],
"word_to_piece_end": [], "seq_len": [], "query_ments": [], "query_probs": [], "subsentence_num": [],
'split_words': []}
for idx in idx_list:
if query_ment_mp is not None:
self.samples[idx].query_ments = [x[0] for x in query_ment_mp[str(self.samples[idx].index)]]
self.samples[idx].query_probs = [x[1] for x in query_ment_mp[str(self.samples[idx].index)]]
else:
self.samples[idx].query_ments = [[x[1], x[2]] for x in self.samples[idx].spans]
self.samples[idx].query_probs = [1 for x in self.samples[idx].spans]
item = self.__getraw__(self.samples[idx], add_split)
self.__additem__(idx, dataset, item)
dataset['sentence_num'] = len(dataset["seq_len"])
assert len(dataset['word']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_ind']) == len(dataset['seq_len'])
assert len(dataset['word_to_piece_end']) == len(dataset['seq_len'])
assert len(dataset['spans']) == len(dataset['seq_len'])
if savelabeldic:
dataset['label2tag'] = self.label2tag
return dataset
def __getitem__(self, index):
tmp = self.sampler.__next__(index)
if len(tmp) == 3:
target_classes, support_idx, query_idx = tmp
query_ment_mp = None
else:
target_classes, support_idx, query_idx, query_ment_mp = tmp
target_classes = sorted(target_classes)
distinct_tags = ['O'] + target_classes
self.tag2label = {tag: idx for idx, tag in enumerate(distinct_tags)}
self.label2tag = {idx: tag for idx, tag in enumerate(distinct_tags)}
support_set = self.__populate__(support_idx, add_split=True)
query_set = self.__populate__(query_idx, query_ment_mp=query_ment_mp, add_split=True, savelabeldic=True)
if self.hidden_query_label:
query_set['spans'] = [[] for i in range(query_set['sentence_num'])]
return support_set, query_set
def __len__(self):
return 1000000
def batch_to_device(self, batch, device):
for k, v in batch.items():
if k in ['sentence_num', 'label2tag', 'spans', 'index', 'query_ments', 'query_probs', "subsentence_num", 'split_words']:
continue
batch[k] = v.to(device)
return batch
def get_query_loader(filepath, mode, encoder, N, K, Q, batch_size, max_length,
bio, shuffle, num_workers=8, debug_file=None, query_file=None,
iou_thred=None, hidden_query_label=False, label_fn=None, use_oproto=False):
batcher = QuerySpanBatcher(iou_thred, use_oproto)
dataset = QuerySpanNERDataset(filepath, encoder, N, K, Q, max_length, bio,
debug_file=debug_file, query_file=query_file,
hidden_query_label=hidden_query_label, labelname_fn=label_fn)
dataloader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
pin_memory=True,
num_workers=num_workers,
collate_fn=lambda x: batcher(x, mode))
return dataloader

Просмотреть файл

@ -0,0 +1,145 @@
from .fewshotsampler import FewshotSampleBase
def get_class_name(rawtag):
if rawtag.startswith('B-') or rawtag.startswith('I-'):
return rawtag[2:]
else:
return rawtag
def convert_bio2spans(tags, schema):
spans = []
cur_span = []
err_cnt = 0
for i in range(len(tags)):
if schema == 'BIO':
if tags[i].startswith("B-") or tags[i] == 'O':
if len(cur_span):
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = []
if tags[i].startswith("B-"):
cur_span.append(tags[i][2:])
cur_span.append(i)
elif tags[i].startswith("I-"):
if len(cur_span) == 0:
cur_span = [tags[i][2:], i]
err_cnt += 1
if cur_span[0] != tags[i][2:]:
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = [tags[i][2:], i]
err_cnt += 1
assert cur_span[0] == tags[i][2:]
else:
assert tags[i] == "O"
elif schema == 'IO':
if tags[i] == "O":
if len(cur_span):
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = []
elif (i == 0) or (tags[i] != tags[i - 1]):
if len(cur_span):
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = []
cur_span.append(tags[i].strip("I-"))
cur_span.append(i)
else:
assert cur_span[0] == tags[i].strip("I-")
elif schema == "BIOES":
if tags[i] == "O":
if len(cur_span):
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = []
elif tags[i][0] == "S":
if len(cur_span):
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = []
spans.append([tags[i][2:], i, i])
elif tags[i][0] == "E":
if len(cur_span) == 0:
err_cnt += 1
continue
cur_span.append(i)
spans.append(cur_span)
cur_span = []
elif tags[i][0] == "B":
if len(cur_span):
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = []
cur_span = [tags[i][2:], i]
else:
if len(cur_span) == 0:
cur_span = [tags[i][2:], i]
err_cnt += 1
continue
if cur_span[0] != tags[i][2:]:
cur_span.append(i - 1)
spans.append(cur_span)
cur_span = [tags[i][2:], i]
err_cnt += 1
continue
else:
raise ValueError
if len(cur_span):
cur_span.append(len(tags) - 1)
spans.append(cur_span)
return spans
class SpanSample(FewshotSampleBase):
def __init__(self, idx, filelines, bio=True):
super(SpanSample, self).__init__()
self.index = idx
filelines = [line.split('\t') for line in filelines]
if len(filelines[0]) == 2:
self.words, self.tags = zip(*filelines)
else:
self.words, self.postags, self.tags = zip(*filelines)
self.spans = convert_bio2spans(self.tags, "BIO" if bio else "IO")
return
def get_max_ent_len(self):
max_len = -1
for sp in self.spans:
max_len = max(max_len, sp[2] - sp[1] + 1)
return max_len
def get_num_of_long_ent(self, max_span_len):
cnt = 0
for sp in self.spans:
if sp[2] - sp[1] + 1 > max_span_len:
cnt += 1
return cnt
def __count_entities__(self):
self.class_count = {}
for tag, i, j in self.spans:
if tag in self.class_count:
self.class_count[tag] += 1
else:
self.class_count[tag] = 1
return
def get_class_count(self):
if self.class_count:
return self.class_count
else:
self.__count_entities__()
return self.class_count
def get_tag_class(self):
tag_class = list(set(map(lambda x: x[2:] if x[:2] in ['B-', 'I-', 'E-', 'S-'] else x, self.tags)))
if 'O' in tag_class:
tag_class.remove('O')
return tag_class
def valid(self, target_classes):
return (set(self.get_class_count().keys()).intersection(set(target_classes))) and \
not (set(self.get_class_count().keys()).difference(set(target_classes)))
def __str__(self):
return