Implement own BERT score (#473)
* Start adding own BERTScore implementation * Prepare the basic backbone for own BERTScore * Make working BERTScore with all_layers=True + init bert_score in torchmetrics * Fix cuda device placing * Use IDF only if asked * Add data collators * Remove old parts of code * Fix IDF rescaling and add new tests + fix type * Add new tests for IDF * Add explicit dtypes when used torch.cat to make this work with older PT versions * Adjust code to work with the DDP plus some changes * Add a support for the DDP mode and add the corresponding test * Add a suport for verbose using tqdm loader * Fix a bug with tokenizer and add hash_code * Fix transformers import * Add support for the user's own model * Add the corresponding example * Fix error raised by default tokenizer * Add support for the rescale with baseline * Clean some code + add some docstirngs * Run bert-ddp tests only if torch.distributed.is_available() * Use smaller model, 'albert-base-v2', for testing because of OOM issues * Set join=False for mp_spawn in the ddp_test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Родитель
e6ad813867
Коммит
7b093813e2
|
@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Removed `jiwer` as dependency for text package ([#446](https://github.com/PyTorchLightning/metrics/pull/446))
|
||||
|
||||
|
||||
- Removed `bert-score` as dependency for text package ([#473](https://github.com/PyTorchLightning/metrics/pull/473))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed ranking of samples in `SpearmanCorrCoef` metric ([#448](https://github.com/PyTorchLightning/metrics/pull/448))
|
||||
|
|
|
@ -13,6 +13,9 @@ include LICENSE
|
|||
# Include marker file for PEP 561
|
||||
include torchmetrics/py.typed
|
||||
|
||||
# Include examples
|
||||
recursive-include tm_examples *.py
|
||||
|
||||
exclude *.sh
|
||||
exclude *.toml
|
||||
exclude *.svg
|
||||
|
|
|
@ -14,6 +14,7 @@ known_first_party = [
|
|||
"torchmetrics",
|
||||
"tests",
|
||||
"integrations",
|
||||
"tm_examples",
|
||||
]
|
||||
skip_glob = []
|
||||
profile = "black"
|
||||
|
|
|
@ -29,3 +29,5 @@ speechmetrics @ https://github.com/aliutkus/speechmetrics/archive/refs/heads/mas
|
|||
# text
|
||||
jiwer>=2.2.0
|
||||
rouge-score>=0.0.4
|
||||
bert_score==0.3.10
|
||||
transformers>=4.0
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
nltk>=3.6
|
||||
bert-score==0.3.10
|
||||
tqdm>=4.41.0
|
||||
|
|
|
@ -1,12 +1,21 @@
|
|||
from typing import Any
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from torchmetrics.functional import bert_score
|
||||
from torchmetrics.functional import bert_score as metrics_bert_score
|
||||
from torchmetrics.text import BERTScore
|
||||
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE
|
||||
|
||||
if _BERTSCORE_AVAILABLE:
|
||||
from bert_score import score as original_bert_score
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "1"
|
||||
|
||||
# Examples and expected values taken from:
|
||||
# https://github.com/Tiiiger/bert_score/blob/master/tests/test_scorer.py
|
||||
preds = [
|
||||
|
@ -25,11 +34,22 @@ refs = [
|
|||
]
|
||||
|
||||
|
||||
_METRICS = ["precision", "recall", "f1"]
|
||||
|
||||
MODEL_NAME = "albert-base-v2"
|
||||
|
||||
|
||||
def _assert_list(preds: Any, refs: Any, threshold: float = 1e-8):
|
||||
"""Assert two lists are equal."""
|
||||
assert np.allclose(preds, refs, atol=threshold, equal_nan=True)
|
||||
|
||||
|
||||
def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]:
|
||||
"""Parse the BERT score returned by the original `bert-score` package."""
|
||||
score_dict = {metric: value.tolist() for metric, value in zip(_METRICS, score)}
|
||||
return score_dict
|
||||
|
||||
|
||||
preds_batched = [preds[0:2], preds[2:]]
|
||||
refs_batched = [refs[0:2], refs[2:]]
|
||||
|
||||
|
@ -41,10 +61,137 @@ refs_batched = [refs[0:2], refs[2:]]
|
|||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_fn(preds, refs):
|
||||
"""Tests for functional."""
|
||||
Score = bert_score(preds, refs, model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
|
||||
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
|
||||
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
|
||||
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
metrics_score = metrics_bert_score(
|
||||
preds, refs, model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3
|
||||
)
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_fn_with_idf(preds, refs):
|
||||
"""Tests for functional with IDF rescaling."""
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=12, idf=True, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
metrics_score = metrics_bert_score(
|
||||
preds, refs, model_name_or_path=MODEL_NAME, num_layers=12, idf=True, batch_size=3
|
||||
)
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_fn_all_layers(preds, refs):
|
||||
"""Tests for functional and all layers."""
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
metrics_score = metrics_bert_score(
|
||||
preds, refs, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3
|
||||
)
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_fn_all_layers_with_idf(preds, refs):
|
||||
"""Tests for functional and all layers with IDF rescaling."""
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
metrics_score = metrics_bert_score(
|
||||
preds, refs, model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3
|
||||
)
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_fn_all_layers_rescale_with_baseline(preds, refs):
|
||||
"""Tests for functional with baseline rescaling."""
|
||||
original_score = original_bert_score(
|
||||
preds,
|
||||
refs,
|
||||
model_type=MODEL_NAME,
|
||||
lang="en",
|
||||
num_layers=8,
|
||||
idf=False,
|
||||
batch_size=3,
|
||||
rescale_with_baseline=True,
|
||||
)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
metrics_score = metrics_bert_score(
|
||||
preds,
|
||||
refs,
|
||||
model_name_or_path=MODEL_NAME,
|
||||
lang="en",
|
||||
num_layers=8,
|
||||
idf=False,
|
||||
batch_size=3,
|
||||
rescale_with_baseline=True,
|
||||
)
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_fn_rescale_with_baseline(preds, refs):
|
||||
"""Tests for functional with baseline rescaling with all layers."""
|
||||
original_score = original_bert_score(
|
||||
preds,
|
||||
refs,
|
||||
model_type=MODEL_NAME,
|
||||
lang="en",
|
||||
all_layers=True,
|
||||
idf=False,
|
||||
batch_size=3,
|
||||
rescale_with_baseline=True,
|
||||
)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
metrics_score = metrics_bert_score(
|
||||
preds,
|
||||
refs,
|
||||
model_name_or_path=MODEL_NAME,
|
||||
lang="en",
|
||||
all_layers=True,
|
||||
idf=False,
|
||||
batch_size=3,
|
||||
rescale_with_baseline=True,
|
||||
)
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -54,12 +201,69 @@ def test_score_fn(preds, refs):
|
|||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score(preds, refs):
|
||||
"""Tests for metric."""
|
||||
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
|
||||
Scorer.update(predictions=preds, references=refs)
|
||||
Score = Scorer.compute()
|
||||
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
|
||||
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
|
||||
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])
|
||||
metrics_score = Scorer.compute()
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_with_idf(preds, refs):
|
||||
"""Tests for metric with IDF rescaling."""
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=True, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=True, batch_size=3)
|
||||
Scorer.update(predictions=preds, references=refs)
|
||||
metrics_score = Scorer.compute()
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_all_layers(preds, refs):
|
||||
"""Tests for metric and all layers."""
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3)
|
||||
Scorer.update(predictions=preds, references=refs)
|
||||
metrics_score = Scorer.compute()
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_score_all_layers_with_idf(preds, refs):
|
||||
"""Tests for metric and all layers with IDF rescaling."""
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
|
||||
Scorer.update(predictions=preds, references=refs)
|
||||
metrics_score = Scorer.compute()
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -69,10 +273,46 @@ def test_score(preds, refs):
|
|||
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
|
||||
def test_accumulation(preds, refs):
|
||||
"""Tests for metric works with accumulation."""
|
||||
Scorer = BERTScore(model_type="roberta-large", num_layers=17, idf=False, batch_size=3)
|
||||
original_score = original_bert_score(
|
||||
sum(preds, []), sum(refs, []), model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3
|
||||
)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
|
||||
Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
|
||||
for p, r in zip(preds, refs):
|
||||
Scorer.update(predictions=p, references=r)
|
||||
Score = Scorer.compute()
|
||||
_assert_list(Score["precision"], [0.9843302369117737, 0.9832239747047424, 0.9120386242866516])
|
||||
_assert_list(Score["recall"], [0.9823839068412781, 0.9732863903045654, 0.920428991317749])
|
||||
_assert_list(Score["f1"], [0.9833561182022095, 0.9782299995422363, 0.916214644908905])
|
||||
metrics_score = Scorer.compute()
|
||||
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
|
||||
|
||||
def _bert_score_ddp(rank, world_size, preds, refs, original_score):
|
||||
"""Define a DDP process for BERTScore."""
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3, max_length=128)
|
||||
Scorer.update(preds, refs)
|
||||
metrics_score = Scorer.compute()
|
||||
for metric in _METRICS:
|
||||
_assert_list(metrics_score[metric], original_score[metric])
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _test_score_ddp_fn(rank, world_size, preds, refs):
|
||||
"""Core functionality for the `test_score_ddp` test."""
|
||||
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
|
||||
original_score = _parse_original_bert_score(original_score)
|
||||
_bert_score_ddp(rank, world_size, preds, refs, original_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"preds,refs",
|
||||
[(preds, refs)],
|
||||
)
|
||||
@pytest.mark.skipif(not (_BERTSCORE_AVAILABLE and dist.is_available()), reason="test requires bert_score")
|
||||
def test_score_ddp(preds, refs):
|
||||
"""Tests for metric using DDP."""
|
||||
world_size = 2
|
||||
mp.spawn(_test_score_ddp_fn, args=(world_size, preds, refs), nprocs=world_size, join=False)
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An example of how to use BERTScore with a user's defined/own model and tokenizer.
|
||||
|
||||
To run: python bert_score-own_model.py
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torchmetrics import BERTScore
|
||||
|
||||
_NUM_LAYERS = 2
|
||||
_MODEL_DIM = 4
|
||||
_NHEAD = 2
|
||||
_MAX_LEN = 6
|
||||
|
||||
|
||||
class UserTokenizer:
|
||||
"""The `UserTokenizer` class is required to be defined when a non-default model (i.e. not one from
|
||||
`transformers`) is used.
|
||||
|
||||
The user's defined tokenizer is expected to return either token IDs or token embeddings that are fed into the model.
|
||||
The tokenizer vocabulary should contain some special tokens, such as a `<pad>` token so that a tokenization will run
|
||||
successfully in batches.
|
||||
"""
|
||||
|
||||
CLS_TOKEN = "<cls>"
|
||||
SEP_TOKEN = "<sep>"
|
||||
PAD_TOKEN = "<pad>"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.word2vec = {
|
||||
"hello": 0.5 * torch.ones(1, _MODEL_DIM),
|
||||
"world": -0.5 * torch.ones(1, _MODEL_DIM),
|
||||
self.CLS_TOKEN: torch.zeros(1, _MODEL_DIM),
|
||||
self.SEP_TOKEN: torch.zeros(1, _MODEL_DIM),
|
||||
self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM),
|
||||
}
|
||||
|
||||
def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, torch.Tensor]:
|
||||
"""The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method
|
||||
should obey the input/output arguments structure described below.
|
||||
|
||||
Args:
|
||||
sentences:
|
||||
Input text. `Union[str, List[str]]`
|
||||
max_len:
|
||||
Maximum length of pre-processed text. `int`
|
||||
|
||||
Return:
|
||||
Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding `torch.Tensor`
|
||||
values.
|
||||
"""
|
||||
output_dict: Dict[str, torch.Tensor] = {}
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
# Add special tokens
|
||||
sentences = [" ".join([self.CLS_TOKEN, sentence, self.SEP_TOKEN]) for sentence in sentences]
|
||||
# Tokennize sentence
|
||||
tokenized_sentences = [
|
||||
sentence.lower().split()[:max_len] + [self.PAD_TOKEN] * (max_len - len(sentence.lower().split()))
|
||||
for sentence in sentences
|
||||
]
|
||||
output_dict["input_ids"] = torch.cat(
|
||||
[torch.cat([self.word2vec[word] for word in sentence]).unsqueeze(0) for sentence in tokenized_sentences]
|
||||
)
|
||||
output_dict["attention_mask"] = torch.cat(
|
||||
[
|
||||
torch.tensor([1 if word != self.PAD_TOKEN else 0 for word in sentence]).unsqueeze(0)
|
||||
for sentence in tokenized_sentences
|
||||
]
|
||||
).long()
|
||||
|
||||
return output_dict
|
||||
|
||||
|
||||
def get_user_model_encoder(
|
||||
num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_DIM, nhead: int = _NHEAD
|
||||
) -> torch.nn.Module:
|
||||
"""Initialize the Transformer encoder."""
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
|
||||
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
return transformer_encoder
|
||||
|
||||
|
||||
def user_forward_fn(model: torch.nn.Module, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""User forward function used for the computation of model embeddings.
|
||||
|
||||
This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the
|
||||
input/output argument structure described below.
|
||||
|
||||
Args:
|
||||
model:
|
||||
`torch.nn.Module`
|
||||
batch:
|
||||
`Dict[str, torch.Tensor]`
|
||||
|
||||
Return:
|
||||
The model output. `torch.Tensor`
|
||||
"""
|
||||
return model(batch["input_ids"])
|
||||
|
||||
|
||||
_PREDS = ["hello", "hello world", "world world world"]
|
||||
_REFS = ["hello", "hello hello", "hello world hello"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = UserTokenizer()
|
||||
model = get_user_model_encoder()
|
||||
|
||||
Scorer = BERTScore(
|
||||
model=model, user_tokenizer=tokenizer, user_forward_fn=user_forward_fn, max_length=_MAX_LEN, return_hash=False
|
||||
)
|
||||
Scorer.update(_PREDS, _REFS)
|
||||
print("Predictions")
|
||||
pprint(Scorer.predictions)
|
||||
|
||||
pprint(Scorer.compute())
|
|
@ -11,105 +11,653 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Optional
|
||||
import csv
|
||||
import math
|
||||
import urllib
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
if _BERTSCORE_AVAILABLE:
|
||||
from bert_score import BERTScorer, get_hash, lang2model, model2layers
|
||||
from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_AVAILABLE
|
||||
|
||||
if _TRANSFORMERS_AVAILABLE:
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
if _TQDM_AVAILABLE:
|
||||
import tqdm
|
||||
|
||||
|
||||
def _preprocess_text(
|
||||
text: List[str],
|
||||
tokenizer: Any,
|
||||
max_length: int = 512,
|
||||
truncation: bool = True,
|
||||
sort_according_length: bool = True,
|
||||
own_tokenizer: bool = False,
|
||||
) -> Dict[str, Tensor]:
|
||||
"""Default text pre-processing function using `transformers` `AutoTokenizer` instance.
|
||||
|
||||
Args:
|
||||
text:
|
||||
An iterable of sentences.
|
||||
tokenizer:
|
||||
Either `AutoTokenizer` instance from `transformers` package, or a user's own tokenizer.
|
||||
max_length:
|
||||
A maximum sequence length.
|
||||
trunction:
|
||||
An indication of whether tokenized sequences should be padded only to the length of the longest sequence.
|
||||
sort_according_to_length:
|
||||
An indication of whether tokenized sequences should be sorted from shortest to longest. This is appropriate
|
||||
to do for leveraging dynamic padding during embedding calculation and thereby to hasten inference.
|
||||
own_tokenizer:
|
||||
An indication of whether a non-default user's own tokenizer is used.
|
||||
|
||||
Return:
|
||||
A dictionary of tokenized sentences including input_ids and attention_mask.
|
||||
|
||||
Raises:
|
||||
BaseException:
|
||||
If a tokenization with a user's own tokenizer is not successful.
|
||||
"""
|
||||
if not own_tokenizer:
|
||||
tokenized_data = tokenizer(
|
||||
text, padding="max_length", max_length=max_length, truncation=truncation, return_tensors="pt"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
tokenized_data = tokenizer(text, max_length)
|
||||
except BaseException as e:
|
||||
raise BaseException(f"Tokenization was not successful: {e}")
|
||||
|
||||
input_ids, attention_mask = (
|
||||
_sort_data_according_length(tokenized_data["input_ids"], tokenized_data["attention_mask"])
|
||||
if sort_according_length
|
||||
else (tokenized_data["input_ids"], tokenized_data["attention_mask"])
|
||||
)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
def _process_attention_mask_for_special_tokens(attention_mask: Tensor) -> Tensor:
|
||||
"""Process attention mask to be zero for special [CLS] and [SEP] tokens as they're not included in a
|
||||
calculation for BERT score.
|
||||
|
||||
Args:
|
||||
attention_mask:
|
||||
An attention mask to be returned, for example, by a `transformers` tokenizer.
|
||||
|
||||
Return:
|
||||
A processd attention mask.
|
||||
"""
|
||||
# Make attention_mask zero for [CLS] token
|
||||
attention_mask[:, 0] = 0
|
||||
# Make attention_mask zero for [SEP] token
|
||||
sep_token_position = (attention_mask - 0.1).cumsum(-1).argmax(-1)
|
||||
attention_mask[torch.arange(attention_mask.size(0)).long(), sep_token_position] = 0
|
||||
return attention_mask
|
||||
|
||||
|
||||
def _sort_data_according_length(input_ids: Tensor, attention_mask: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Sort tokenized sentence from the shortest to the longest one."""
|
||||
sorted_indices = attention_mask.sum(1).argsort()
|
||||
input_ids = input_ids[sorted_indices]
|
||||
attention_mask = attention_mask[sorted_indices]
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def _input_data_collator(
|
||||
batch: Dict[str, Tensor], device: Optional[Union[str, torch.device]] = None
|
||||
) -> Dict[str, Tensor]:
|
||||
"""Helper function that trims model inputs to the longest sequence within the batch and put the input on the
|
||||
proper device."""
|
||||
max_len = int(batch["attention_mask"].sum(1).max().item())
|
||||
input_ids = batch["input_ids"][:, :max_len].to(device)
|
||||
attention_mask = batch["attention_mask"][:, :max_len].to(device)
|
||||
batch.update({"input_ids": input_ids, "attention_mask": attention_mask})
|
||||
return batch
|
||||
|
||||
|
||||
def _output_data_collator(model_output: Tensor, attention_mask: Tensor, target_len: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Helper function that pads the model output and attention mask to the target length."""
|
||||
zeros_shape = list(model_output.shape)
|
||||
zeros_shape[2] = target_len - zeros_shape[2]
|
||||
model_output = torch.cat(
|
||||
[model_output, torch.zeros(zeros_shape, dtype=model_output.dtype).to(model_output.device)], dim=2
|
||||
)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
torch.zeros(zeros_shape[0], zeros_shape[2], dtype=attention_mask.dtype).to(attention_mask.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return model_output, attention_mask
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""PyTorch dataset class for storing tokenized sentences and other properties used for BERT score
|
||||
calculation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: List[str],
|
||||
tokenizer: Any,
|
||||
max_length: int = 512,
|
||||
preprocess_text_fn: Callable[[List[str], Any, int], Dict[str, Tensor]] = _preprocess_text,
|
||||
idf: bool = False,
|
||||
tokens_idf: Optional[Dict[int, float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
text:
|
||||
An iterable of sentences.
|
||||
tokenizer:
|
||||
`AutoTokenizer` instance from `transformers` package.
|
||||
max_length:
|
||||
A maximum sequence length.
|
||||
preprocess_text_fn:
|
||||
A function used for processing the input sentences.
|
||||
idf:
|
||||
An indication of whether calculate token inverse document frequencies to weight the model embeddings.
|
||||
tokens_idf:
|
||||
Inverse document frequencies (these should be calculated on reference sentences).
|
||||
own_tokenizer:
|
||||
An indication of whether a non-default user's own tokenizer is used.
|
||||
"""
|
||||
self.text = preprocess_text_fn(text, tokenizer, max_length)
|
||||
self.max_length = self.text["input_ids"].shape[1]
|
||||
self.num_sentences = len(text)
|
||||
self.idf = idf
|
||||
self.tokens_idf = {}
|
||||
if idf:
|
||||
self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf()
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
|
||||
input_ids = self.text["input_ids"][idx, :]
|
||||
attention_mask = self.text["attention_mask"][idx, :]
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
if self.idf:
|
||||
input_ids_idf = torch.tensor([self.tokens_idf[input_idx] for input_idx in input_ids.tolist()])
|
||||
inputs_dict["input_ids_idf"] = input_ids_idf
|
||||
return inputs_dict
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_sentences
|
||||
|
||||
def _get_tokens_idf(self) -> Dict[int, float]:
|
||||
"""Calculate token inverse document frequences.
|
||||
|
||||
Return:
|
||||
A python dictionary containing inverse document frequences for token ids.
|
||||
"""
|
||||
token_counter: Counter = Counter()
|
||||
for tokens in map(self._set_of_tokens, self.text["input_ids"]):
|
||||
token_counter.update(tokens)
|
||||
|
||||
tokens_idf: Dict[int, float] = defaultdict(self._get_tokens_idf_default_value)
|
||||
tokens_idf.update(
|
||||
{idx: math.log((self.num_sentences + 1) / (occurrence + 1)) for idx, occurrence in token_counter.items()}
|
||||
)
|
||||
return tokens_idf
|
||||
|
||||
def _get_tokens_idf_default_value(self) -> float:
|
||||
"""Helper function that ensures `defaultdict` to be pickled."""
|
||||
return math.log((self.num_sentences + 1) / 1)
|
||||
|
||||
@staticmethod
|
||||
def _set_of_tokens(input_ids: Tensor) -> Set:
|
||||
"""Return set of tokens from the `input_ids` `torch.Tensor`."""
|
||||
return set(input_ids.tolist())
|
||||
|
||||
|
||||
class TokenizedDataset(TextDataset):
|
||||
"""The child class of `TextDataset` class used with already tokenized data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Tensor,
|
||||
idf: bool = False,
|
||||
tokens_idf: Optional[Dict[int, float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
input_ids:
|
||||
Input ids (`torch.Tensor`).
|
||||
attention_mask:
|
||||
Attention mask (`torch.Tensor`).
|
||||
idf:
|
||||
An indication of whether calculate token inverse document frequencies to weight the model embeddings.
|
||||
tokens_idf:
|
||||
Inverse document frequencies (these should be calculated on reference sentences).
|
||||
"""
|
||||
self.text = dict(zip(["input_ids", "attention_mask"], _sort_data_according_length(input_ids, attention_mask)))
|
||||
self.text = _input_data_collator(self.text)
|
||||
self.num_sentences = len(self.text["input_ids"])
|
||||
self.max_length = self.text["input_ids"].shape[1]
|
||||
self.idf = idf
|
||||
self.tokens_idf = {}
|
||||
if idf:
|
||||
self.tokens_idf = tokens_idf if tokens_idf is not None else self._get_tokens_idf()
|
||||
|
||||
|
||||
def _get_progress_bar(dataloader: DataLoader, verbose: bool = False) -> Union[DataLoader, tqdm.auto.tqdm]:
|
||||
"""Helper function returning either the dataloader itself when `verbose = False`, or it wraps the dataloader with
|
||||
`tqdm.auto.tqdm`, when `verbose = True` to display a progress bar during the embbeddings calculation."""
|
||||
return tqdm.auto.tqdm(dataloader) if verbose else dataloader
|
||||
|
||||
|
||||
def _check_shape_of_model_output(output: Tensor, input_ids: Tensor) -> None:
|
||||
"""Check if the shape of the user's own model output."""
|
||||
bs, seq_len = input_ids.shape[:2]
|
||||
invalid_out_shape = len(output.shape) != 3 or output.shape[0] != bs or output.shape[1] != seq_len
|
||||
if invalid_out_shape:
|
||||
raise ValueError(
|
||||
"The model output must be `torch.Tensor` of a shape `[batch_size, seq_len, model_dim]` "
|
||||
f"i.e. [{bs}, {seq_len}. , `model_dim`], but got {output.shape}."
|
||||
)
|
||||
|
||||
|
||||
def _get_embeddings_and_idf_scale(
|
||||
dataloader: DataLoader,
|
||||
target_len: int,
|
||||
model: torch.nn.Module,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
num_layers: Optional[int] = None,
|
||||
all_layers: bool = False,
|
||||
idf: bool = False,
|
||||
verbose: bool = False,
|
||||
user_forward_fn: Callable[[torch.nn.Module, Dict[str, Tensor]], Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Calculate sentence embeddings and the inverse-document-frequence scaling factor.
|
||||
Args:
|
||||
dataloader:
|
||||
`torch.utils.data.DataLoader` instance.
|
||||
target_len:
|
||||
A length of the longest sequence in the data. Used for padding the model output.
|
||||
model:
|
||||
BERT model.
|
||||
device:
|
||||
A device to be used for calculation.
|
||||
num_layers:
|
||||
The layer of representation to use.
|
||||
all_layers:
|
||||
An indication whether representation from all model layers should be used for BERTScore.
|
||||
idf:
|
||||
An Indication whether normalization using inverse document frequencies should be used.
|
||||
verbose:
|
||||
An indication of whether a progress bar to be displayed during the embeddings calculation.
|
||||
user_forward_fn:
|
||||
A user's own forward function used in a combination with `user_model`. This function must take `user_model`
|
||||
and a python dictionary of containing `"input_ids"` and `"attention_mask"` represented by `torch.Tensor`
|
||||
as an input and return the model's output represented by the single `torch.Tensor`.
|
||||
|
||||
Return:
|
||||
A tuple of torch.Tensors containing the model's embeddings and the normalized tokens IDF.
|
||||
When `idf = False`, tokens IDF is not calculated, and a matrix of mean weights is returned instead.
|
||||
For a single sentence, `mean_weight = 1/seq_len`, where `seq_len` is a sum over the corresponding
|
||||
`attention_mask`.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If `all_layers = True` and a model, which is not from the `transformers` package, is used.
|
||||
"""
|
||||
embeddings_list: List[Tensor] = []
|
||||
idf_scale_list: List[Tensor] = []
|
||||
for batch in _get_progress_bar(dataloader, verbose):
|
||||
with torch.no_grad():
|
||||
batch = _input_data_collator(batch, device)
|
||||
# Output shape: batch_size x num_layers OR 1 x sequence_length x bert_dim
|
||||
if not all_layers:
|
||||
if not user_forward_fn:
|
||||
out = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True)
|
||||
out = out.hidden_states[num_layers if num_layers is not None else -1]
|
||||
else:
|
||||
out = user_forward_fn(model, batch)
|
||||
_check_shape_of_model_output(out, batch["input_ids"])
|
||||
out = out.unsqueeze(1)
|
||||
else:
|
||||
if user_forward_fn:
|
||||
raise ValueError(
|
||||
"The option `all_layers=True` can be used only with default `transformers` models."
|
||||
)
|
||||
out = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True)
|
||||
out = torch.cat([o.unsqueeze(1) for o in out.hidden_states], dim=1)
|
||||
|
||||
out /= out.norm(dim=-1).unsqueeze(-1) # normalize embeddings
|
||||
out, attention_mask = _output_data_collator(out, batch["attention_mask"], target_len)
|
||||
processed_attention_mask = _process_attention_mask_for_special_tokens(attention_mask)
|
||||
# Multiply embeddings with attention_mask (b=batch_size, l=num_layers, s=seq_len, d=emb_dim)
|
||||
out = torch.einsum("blsd, bs -> blsd", out, processed_attention_mask)
|
||||
embeddings_list.append(out.cpu())
|
||||
|
||||
# Calculate weighted (w.r.t. sentence length) input_ids IDF matrix
|
||||
input_ids_idf = (
|
||||
batch["input_ids_idf"] * processed_attention_mask if idf else processed_attention_mask.type(out.dtype)
|
||||
)
|
||||
input_ids_idf /= input_ids_idf.sum(-1, keepdim=True)
|
||||
idf_scale_list.append(input_ids_idf)
|
||||
|
||||
embeddings = torch.cat(embeddings_list)
|
||||
idf_scale = torch.cat(idf_scale_list)
|
||||
|
||||
return embeddings, idf_scale
|
||||
|
||||
|
||||
def _get_scaled_precision_or_recall(cos_sim: Tensor, metric: str, idf_scale: Tensor) -> Tensor:
|
||||
"""Helper function that calculates precision or recall, transpose it and scale it with idf_scale factor."""
|
||||
dim = 3 if metric == "precision" else 2
|
||||
res = cos_sim.max(dim=dim).values
|
||||
res = torch.einsum("bls, bs -> bls", res, idf_scale).sum(-1)
|
||||
# We transpose the results and squeeze if possible to match the format of the original BERTScore implementation
|
||||
res = res.transpose(0, 1).squeeze()
|
||||
return res
|
||||
|
||||
|
||||
def _get_precision_recall_f1(
|
||||
pred_embeddings: Tensor, ref_embeddings: Tensor, pred_idf_scale: Tensor, ref_idf_scale: Tensor
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Calculate precision, recall and F1 score over candidate and reference sentences.
|
||||
|
||||
Args:
|
||||
pred_embeddings:
|
||||
Embeddings of candidate sentenecs.
|
||||
ref_embeddings:
|
||||
Embeddings of reference sentences.
|
||||
pred_idf_scale:
|
||||
An IDF scale factor for candidate sentences.
|
||||
ref_idf_scale:
|
||||
An IDF scale factor for reference sentences.
|
||||
|
||||
Return:
|
||||
Tensors containing precision, recall and F1 score, respectively.
|
||||
"""
|
||||
# Dimensions: b = batch_size, l = num_layers, p = predictions_seq_len, r = references_seq_len, d = bert_dim
|
||||
cos_sim = torch.einsum("blpd, blrd -> blpr", pred_embeddings, ref_embeddings)
|
||||
# Final metrics shape = (batch_size * num_layers | batch_size)
|
||||
precision = _get_scaled_precision_or_recall(cos_sim, "precision", pred_idf_scale)
|
||||
recall = _get_scaled_precision_or_recall(cos_sim, "recall", ref_idf_scale)
|
||||
|
||||
f1_score = 2 * precision * recall / (precision + recall)
|
||||
f1_score = f1_score.masked_fill(torch.isnan(f1_score), 0.0)
|
||||
|
||||
return precision, recall, f1_score
|
||||
|
||||
|
||||
def _get_hash(model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, idf: bool = False) -> str:
|
||||
"""Copied from https://github.com/Tiiiger/bert_score/blob/master/bert_score/utils.py and adjusted."""
|
||||
msg = f"{model_name_or_path}_L{num_layers}{'_idf' if idf else '_no-idf'}"
|
||||
return msg
|
||||
|
||||
|
||||
def _read_csv_from_local_file(baseline_path: str) -> Tensor:
|
||||
"""Helper function which reads baseline the csv file from the local file.
|
||||
|
||||
This method implemented to avoid `pandas` dependency.
|
||||
"""
|
||||
with open(baseline_path) as fname:
|
||||
csv_file = csv.reader(fname)
|
||||
baseline_list = [[float(item) for item in row] for idx, row in enumerate(csv_file) if idx > 0]
|
||||
baseline = torch.tensor(baseline_list)[:, 1:]
|
||||
return baseline
|
||||
|
||||
|
||||
def _read_csv_from_url(baseline_url: str) -> Tensor:
|
||||
"""Helper function which reads the baseline csv file from URL.
|
||||
|
||||
This method is implemented to avoid `pandas` dependency.
|
||||
"""
|
||||
with urllib.request.urlopen(baseline_url) as http_request: # type: ignore
|
||||
baseline_list = [
|
||||
[float(item) for item in row.strip().decode("utf-8").split(",")]
|
||||
for idx, row in enumerate(http_request)
|
||||
if idx > 0
|
||||
]
|
||||
baseline = torch.tensor(baseline_list)[:, 1:]
|
||||
return baseline
|
||||
|
||||
|
||||
def _load_baseline(
|
||||
lang: str = "en",
|
||||
model_name_or_path: Optional[str] = None,
|
||||
baseline_path: Optional[str] = None,
|
||||
baseline_url: Optional[str] = None,
|
||||
) -> Optional[Tensor]:
|
||||
"""Load a CSV file with the baseline values used for rescaling."""
|
||||
if baseline_path:
|
||||
baseline: Optional[Tensor] = _read_csv_from_local_file(baseline_path)
|
||||
elif baseline_url:
|
||||
baseline = _read_csv_from_url(baseline_url)
|
||||
# Read default baseline from the original `bert-score` package https://github.com/Tiiiger/bert_score
|
||||
elif lang and model_name_or_path:
|
||||
_URL_BASE = "https://raw.githubusercontent.com/Tiiiger/bert_score/master/bert_score/rescale_baseline"
|
||||
baseline_url = f"{_URL_BASE}/{lang}/{model_name_or_path}.tsv"
|
||||
baseline = _read_csv_from_url(baseline_url)
|
||||
else:
|
||||
baseline = None
|
||||
warnings.warn("Baseline was not successfully loaded. No baseline is going to be used.")
|
||||
|
||||
return baseline
|
||||
|
||||
|
||||
def _rescale_metrics_with_baseline(
|
||||
precision: Tensor,
|
||||
recall: Tensor,
|
||||
f1_score: Tensor,
|
||||
baseline: Tensor,
|
||||
num_layers: Optional[int] = None,
|
||||
all_layers: bool = False,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Rescale the computed metrics with the pre-computed baseline."""
|
||||
if num_layers is None and all_layers is False:
|
||||
num_layers = -1
|
||||
all_metrics = torch.stack([precision, recall, f1_score], dim=-1)
|
||||
baseline_scale = baseline.unsqueeze(1) if all_layers else baseline[num_layers]
|
||||
all_metrics = (all_metrics - baseline_scale) / (1 - baseline_scale)
|
||||
|
||||
return all_metrics[..., 0], all_metrics[..., 1], all_metrics[..., 2]
|
||||
|
||||
|
||||
def bert_score(
|
||||
predictions: List[str],
|
||||
references: List[str],
|
||||
lang: str = "en",
|
||||
model_type: Optional[str] = None,
|
||||
num_layers: int = None,
|
||||
predictions: Union[List[str], Dict[str, Tensor]],
|
||||
references: Union[List[str], Dict[str, Tensor]],
|
||||
model_name_or_path: Optional[str] = None,
|
||||
num_layers: Optional[int] = None,
|
||||
all_layers: bool = False,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
user_tokenizer: Any = None,
|
||||
user_forward_fn: Callable[[torch.nn.Module, Dict[str, Tensor]], Tensor] = None,
|
||||
verbose: bool = False,
|
||||
idf: bool = False,
|
||||
device: Optional[str] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
max_length: int = 512,
|
||||
batch_size: int = 64,
|
||||
num_threads: int = 4,
|
||||
all_layers: bool = False,
|
||||
return_hash: bool = False,
|
||||
lang: str = "en",
|
||||
rescale_with_baseline: bool = False,
|
||||
baseline_path: Optional[str] = None,
|
||||
) -> Dict:
|
||||
baseline_url: Optional[str] = None,
|
||||
) -> Dict[str, Union[List[float], str]]:
|
||||
"""`BERTScore <https://arxiv.org/abs/1904.09675>`_ leverages the pre-trained contextual embeddings from BERT
|
||||
and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate
|
||||
with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision,
|
||||
recall, and F1 measure, which can be useful for evaluating different language generation tasks.
|
||||
recall, and F1 measure, which can be useful for evaluating different language generation tasks. assert
|
||||
len(predictions) == len(references), "Number of predicted and reference sententes must be the same!".
|
||||
|
||||
This implemenation follows the original implementation from https://github.com/Tiiiger/bert_score.
|
||||
|
||||
Args:
|
||||
predictions: candidate sentences
|
||||
references: reference sentences
|
||||
model_type: bert specification
|
||||
num_layers: the layer of representation to use.
|
||||
verbose: turn on intermediate status update
|
||||
idf: use idf weighting, can also be a precomputed idf_dict
|
||||
device: on which the contextual embedding model will be allocated on.
|
||||
num_threads: number of threads
|
||||
batch_size: bert score processing batch size
|
||||
lang: language of the sentences
|
||||
rescale_with_baseline: rescale bertscore with pre-computed baseline
|
||||
baseline_path: customized baseline file
|
||||
predictions:
|
||||
Either an iterable of predicted sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and
|
||||
`attention_mask` `torch.Tensor`.
|
||||
references:
|
||||
Either an iterable of target sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and
|
||||
`attention_mask` `torch.Tensor`.
|
||||
model_type:
|
||||
A name or a model path used to load `transformers` pretrained model.
|
||||
num_layers:
|
||||
A layer of representation to use.
|
||||
all_layers:
|
||||
An indication of whether the representation from all model's layers should be used.
|
||||
If `all_layers = True`, the argument `num_layers` is ignored.
|
||||
model:
|
||||
A user's own model. Must be of `torch.nn.Module` instance.
|
||||
user_tokenizer:
|
||||
A user's own tokenizer used with the own model. This must be an instance with the `__call__` method.
|
||||
This method must take an iterable of sentences (`List[str]`) and must return a python dictionary
|
||||
containing `"input_ids"` and `"attention_mask"` represented by `torch.Tensor`. It is up to the user's model
|
||||
of whether `"input_ids"` is a `torch.Tensor` of input ids or embedding vectors.
|
||||
This tokenizer must prepend an equivalent of `[CLS]` token and append an equivalent of `[SEP]` token
|
||||
as `transformers` tokenizer does.
|
||||
user_forward_fn:
|
||||
A user's own forward function used in a combination with `user_model`. This function must take `user_model`
|
||||
and a python dictionary of containing `"input_ids"` and `"attention_mask"` represented by `torch.Tensor`
|
||||
as an input and return the model's output represented by the single `torch.Tensor`.
|
||||
verbose:
|
||||
An indication of whether a progress bar to be displayed during the embeddings calculation.
|
||||
idf:
|
||||
An indication of whether normalization using inverse document frequencies should be used.
|
||||
device:
|
||||
A device to be used for calculation.
|
||||
max_length:
|
||||
A maximum length of input sequences. Sequences longer than `max_length` are to be trimmed.
|
||||
batch_size:
|
||||
A batch size used for model processing.
|
||||
num_threads:
|
||||
A number of threads to use for a dataloader.
|
||||
return_hash:
|
||||
An indication of whether the correspodning `hash_code` should be returned.
|
||||
lang:
|
||||
A language of input sentences. It is used when the scores are rescaled with a baseline.
|
||||
rescale_with_baseline:
|
||||
An indication of whether bertscore should be rescaled with a pre-computed baseline.
|
||||
When a pretrained model from `transformers` model is used, the corresponding baseline is downloaded
|
||||
from the original `bert-score` package from https://github.com/Tiiiger/bert_score if available.
|
||||
In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting
|
||||
of the files from https://github.com/Tiiiger/bert_score.
|
||||
baseline_path:
|
||||
A path to the user's own local csv/tsv file with the baseline scale.
|
||||
baseline_url:
|
||||
A url path to the user's own csv/tsv file with the baseline scale.
|
||||
|
||||
Returns:
|
||||
Dict containing the keys `precision`, `recall`, `f1` and `hashcode` with corresponding values
|
||||
Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If `len(predictions) != len(references)`.
|
||||
ValueError:
|
||||
If `tqdm` package is required and not installed.
|
||||
ValueError:
|
||||
If `transformers` package is required and not installed.
|
||||
ValueError:
|
||||
If `num_layer` is larger than the number of the model layers.
|
||||
ValueError:
|
||||
If invalid input is provided.
|
||||
|
||||
Example:
|
||||
>>> predictions = ["hello there", "general kenobi"]
|
||||
>>> references = ["hello there", "master kenobi"]
|
||||
>>> bert_score(predictions=predictions, references=references, lang="en") # doctest: +SKIP
|
||||
{'f1': [0.99..., 0.99...],
|
||||
'hashcode': '...',
|
||||
'precision': [0.99..., 0.99...],
|
||||
'recall': [0.99..., 0.99...]}
|
||||
{'precision': [0.99..., 0.99...],
|
||||
'recall': [0.99..., 0.99...],
|
||||
'f1': [0.99..., 0.99...]}
|
||||
"""
|
||||
if len(predictions) != len(references):
|
||||
raise ValueError("Number of predicted and reference sententes must be the same!")
|
||||
|
||||
if not _BERTSCORE_AVAILABLE:
|
||||
if verbose and (not _TQDM_AVAILABLE):
|
||||
raise ValueError(
|
||||
"bert_score metric requires that bert-score package is installed."
|
||||
" Either install with `pip install bert-score` or `pip install torchmetrics[text]`"
|
||||
"An argument `verbose = True` requires `tqdm` package be installed. Install with `pip install tqdm`."
|
||||
)
|
||||
|
||||
if model_type is None:
|
||||
model_type = lang2model[lang.lower()]
|
||||
if model is None:
|
||||
if not _TRANSFORMERS_AVAILABLE:
|
||||
raise ValueError(
|
||||
"`bert_score` metric with default models requires `transformers` package be installed. "
|
||||
"Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
model = AutoModel.from_pretrained(model_name_or_path)
|
||||
else:
|
||||
tokenizer = user_tokenizer
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
if num_layers is None:
|
||||
num_layers = model2layers[model_type]
|
||||
try:
|
||||
if num_layers and num_layers > model.config.num_hidden_layers: # type: ignore
|
||||
raise ValueError(
|
||||
f"num_layers={num_layers} is forbidden for {model_name_or_path}. " # type: ignore
|
||||
f"Please use num_layers <= {model.config.num_hidden_layers}" # type: ignore
|
||||
)
|
||||
except AttributeError:
|
||||
warnings.warn("It was not possible to retrieve the parameter `num_layers` from the model specification.")
|
||||
|
||||
hashcode = get_hash(
|
||||
model=model_type,
|
||||
num_layers=num_layers,
|
||||
idf=idf,
|
||||
rescale_with_baseline=rescale_with_baseline,
|
||||
use_custom_baseline=baseline_path is not None,
|
||||
use_fast_tokenizer=True,
|
||||
_are_empty_lists = all(isinstance(text, list) and len(text) == 0 for text in (predictions, references))
|
||||
_are_valid_lists = all(
|
||||
isinstance(text, list) and len(text) > 0 and isinstance(text[0], str) for text in (predictions, references)
|
||||
)
|
||||
_are_valid_tensors = all(
|
||||
isinstance(text, dict) and isinstance(text["input_ids"], Tensor) for text in (predictions, references)
|
||||
)
|
||||
if _are_empty_lists:
|
||||
warnings.warn("Predictions and references are empty.")
|
||||
output_dict: Dict[str, Union[List[float], str]] = {
|
||||
"precision": [0.0],
|
||||
"recall": [0.0],
|
||||
"f1": [0.0],
|
||||
}
|
||||
if return_hash:
|
||||
output_dict.update({"hash": _get_hash(model_name_or_path, num_layers, idf)})
|
||||
return output_dict
|
||||
|
||||
# Load baselines if needed
|
||||
baseline = _load_baseline(lang, model_name_or_path, baseline_path, baseline_url) if rescale_with_baseline else None
|
||||
|
||||
# We ignore mypy typing below as the proper typing is ensured by conditions above, only mypy cannot infer that.
|
||||
if _are_valid_lists:
|
||||
ref_dataset = TextDataset(references, tokenizer, max_length, idf=idf) # type: ignore
|
||||
pred_dataset = TextDataset(
|
||||
predictions, # type: ignore
|
||||
tokenizer,
|
||||
max_length,
|
||||
idf=idf,
|
||||
tokens_idf=ref_dataset.tokens_idf,
|
||||
)
|
||||
elif _are_valid_tensors:
|
||||
ref_dataset = TokenizedDataset(**references, idf=idf) # type: ignore
|
||||
pred_dataset = TokenizedDataset(**predictions, idf=idf, tokens_idf=ref_dataset.tokens_idf) # type: ignore
|
||||
else:
|
||||
raise ValueError("Invalid input provided.")
|
||||
|
||||
ref_loader = DataLoader(ref_dataset, batch_size=batch_size, num_workers=num_threads)
|
||||
pred_loader = DataLoader(pred_dataset, batch_size=batch_size, num_workers=num_threads)
|
||||
|
||||
ref_embeddings, ref_idf_scale = _get_embeddings_and_idf_scale(
|
||||
ref_loader, ref_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn
|
||||
)
|
||||
pred_embeddings, pred_idf_scale = _get_embeddings_and_idf_scale(
|
||||
pred_loader, pred_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn
|
||||
)
|
||||
|
||||
cached_bertscorer = BERTScorer(
|
||||
model_type=model_type,
|
||||
num_layers=num_layers,
|
||||
batch_size=batch_size,
|
||||
nthreads=num_threads,
|
||||
all_layers=all_layers,
|
||||
idf=idf,
|
||||
device=device,
|
||||
lang=lang,
|
||||
rescale_with_baseline=rescale_with_baseline,
|
||||
baseline_path=baseline_path,
|
||||
precision, recall, f1_score = _get_precision_recall_f1(
|
||||
pred_embeddings, ref_embeddings, pred_idf_scale, ref_idf_scale
|
||||
)
|
||||
|
||||
prec, recall, f1 = cached_bertscorer.score(
|
||||
cands=predictions,
|
||||
refs=references,
|
||||
verbose=verbose,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
if baseline is not None:
|
||||
precision, recall, f1_score = _rescale_metrics_with_baseline(
|
||||
precision, recall, f1_score, baseline, num_layers, all_layers
|
||||
)
|
||||
|
||||
output_dict = {
|
||||
"precision": prec.tolist(),
|
||||
"precision": precision.tolist(),
|
||||
"recall": recall.tolist(),
|
||||
"f1": f1.tolist(),
|
||||
"hashcode": hashcode,
|
||||
"f1": f1_score.tolist(),
|
||||
}
|
||||
if return_hash:
|
||||
output_dict.update({"hash": _get_hash(model_name_or_path, num_layers, idf)})
|
||||
return output_dict
|
||||
|
|
|
@ -11,15 +11,30 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torchmetrics.functional import bert_score
|
||||
from torchmetrics.functional.text.bert import _preprocess_text
|
||||
from torchmetrics.metric import Metric
|
||||
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE
|
||||
|
||||
if _TRANSFORMERS_AVAILABLE:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def _flatten(x: List[List[str]]) -> List[str]:
|
||||
"""converts list of list to single list of strings."""
|
||||
return [e for y in x for e in y]
|
||||
# Default model recommended in the original implementation.
|
||||
_DEFAULT_MODEL = "roberta-large"
|
||||
|
||||
|
||||
def _concatenate(d: Dict[str, List[torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||
"""Concatenate list of tensors within a given dictionary."""
|
||||
output_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in d.items():
|
||||
output_dict[k] = torch.cat(v)
|
||||
return output_dict
|
||||
|
||||
|
||||
class BERTScore(Metric):
|
||||
|
@ -28,19 +43,59 @@ class BERTScore(Metric):
|
|||
with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision,
|
||||
recall, and F1 measure, which can be useful for evaluating different language generation tasks.
|
||||
|
||||
This implemenation follows the original implementation from https://github.com/Tiiiger/bert_score.
|
||||
|
||||
Args:
|
||||
predictions: candidate sentences
|
||||
references: reference sentences
|
||||
model_type: bert specification
|
||||
num_layers: the layer of representation to use.
|
||||
verbose: turn on intermediate status update
|
||||
idf: use idf weighting, can also be a precomputed idf_dict
|
||||
device: on which the contextual embedding model will be allocated on.
|
||||
num_threads: number of threads
|
||||
batch_size: bert score processing batch size
|
||||
lang: language of the sentences
|
||||
rescale_with_baseline: rescale bertscore with pre-computed baseline
|
||||
baseline_path: customized baseline file
|
||||
predictions:
|
||||
An iterable of predicted sentences.
|
||||
references:
|
||||
An iterable of target sentences.
|
||||
model_type:
|
||||
A name or a model path used to load `transformers` pretrained model.
|
||||
num_layers:
|
||||
A layer of representation to use.
|
||||
all_layers:
|
||||
An indication of whether the representation from all model's layers should be used.
|
||||
If `all_layers = True`, the argument `num_layers` is ignored.
|
||||
model:
|
||||
A user's own model. Must be of `torch.nn.Module` instance.
|
||||
user_tokenizer:
|
||||
A user's own tokenizer used with the own model. This must be an instance with the `__call__` method.
|
||||
This method must take an iterable of sentences (`List[str]`) and must return a python dictionary
|
||||
containing `"input_ids"` and `"attention_mask"` represented by `torch.Tensor`. It is up to the user's model
|
||||
of whether `"input_ids"` is a `torch.Tensor` of input ids or embedding vectors.
|
||||
This tokenizer must prepend an equivalent of `[CLS]` token and append an equivalent of `[SEP]` token
|
||||
as `transformers` tokenizer does.
|
||||
user_forward_fn:
|
||||
A user's own forward function used in a combination with `user_model`. This function must take `user_model`
|
||||
and a python dictionary of containing `"input_ids"` and `"attention_mask"` represented by `torch.Tensor`
|
||||
as an input and return the model's output represented by the single `torch.Tensor`.
|
||||
verbose:
|
||||
An indication of whether a progress bar to be displayed during the embeddings calculation.
|
||||
idf:
|
||||
An indication whether normalization using inverse document frequencies should be used.
|
||||
device:
|
||||
A device to be used for calculation.
|
||||
max_length:
|
||||
A maximum length of input sequences. Sequences longer than `max_length` are to be trimmed.
|
||||
batch_size:
|
||||
A batch size used for model processing.
|
||||
num_threads:
|
||||
A number of threads to use for a dataloader.
|
||||
return_hash:
|
||||
An indication of whether the correspodning `hash_code` should be returned.
|
||||
lang:
|
||||
A language of input sentences.
|
||||
rescale_with_baseline:
|
||||
An indication of whether bertscore should be rescaled with a pre-computed baseline.
|
||||
When a pretrained model from `transformers` model is used, the corresponding baseline is downloaded
|
||||
from the original `bert-score` package from https://github.com/Tiiiger/bert_score if available.
|
||||
In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting
|
||||
of the files from https://github.com/Tiiiger/bert_score.
|
||||
baseline_path:
|
||||
A path to the user's own local csv/tsv file with the baseline scale.
|
||||
baseline_url:
|
||||
A url path to the user's own csv/tsv file with the baseline scale.
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||
dist_sync_on_step:
|
||||
|
@ -53,7 +108,7 @@ class BERTScore(Metric):
|
|||
will be used to perform the allgather
|
||||
|
||||
Returns:
|
||||
Dict containing the keys `precision`, `recall`, `f1` and `hashcode` with corresponding values
|
||||
Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values.
|
||||
|
||||
Example:
|
||||
>>> predictions = ["hello there", "general kenobi"]
|
||||
|
@ -61,25 +116,30 @@ class BERTScore(Metric):
|
|||
>>> bertscore = BERTScore()
|
||||
>>> bertscore.update(predictions=predictions,references=references)
|
||||
>>> bertscore.compute() # doctest: +SKIP
|
||||
{'f1': [0.99..., 0.99...],
|
||||
'hashcode': '...',
|
||||
'precision': [0.99..., 0.99...],
|
||||
'recall': [0.99..., 0.99...]}
|
||||
{'precision': [0.99..., 0.99...],
|
||||
'recall': [0.99..., 0.99...],
|
||||
'f1': [0.99..., 0.99...]}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: Optional[str] = None,
|
||||
lang: str = "en",
|
||||
num_layers: int = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
num_layers: Optional[int] = None,
|
||||
all_layers: bool = False,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
user_tokenizer: Optional[Any] = None,
|
||||
user_forward_fn: Callable[[torch.nn.Module, Dict[str, torch.Tensor]], torch.Tensor] = None,
|
||||
verbose: bool = False,
|
||||
idf: bool = False,
|
||||
device: Optional[str] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
max_length: int = 512,
|
||||
batch_size: int = 64,
|
||||
num_threads: int = 4,
|
||||
all_layers: bool = False,
|
||||
return_hash: bool = False,
|
||||
lang: str = "en",
|
||||
rescale_with_baseline: bool = False,
|
||||
baseline_path: Optional[str] = None,
|
||||
baseline_url: Optional[str] = None,
|
||||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
|
@ -91,47 +151,99 @@ class BERTScore(Metric):
|
|||
process_group=process_group,
|
||||
dist_sync_fn=dist_sync_fn,
|
||||
)
|
||||
self.baseline_path = baseline_path
|
||||
self.rescale_with_baseline = rescale_with_baseline
|
||||
self.lang = lang
|
||||
self.all_layers = all_layers
|
||||
self.num_threads = num_threads
|
||||
self.batch_size = batch_size
|
||||
self.embedding_device = device
|
||||
self.idf = idf
|
||||
self.verbose = verbose
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.num_layers = num_layers
|
||||
self.model_type = model_type
|
||||
self.add_state("predictions", [], dist_reduce_fx="cat")
|
||||
self.add_state("references", [], dist_reduce_fx="cat")
|
||||
self.all_layers = all_layers
|
||||
self.model = model
|
||||
self.user_forward_fn = user_forward_fn
|
||||
self.verbose = verbose
|
||||
self.idf = idf
|
||||
self.embedding_device = device
|
||||
self.max_length = max_length
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.return_hash = return_hash
|
||||
self.lang = lang
|
||||
self.rescale_with_baseline = rescale_with_baseline
|
||||
self.baseline_path = baseline_path
|
||||
self.baseline_url = baseline_url
|
||||
self.predictions: Dict[str, List[torch.Tensor]] = {"input_ids": [], "attention_mask": []}
|
||||
self.references: Dict[str, List[torch.Tensor]] = {"input_ids": [], "attention_mask": []}
|
||||
|
||||
if user_tokenizer:
|
||||
self.tokenizer = user_tokenizer
|
||||
self.user_tokenizer = True
|
||||
else:
|
||||
if not _TRANSFORMERS_AVAILABLE:
|
||||
raise ValueError(
|
||||
"`BERTScore` metric with default tokenizers requires `transformers` package be installed. "
|
||||
"Either install with `pip install transformers>=4.0` or `pip install torchmetrics[text]`"
|
||||
)
|
||||
if not model_name_or_path:
|
||||
model_name_or_path = _DEFAULT_MODEL
|
||||
warnings.warn(
|
||||
"The argument `model_name_or_path` was not specified while it is required when default "
|
||||
" `transformers` model are used."
|
||||
f"It is, therefore, used the default recommended model - {_DEFAULT_MODEL}."
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.user_tokenizer = False
|
||||
|
||||
def update(self, predictions: List[str], references: List[str]) -> None: # type: ignore
|
||||
"""Store predictions/references for computing BERT scores.
|
||||
"""Store predictions/references for computing BERT scores. It is necessary to store sentences in a
|
||||
tokenized form to ensure the DDP mode working.
|
||||
|
||||
Args:
|
||||
predictions: List of predicted sentences
|
||||
references: List of references
|
||||
predictions:
|
||||
An iterable of predicted sentences.
|
||||
references:
|
||||
An iterable of predicted sentences.
|
||||
"""
|
||||
self.predictions.append(predictions)
|
||||
self.references.append(references)
|
||||
predictions_dict = _preprocess_text(
|
||||
predictions,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
truncation=False,
|
||||
sort_according_length=False,
|
||||
own_tokenizer=self.user_tokenizer,
|
||||
)
|
||||
references_dict = _preprocess_text(
|
||||
references,
|
||||
self.tokenizer,
|
||||
self.max_length,
|
||||
truncation=False,
|
||||
sort_according_length=False,
|
||||
own_tokenizer=self.user_tokenizer,
|
||||
)
|
||||
self.predictions["input_ids"].append(predictions_dict["input_ids"])
|
||||
self.predictions["attention_mask"].append(predictions_dict["attention_mask"])
|
||||
self.references["input_ids"].append(references_dict["input_ids"])
|
||||
self.references["attention_mask"].append(references_dict["attention_mask"])
|
||||
|
||||
def compute(self) -> Dict:
|
||||
"""Calculate Bertscores.
|
||||
def compute(self) -> Dict[str, Union[List[float], str]]:
|
||||
"""Calculate BERT scores.
|
||||
|
||||
Return:
|
||||
Dict with Bertscores.
|
||||
Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values.
|
||||
"""
|
||||
return bert_score(
|
||||
predictions=_flatten(self.predictions),
|
||||
references=_flatten(self.references),
|
||||
model_type=self.model_type,
|
||||
predictions=_concatenate(self.predictions),
|
||||
references=_concatenate(self.references),
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
num_layers=self.num_layers,
|
||||
all_layers=self.all_layers,
|
||||
model=self.model,
|
||||
user_tokenizer=self.tokenizer if self.user_tokenizer else None,
|
||||
user_forward_fn=self.user_forward_fn,
|
||||
verbose=self.verbose,
|
||||
idf=self.idf,
|
||||
device=self.embedding_device,
|
||||
baseline_path=self.baseline_path,
|
||||
max_length=self.max_length,
|
||||
batch_size=self.batch_size,
|
||||
lang=self.lang,
|
||||
all_layers=self.all_layers,
|
||||
num_threads=self.num_threads,
|
||||
return_hash=self.return_hash,
|
||||
lang=self.lang,
|
||||
rescale_with_baseline=self.rescale_with_baseline,
|
||||
baseline_path=self.baseline_path,
|
||||
baseline_url=self.baseline_url,
|
||||
)
|
||||
|
|
|
@ -80,3 +80,5 @@ _BERTSCORE_AVAILABLE: bool = _module_available("bert_score")
|
|||
_SCIPY_AVAILABLE: bool = _module_available("scipy")
|
||||
_TORCH_FIDELITY_AVAILABLE: bool = _module_available("torch_fidelity")
|
||||
_LPIPS_AVAILABLE: bool = _module_available("lpips")
|
||||
_TQDM_AVAILABLE: bool = _module_available("tqdm")
|
||||
_TRANSFORMERS_AVAILABLE: bool = _module_available("transformers")
|
||||
|
|
Загрузка…
Ссылка в новой задаче