metrics/tm_examples/bert_score-own_model.py

135 строки
4.6 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example of how to use BERTScore with a user's defined/own model and tokenizer.
To run: python bert_score-own_model.py
"""
from pprint import pprint
from typing import Dict, List, Union
import torch
import torch.nn as nn
from torchmetrics import BERTScore
_NUM_LAYERS = 2
_MODEL_DIM = 4
_NHEAD = 2
_MAX_LEN = 6
class UserTokenizer:
"""The `UserTokenizer` class is required to be defined when a non-default model (i.e. not one from
`transformers`) is used.
The user's defined tokenizer is expected to return either token IDs or token embeddings that are fed into the model.
The tokenizer vocabulary should contain some special tokens, such as a `<pad>` token so that a tokenization will run
successfully in batches.
"""
CLS_TOKEN = "<cls>"
SEP_TOKEN = "<sep>"
PAD_TOKEN = "<pad>"
def __init__(self) -> None:
self.word2vec = {
"hello": 0.5 * torch.ones(1, _MODEL_DIM),
"world": -0.5 * torch.ones(1, _MODEL_DIM),
self.CLS_TOKEN: torch.zeros(1, _MODEL_DIM),
self.SEP_TOKEN: torch.zeros(1, _MODEL_DIM),
self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM),
}
def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, torch.Tensor]:
"""The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method
should obey the input/output arguments structure described below.
Args:
sentences:
Input text. `Union[str, List[str]]`
max_len:
Maximum length of pre-processed text. `int`
Return:
Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding `torch.Tensor`
values.
"""
output_dict: Dict[str, torch.Tensor] = {}
if isinstance(sentences, str):
sentences = [sentences]
# Add special tokens
sentences = [" ".join([self.CLS_TOKEN, sentence, self.SEP_TOKEN]) for sentence in sentences]
# Tokennize sentence
tokenized_sentences = [
sentence.lower().split()[:max_len] + [self.PAD_TOKEN] * (max_len - len(sentence.lower().split()))
for sentence in sentences
]
output_dict["input_ids"] = torch.cat(
[torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences]
)
output_dict["attention_mask"] = torch.cat(
[
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
for sentence in tokenized_sentences
]
).long()
return output_dict
def get_user_model_encoder(
num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD
) -> torch.nn.Module:
"""Initialize the Transformer encoder."""
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
return transformer_encoder
def user_forward_fn(model: torch.nn.Module, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""User forward function used for the computation of model embeddings.
This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the
input/output argument structure described below.
Args:
model:
`torch.nn.Module`
batch:
`Dict[str, torch.Tensor]`
Return:
The model output. `torch.Tensor`
"""
return model(batch["input_ids"])
_PREDS = ["hello", "hello world", "world world world"]
_REFS = ["hello", "hello hello", "hello world hello"]
if __name__ == "__main__":
tokenizer = UserTokenizer()
model = get_user_model_encoder()
Scorer = BERTScore(
model=model, user_tokenizer=tokenizer, user_forward_fn=user_forward_fn, max_length=_MAX_LEN, return_hash=False
)
Scorer.update(_PREDS, _REFS)
print("Predictions")
pprint(Scorer.predictions)
pprint(Scorer.compute())