Родитель
9c8dbf68db
Коммит
be89a1b731
|
@ -44,7 +44,7 @@ help you or finish it with you :]_
|
|||
|
||||
Want to keep Torchmetrics healthy? Love seeing those green tests? So do we! How to we keep it that way?
|
||||
We write tests! We value tests contribution even more than new features. One of the core values of torchmetrics
|
||||
is that our users can trust our metric implementation. We can only garantee this if our metrics are well tested.
|
||||
is that our users can trust our metric implementation. We can only guarantee this if our metrics are well tested.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
@ -373,7 +373,7 @@ def linkcode_resolve(domain, info):
|
|||
return None
|
||||
try:
|
||||
filename = "%s#L%d-L%d" % find_source()
|
||||
except Exception:
|
||||
except Exception: # todo: specify the exception
|
||||
filename = info["module"].replace(".", "/") + ".py"
|
||||
# import subprocess
|
||||
# tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE,
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
.. PyTorchtorchmetrics documentation master file, created by
|
||||
sphinx-quickstart on Wed Mar 25 21:34:07 2020.
|
||||
.. TorchMetrics documentation master file.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ TorchMetrics is a collection of Machine learning metrics for distributed, scalab
|
|||
* Optimized for distributed-training
|
||||
* A standardized interface to increase reproducibility
|
||||
* Reduces Boilerplate
|
||||
* Distrubuted-training compatible
|
||||
* Distributed-training compatible
|
||||
* Rigorously tested
|
||||
* Automatic accumulation over batches
|
||||
* Automatic synchronization between multiple devices
|
||||
|
|
|
@ -131,7 +131,7 @@ and tests gets formatted in the following way:
|
|||
|
||||
4. Remember to add binding to the different relevant ``__init__`` files.
|
||||
|
||||
5. Testing is key to keeping ``torchmetrics`` trustworty. This is why we have a very rigid testing protocol. This means
|
||||
5. Testing is key to keeping ``torchmetrics`` trustworthy. This is why we have a very rigid testing protocol. This means
|
||||
that we in most cases require the metric to be tested against some other common framework (``sklearn``, ``scipy`` etc).
|
||||
|
||||
1. Create a testing file in ``tests/"domain"/test_"new_metric".py``. Only one file is needed as it is intended to test
|
||||
|
|
|
@ -148,7 +148,7 @@ This pattern is implemented for the following operators (with ``a`` being metric
|
|||
* Inequality (``a != b``)
|
||||
* Bitwise OR (``a | b``)
|
||||
* Power (``a ** b``)
|
||||
* Substraction (``a - b``)
|
||||
* Subtraction (``a - b``)
|
||||
* True Division (``a / b``)
|
||||
* Bitwise XOR (``a ^ b``)
|
||||
* Absolute Value (``abs(a)``)
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
numpy
|
||||
torch>=1.3.1
|
|
@ -54,7 +54,7 @@ def test_add_state():
|
|||
with pytest.raises(ValueError):
|
||||
a.add_state("d4", 42, 'sum')
|
||||
|
||||
def custom_fx(x):
|
||||
def custom_fx(_):
|
||||
return -1
|
||||
|
||||
a.add_state("e", tensor(0), custom_fx)
|
||||
|
|
|
@ -31,6 +31,7 @@ torch.manual_seed(42)
|
|||
|
||||
|
||||
def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
|
||||
# todo: `multi_class` is unused
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr)
|
||||
|
|
|
@ -35,6 +35,7 @@ torch.manual_seed(42)
|
|||
|
||||
|
||||
def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0):
|
||||
# todo: `average` is unused
|
||||
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
|
@ -42,6 +43,7 @@ def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0):
|
|||
|
||||
|
||||
def _sk_fbeta_binary(preds, target, average='micro', beta=1.0):
|
||||
# todo: `average` is unused
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
|
|
|
@ -143,12 +143,12 @@ class TestIoU(MetricTester):
|
|||
pytest.param(True, 'none', 0, Tensor([0.5, 0.5])),
|
||||
])
|
||||
def test_iou(half_ones, reduction, ignore_index, expected):
|
||||
pred = (torch.arange(120) % 3).view(-1, 1)
|
||||
preds = (torch.arange(120) % 3).view(-1, 1)
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
if half_ones:
|
||||
pred[:60] = 1
|
||||
preds[:60] = 1
|
||||
iou_val = iou(
|
||||
pred=pred,
|
||||
preds=preds,
|
||||
target=target,
|
||||
ignore_index=ignore_index,
|
||||
reduction=reduction,
|
||||
|
@ -191,7 +191,7 @@ def test_iou(half_ones, reduction, ignore_index, expected):
|
|||
)
|
||||
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
|
||||
iou_val = iou(
|
||||
pred=tensor(pred),
|
||||
preds=tensor(pred),
|
||||
target=tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
absent_score=absent_score,
|
||||
|
@ -221,7 +221,7 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
|
|||
)
|
||||
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
|
||||
iou_val = iou(
|
||||
pred=tensor(pred),
|
||||
preds=tensor(pred),
|
||||
target=tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
num_classes=num_classes,
|
||||
|
|
|
@ -36,6 +36,7 @@ torch.manual_seed(42)
|
|||
|
||||
|
||||
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
|
||||
# todo: `mdmc_average` is unused
|
||||
if average == "none":
|
||||
average = None
|
||||
if num_classes == 1:
|
||||
|
@ -211,6 +212,7 @@ class TestPrecisionRecall(MetricTester):
|
|||
mdmc_average: Optional[str],
|
||||
ignore_index: Optional[int],
|
||||
):
|
||||
# todo: `metric_fn` is unused
|
||||
if num_classes == 1 and average != "micro":
|
||||
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
|
||||
|
||||
|
@ -261,6 +263,7 @@ class TestPrecisionRecall(MetricTester):
|
|||
mdmc_average: Optional[str],
|
||||
ignore_index: Optional[int],
|
||||
):
|
||||
# todo: `metric_class` is unused
|
||||
if num_classes == 1 and average != "micro":
|
||||
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
|
||||
|
||||
|
@ -345,8 +348,7 @@ def test_top_k(
|
|||
):
|
||||
"""A simple test to check that top_k works as expected.
|
||||
|
||||
Just a sanity check, the tests in StatScores should already guarantee
|
||||
the corectness of results.
|
||||
Just a sanity check, the tests in StatScores should already guarantee the correctness of results.
|
||||
"""
|
||||
|
||||
class_metric = metric_class(top_k=k, average=average, num_classes=3)
|
||||
|
|
|
@ -35,6 +35,7 @@ torch.manual_seed(42)
|
|||
|
||||
|
||||
def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None):
|
||||
# todo: `mdmc_reduce` is unused
|
||||
preds, target, _ = _input_format_classification(
|
||||
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
|
||||
)
|
||||
|
|
|
@ -59,13 +59,13 @@ def test_to_categorical():
|
|||
assert torch.allclose(result, expected.to(result.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
|
||||
@pytest.mark.parametrize(['preds', 'target', 'num_classes', 'expected_num_classes'], [
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
])
|
||||
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
|
||||
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
||||
def test_get_num_classes(preds, target, num_classes, expected_num_classes):
|
||||
assert get_num_classes(preds, target, num_classes) == expected_num_classes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
|
||||
|
|
|
@ -66,16 +66,7 @@ def test_multi_batch_image_gradients():
|
|||
[5., 5., 5., 5., 5.],
|
||||
[0., 0., 0., 0., 0.],
|
||||
]
|
||||
|
||||
true_dx = [
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 0.],
|
||||
]
|
||||
true_dy = Tensor(true_dy)
|
||||
true_dx = Tensor(true_dx)
|
||||
|
||||
dy, dx = image_gradients(image)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ THRESHOLD = 0.5
|
|||
|
||||
|
||||
def setup_ddp(rank, world_size):
|
||||
""" Setup ddp enviroment """
|
||||
""" Setup ddp environment """
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "8088"
|
||||
|
||||
|
@ -81,7 +81,7 @@ def _class_test(
|
|||
metric_class: Metric,
|
||||
sk_metric: Callable,
|
||||
dist_sync_on_step: bool,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
check_dist_sync_on_step: bool = True,
|
||||
check_batch: bool = True,
|
||||
atol: float = 1e-8,
|
||||
|
@ -104,6 +104,8 @@ def _class_test(
|
|||
check_batch: bool, if true will check if the metric is also correctly
|
||||
calculated across devices for each batch (and not just at the end)
|
||||
"""
|
||||
if not metric_args:
|
||||
metric_args = {}
|
||||
# Instanciate lightning metric
|
||||
metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
|
||||
|
||||
|
@ -145,7 +147,7 @@ def _functional_test(
|
|||
target: Tensor,
|
||||
metric_functional: Callable,
|
||||
sk_metric: Callable,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
atol: float = 1e-8,
|
||||
):
|
||||
"""Utility function doing the actual comparison between lightning functional metric
|
||||
|
@ -158,6 +160,8 @@ def _functional_test(
|
|||
sk_metric: callable function that is used for comparison
|
||||
metric_args: dict with additional arguments used for class initialization
|
||||
"""
|
||||
if not metric_args:
|
||||
metric_args = {}
|
||||
metric = partial(metric_functional, **metric_args)
|
||||
|
||||
for i in range(NUM_BATCHES):
|
||||
|
@ -200,7 +204,7 @@ class MetricTester:
|
|||
target: Tensor,
|
||||
metric_functional: Callable,
|
||||
sk_metric: Callable,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
):
|
||||
"""Main method that should be used for testing functions. Call this inside
|
||||
testing method
|
||||
|
@ -229,7 +233,7 @@ class MetricTester:
|
|||
metric_class: Metric,
|
||||
sk_metric: Callable,
|
||||
dist_sync_on_step: bool,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
check_dist_sync_on_step: bool = True,
|
||||
check_batch: bool = True,
|
||||
):
|
||||
|
@ -250,6 +254,8 @@ class MetricTester:
|
|||
check_batch: bool, if true will check if the metric is also correctly
|
||||
calculated across devices for each batch (and not just at the end)
|
||||
"""
|
||||
if not metric_args:
|
||||
metric_args = {}
|
||||
if ddp:
|
||||
if sys.platform == "win32":
|
||||
pytest.skip("DDP not supported on windows")
|
||||
|
|
|
@ -68,7 +68,8 @@ class BoringModel(LightningModule):
|
|||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
def loss(self, batch, prediction):
|
||||
@staticmethod
|
||||
def loss(_, prediction):
|
||||
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
|
||||
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ class TestMeanError(MetricTester):
|
|||
def test_mean_error_class(
|
||||
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step
|
||||
):
|
||||
# todo: `metric_functional` is unused
|
||||
self.run_class_metric_test(
|
||||
ddp=ddp,
|
||||
preds=preds,
|
||||
|
@ -85,6 +86,7 @@ class TestMeanError(MetricTester):
|
|||
)
|
||||
|
||||
def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
|
||||
# todo: `metric_class` is unused
|
||||
self.run_functional_metric_test(
|
||||
preds=preds,
|
||||
target=target,
|
||||
|
|
|
@ -82,6 +82,7 @@ class TestR2Score(MetricTester):
|
|||
)
|
||||
|
||||
def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
|
||||
# todo: `num_outputs` is unused
|
||||
self.run_functional_metric_test(
|
||||
preds,
|
||||
target,
|
||||
|
@ -102,14 +103,14 @@ def test_error_on_multidim_tensors(metric_class=R2Score):
|
|||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r'Expected both prediction and target to be 1D or 2D tensors,'
|
||||
r' but recevied tensors with dimension .'
|
||||
r' but received tensors with dimension .'
|
||||
):
|
||||
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))
|
||||
|
||||
|
||||
def test_error_on_too_few_samples(metric_class=R2Score):
|
||||
metric = metric_class()
|
||||
with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'):
|
||||
with pytest.raises(ValueError, match='Needs at least two samples to calculate r2 score.'):
|
||||
metric(torch.randn(1, ), torch.randn(1, ))
|
||||
|
||||
|
||||
|
@ -118,7 +119,7 @@ def test_warning_on_too_large_adjusted(metric_class=R2Score):
|
|||
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="More independent regressions than datapoints in"
|
||||
match="More independent regressions than data points in"
|
||||
" adjusted r2 score. Falls back to standard r2 score."
|
||||
):
|
||||
metric(torch.randn(10, ), torch.randn(10, ))
|
||||
|
|
|
@ -31,7 +31,7 @@ class AUC(Metric):
|
|||
Args:
|
||||
reorder: AUC expects its first input to be sorted. If this is not the case,
|
||||
setting this argument to ``True`` will use a stable sorting algorithm to
|
||||
sort the input in decending order
|
||||
sort the input in descending order
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False.
|
||||
dist_sync_on_step:
|
||||
|
@ -40,8 +40,8 @@ class AUC(Metric):
|
|||
process_group:
|
||||
Specify the process group on which synchronization is called. default: None (which selects the entire world)
|
||||
dist_sync_fn:
|
||||
Callback that performs the allgather operation on the metric state. When ``None``, DDP
|
||||
will be used to perform the allgather
|
||||
Callback that performs the ``allgather`` operation on the metric state. When ``None``, DDP
|
||||
will be used to perform the ``allgather``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -118,7 +118,7 @@ class AUROC(Metric):
|
|||
)
|
||||
|
||||
if self.max_fpr is not None:
|
||||
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
|
||||
if not isinstance(max_fpr, float) or not 0 < max_fpr <= 1:
|
||||
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
|
||||
|
||||
if _TORCH_LOWER_1_6:
|
||||
|
|
|
@ -135,7 +135,7 @@ class ROC(Metric):
|
|||
tensor with true positive rates.
|
||||
If multiclass, this is a list of such tensors, one for each class.
|
||||
thresholds:
|
||||
thresholds used for computing false- and true postive rates
|
||||
thresholds used for computing false- and true positive rates
|
||||
|
||||
"""
|
||||
preds = torch.cat(self.preds, dim=0)
|
||||
|
|
|
@ -137,7 +137,8 @@ class MetricCollection(nn.ModuleDict):
|
|||
def _set_prefix(self, k: str) -> str:
|
||||
return k if self.prefix is None else self.prefix + k
|
||||
|
||||
def _check_prefix_arg(self, prefix: str) -> Optional[str]:
|
||||
@staticmethod
|
||||
def _check_prefix_arg(prefix: str) -> Optional[str]:
|
||||
if prefix is not None:
|
||||
if isinstance(prefix, str):
|
||||
return prefix
|
||||
|
|
|
@ -29,6 +29,7 @@ def _accuracy_update(
|
|||
) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
|
||||
correct, total = None, None
|
||||
|
||||
if mode == DataType.MULTILABEL and top_k:
|
||||
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
|
||||
|
|
|
@ -55,7 +55,7 @@ def _auroc_compute(
|
|||
|
||||
# check max_fpr parameter
|
||||
if max_fpr is not None:
|
||||
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
|
||||
if not isinstance(max_fpr, float) and 0 < max_fpr <= 1:
|
||||
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
|
||||
|
||||
if _TORCH_LOWER_1_6:
|
||||
|
@ -157,7 +157,7 @@ def auroc(
|
|||
max_fpr:
|
||||
If not ``None``, calculates standardized partial AUC over the
|
||||
range [0, max_fpr]. Should be a float between 0 and 1.
|
||||
sample_weight: sample weights for each data point
|
||||
sample_weights: sample weights for each data point
|
||||
|
||||
Example (binary case):
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ def _average_precision_compute(
|
|||
pos_label: int,
|
||||
sample_weights: Optional[Sequence] = None,
|
||||
) -> Union[List[Tensor], Tensor]:
|
||||
# todo: `sample_weights` is unused
|
||||
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
|
||||
# Return the step function integral
|
||||
# The following works because the last entry of precision is
|
||||
|
|
|
@ -50,7 +50,11 @@ def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tens
|
|||
|
||||
|
||||
def cohen_kappa(
|
||||
preds: Tensor, target: Tensor, num_classes: int, weights: Optional[str] = None, threshold: float = 0.5
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: int,
|
||||
weights: Optional[str] = None,
|
||||
threshold: float = 0.5,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Calculates `Cohen's kappa score <https://en.wikipedia.org/wiki/Cohen%27s_kappa>`_ that measures
|
||||
|
|
|
@ -38,6 +38,7 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None)
|
|||
f"Argument average needs to one of the following: {allowed_normalize}"
|
||||
confmat = confmat.float()
|
||||
if normalize is not None and normalize != 'none':
|
||||
cm = None
|
||||
if normalize == 'true':
|
||||
cm = confmat / confmat.sum(axis=1, keepdim=True)
|
||||
elif normalize == 'pred':
|
||||
|
|
|
@ -36,7 +36,7 @@ def _iou_from_confmat(
|
|||
scores[union == 0] = absent_score
|
||||
|
||||
# Remove the ignored class index from the scores.
|
||||
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
|
||||
if ignore_index is not None and 0 <= ignore_index < num_classes:
|
||||
scores = torch.cat([
|
||||
scores[:ignore_index],
|
||||
scores[ignore_index + 1:],
|
||||
|
@ -45,7 +45,7 @@ def _iou_from_confmat(
|
|||
|
||||
|
||||
def iou(
|
||||
pred: Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
ignore_index: Optional[int] = None,
|
||||
absent_score: float = 0.0,
|
||||
|
@ -107,6 +107,6 @@ def iou(
|
|||
|
||||
"""
|
||||
|
||||
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
|
||||
confmat = _confusion_matrix_update(pred, target, num_classes, threshold)
|
||||
num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes)
|
||||
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
|
||||
return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -28,6 +28,7 @@ def _precision_compute(
|
|||
average: str,
|
||||
mdmc_average: Optional[str],
|
||||
) -> Tensor:
|
||||
# todo: `tn` is unused
|
||||
return _reduce_stat_scores(
|
||||
numerator=tp,
|
||||
denominator=tp + fp,
|
||||
|
@ -178,6 +179,8 @@ def _recall_compute(
|
|||
average: str,
|
||||
mdmc_average: Optional[str],
|
||||
) -> Tensor:
|
||||
# todo: `tp` is unused
|
||||
# todo: `tn` is unused
|
||||
return _reduce_stat_scores(
|
||||
numerator=tp,
|
||||
denominator=tp + fn,
|
||||
|
@ -330,7 +333,7 @@ def precision_recall(
|
|||
threshold: float = 0.5,
|
||||
top_k: Optional[int] = None,
|
||||
is_multiclass: Optional[bool] = None,
|
||||
) -> Tensor:
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Computes `Precision and Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import functional as F
|
||||
from torch import Tensor, tensor
|
||||
|
||||
from torchmetrics.utilities import rank_zero_warn
|
||||
|
@ -49,7 +49,7 @@ def _binary_clf_curve(
|
|||
# the indices associated with the distinct values. We also
|
||||
# concatenate a value for the end of the curve.
|
||||
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
|
||||
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
|
||||
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
|
||||
target = (target == pos_label).to(torch.long)
|
||||
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
|
||||
|
||||
|
|
|
@ -97,4 +97,11 @@ def explained_variance(
|
|||
tensor([0.9677, 1.0000])
|
||||
"""
|
||||
n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target = _explained_variance_update(preds, target)
|
||||
return _explained_variance_compute(n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target, multioutput)
|
||||
return _explained_variance_compute(
|
||||
n_obs,
|
||||
sum_error,
|
||||
sum_squared_error,
|
||||
sum_target,
|
||||
sum_squared_target,
|
||||
multioutput,
|
||||
)
|
||||
|
|
|
@ -35,7 +35,7 @@ def mean_absolute_error(preds: Tensor, target: Tensor) -> Tensor:
|
|||
Computes mean absolute error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
preds: estimated labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
|
|
|
@ -37,7 +37,7 @@ def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor:
|
|||
Computes mean relative error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
preds: estimated labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
|
|
|
@ -25,10 +25,10 @@ def _r2score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tens
|
|||
if preds.ndim > 2:
|
||||
raise ValueError(
|
||||
'Expected both prediction and target to be 1D or 2D tensors,'
|
||||
f' but recevied tensors with dimension {preds.shape}'
|
||||
f' but received tensors with dimension {preds.shape}'
|
||||
)
|
||||
if len(preds) < 2:
|
||||
raise ValueError('Needs atleast two samples to calculate r2 score.')
|
||||
raise ValueError('Needs at least two samples to calculate r2 score.')
|
||||
|
||||
sum_error = torch.sum(target, dim=0)
|
||||
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
|
||||
|
@ -69,7 +69,7 @@ def _r2score_compute(
|
|||
if adjusted != 0:
|
||||
if adjusted > total - 1:
|
||||
rank_zero_warn(
|
||||
"More independent regressions than datapoints in"
|
||||
"More independent regressions than data points in"
|
||||
" adjusted r2 score. Falls back to standard r2 score.", UserWarning
|
||||
)
|
||||
elif adjusted == total - 1:
|
||||
|
|
|
@ -95,7 +95,7 @@ class Metric(nn.Module, ABC):
|
|||
name: The name of the state variable. The variable will then be accessible at ``self.name``.
|
||||
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
|
||||
reset to this value when ``self.reset()`` is called.
|
||||
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
|
||||
dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode.
|
||||
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
|
||||
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
|
||||
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
|
||||
|
@ -187,7 +187,7 @@ class Metric(nn.Module, ABC):
|
|||
elif isinstance(output_dict[attr][0], list):
|
||||
output_dict[attr] = _flatten(output_dict[attr])
|
||||
|
||||
assert isinstance(reduction_fn, (Callable)) or reduction_fn is None
|
||||
assert isinstance(reduction_fn, Callable) or reduction_fn is None
|
||||
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
|
||||
setattr(self, attr, reduced)
|
||||
|
||||
|
@ -214,6 +214,7 @@ class Metric(nn.Module, ABC):
|
|||
dist_sync_fn = gather_all_tensors
|
||||
|
||||
synced = False
|
||||
cache = []
|
||||
if self._to_sync and dist_sync_fn is not None:
|
||||
# cache prior to syncing
|
||||
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
|
||||
|
@ -276,20 +277,20 @@ class Metric(nn.Module, ABC):
|
|||
"""Overwrite _apply function such that we can also move metric states
|
||||
to the correct device when `.to`, `.cuda`, etc methods are called
|
||||
"""
|
||||
self = super()._apply(fn)
|
||||
this = super()._apply(fn)
|
||||
# Also apply fn to metric states
|
||||
for key in self._defaults.keys():
|
||||
current_val = getattr(self, key)
|
||||
for key in this._defaults.keys():
|
||||
current_val = getattr(this, key)
|
||||
if isinstance(current_val, Tensor):
|
||||
setattr(self, key, fn(current_val))
|
||||
setattr(this, key, fn(current_val))
|
||||
elif isinstance(current_val, Sequence):
|
||||
setattr(self, key, [fn(cur_v) for cur_v in current_val])
|
||||
setattr(this, key, [fn(cur_v) for cur_v in current_val])
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected metric state to be either a Tensor"
|
||||
f"or a list of Tensor, but encountered {current_val}"
|
||||
)
|
||||
return self
|
||||
return this
|
||||
|
||||
def persistent(self, mode: bool = False):
|
||||
"""Method for post-init to change if metric states should be saved to
|
||||
|
@ -449,7 +450,7 @@ def _neg(tensor: Tensor):
|
|||
|
||||
|
||||
class CompositionalMetric(Metric):
|
||||
"""Composition of two metrics with a specific operator which will be executed upon metric's compute """
|
||||
"""Composition of two metrics with a specific operator which will be executed upon metrics compute """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -22,7 +22,7 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme
|
|||
"""Load requirements from a file
|
||||
|
||||
>>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
['torch...']
|
||||
['numpy...', 'torch...']
|
||||
"""
|
||||
with open(os.path.join(path_dir, file_name), 'r') as file:
|
||||
lines = [ln.strip() for ln in file.readlines()]
|
||||
|
|
|
@ -221,8 +221,8 @@ def _check_classification_inputs(
|
|||
|
||||
In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval.
|
||||
|
||||
When ``num_classes`` is given, it is checked that it is consitent with input cases (binary,
|
||||
multi-label, ...), and that, if availible, the implied number of classes in the ``C``
|
||||
When ``num_classes`` is given, it is checked that it is consistent with input cases (binary,
|
||||
multi-label, ...), and that, if available, the implied number of classes in the ``C``
|
||||
dimension is consistent with it (as well as that max label in target is smaller than it).
|
||||
|
||||
When ``num_classes`` is not specified in these cases, consistency of the highest target
|
||||
|
@ -242,13 +242,13 @@ def _check_classification_inputs(
|
|||
Threshold probability value for transforming probability predictions to binary
|
||||
(0,1) predictions, in the case of binary or multi-label inputs.
|
||||
num_classes:
|
||||
Number of classes. If not explicitly set, the number of classes will be infered
|
||||
Number of classes. If not explicitly set, the number of classes will be inferred
|
||||
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
|
||||
tensor, where applicable.
|
||||
top_k:
|
||||
Number of highest probability entries for each sample to convert to 1s - relevant
|
||||
only for inputs with probability predictions. The default value (``None``) will be
|
||||
interepreted as 1 for these inputs. If this parameter is set for multi-label inputs,
|
||||
interpreted as 1 for these inputs. If this parameter is set for multi-label inputs,
|
||||
it will take precedence over threshold.
|
||||
|
||||
Should be left unset (``None``) for inputs with label predictions.
|
||||
|
@ -264,7 +264,7 @@ def _check_classification_inputs(
|
|||
'multi-dim multi-class'
|
||||
"""
|
||||
|
||||
# Baisc validation (that does not need case/type information)
|
||||
# Basic validation (that does not need case/type information)
|
||||
_basic_input_validation(preds, target, threshold, is_multiclass)
|
||||
|
||||
# Check that shape/types fall into one of the cases
|
||||
|
@ -273,7 +273,7 @@ def _check_classification_inputs(
|
|||
# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
|
||||
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
|
||||
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
|
||||
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")
|
||||
raise ValueError("Probabilities in `preds` must sum up to 1 across the `C` dimension.")
|
||||
|
||||
# Check consistency with the `C` dimension in case of multi-class data
|
||||
if preds.shape != target.shape:
|
||||
|
@ -370,7 +370,7 @@ def _input_format_classification(
|
|||
Threshold probability value for transforming probability predictions to binary
|
||||
(0 or 1) predictions, in the case of binary or multi-label inputs.
|
||||
num_classes:
|
||||
Number of classes. If not explicitly set, the number of classes will be infered
|
||||
Number of classes. If not explicitly set, the number of classes will be inferred
|
||||
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
|
||||
tensor, where applicable.
|
||||
top_k:
|
||||
|
@ -438,7 +438,7 @@ def _input_format_classification(
|
|||
target = target.reshape(target.shape[0], -1)
|
||||
preds = preds.reshape(preds.shape[0], -1)
|
||||
|
||||
# Some operatins above create an extra dimension for MC/binary case - this removes it
|
||||
# Some operations above create an extra dimension for MC/binary case - this removes it
|
||||
if preds.ndim > 2:
|
||||
preds, target = preds.squeeze(-1), target.squeeze(-1)
|
||||
|
||||
|
@ -469,7 +469,7 @@ def _input_format_classification_one_hot(
|
|||
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
|
||||
|
||||
if preds.ndim == target.ndim + 1:
|
||||
# multi class probabilites
|
||||
# multi class probabilities
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
|
||||
if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel:
|
||||
|
@ -478,7 +478,7 @@ def _input_format_classification_one_hot(
|
|||
target = to_onehot(target, num_classes=num_classes)
|
||||
|
||||
elif preds.ndim == target.ndim and preds.is_floating_point():
|
||||
# binary or multilabel probablities
|
||||
# binary or multilabel probabilities
|
||||
preds = (preds >= threshold).long()
|
||||
|
||||
# transpose class as first dim and reshape
|
||||
|
|
|
@ -122,7 +122,7 @@ def to_categorical(tensor: Tensor, argmax_dim: int = 1) -> Tensor:
|
|||
|
||||
|
||||
def get_num_classes(
|
||||
pred: Tensor,
|
||||
preds: Tensor,
|
||||
target: Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> int:
|
||||
|
@ -130,7 +130,7 @@ def get_num_classes(
|
|||
Calculates the number of classes for a given prediction and target tensor.
|
||||
|
||||
Args:
|
||||
pred: predicted values
|
||||
preds: predicted values
|
||||
target: true labels
|
||||
num_classes: number of classes if known
|
||||
|
||||
|
@ -138,7 +138,7 @@ def get_num_classes(
|
|||
An integer that represents the number of classes.
|
||||
"""
|
||||
num_target_classes = int(target.max().detach().item() + 1)
|
||||
num_pred_classes = int(pred.max().detach().item() + 1)
|
||||
num_pred_classes = int(preds.max().detach().item() + 1)
|
||||
num_all_classes = max(num_target_classes, num_pred_classes)
|
||||
|
||||
if num_classes is None:
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class EnumStr(str, Enum):
|
||||
|
@ -28,7 +28,7 @@ class EnumStr(str, Enum):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value: str) -> 'EnumStr':
|
||||
def from_str(cls, value: str) -> Optional['EnumStr']:
|
||||
statuses = [status for status in dir(cls) if not status.startswith('_')]
|
||||
for st in statuses:
|
||||
if st.lower() == value.lower():
|
||||
|
@ -63,7 +63,9 @@ class AverageMethod(EnumStr):
|
|||
|
||||
>>> None in list(AverageMethod)
|
||||
True
|
||||
>>> 'none' == AverageMethod.NONE == None
|
||||
>>> AverageMethod.NONE == None
|
||||
True
|
||||
>>> AverageMethod.NONE == 'none'
|
||||
True
|
||||
"""
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче