* typos

* code

* format
This commit is contained in:
Jirka Borovec 2021-03-18 20:56:43 +01:00 коммит произвёл GitHub
Родитель 9c8dbf68db
Коммит be89a1b731
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
40 изменённых файлов: 113 добавлений и 85 удалений

2
.github/CONTRIBUTING.md поставляемый
Просмотреть файл

@ -44,7 +44,7 @@ help you or finish it with you :]_
Want to keep Torchmetrics healthy? Love seeing those green tests? So do we! How to we keep it that way?
We write tests! We value tests contribution even more than new features. One of the core values of torchmetrics
is that our users can trust our metric implementation. We can only garantee this if our metrics are well tested.
is that our users can trust our metric implementation. We can only guarantee this if our metrics are well tested.
---

Просмотреть файл

@ -373,7 +373,7 @@ def linkcode_resolve(domain, info):
return None
try:
filename = "%s#L%d-L%d" % find_source()
except Exception:
except Exception: # todo: specify the exception
filename = info["module"].replace(".", "/") + ".py"
# import subprocess
# tag = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE,

Просмотреть файл

@ -1,5 +1,4 @@
.. PyTorchtorchmetrics documentation master file, created by
sphinx-quickstart on Wed Mar 25 21:34:07 2020.
.. TorchMetrics documentation master file.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.

Просмотреть файл

@ -3,7 +3,7 @@ TorchMetrics is a collection of Machine learning metrics for distributed, scalab
* Optimized for distributed-training
* A standardized interface to increase reproducibility
* Reduces Boilerplate
* Distrubuted-training compatible
* Distributed-training compatible
* Rigorously tested
* Automatic accumulation over batches
* Automatic synchronization between multiple devices

Просмотреть файл

@ -131,7 +131,7 @@ and tests gets formatted in the following way:
4. Remember to add binding to the different relevant ``__init__`` files.
5. Testing is key to keeping ``torchmetrics`` trustworty. This is why we have a very rigid testing protocol. This means
5. Testing is key to keeping ``torchmetrics`` trustworthy. This is why we have a very rigid testing protocol. This means
that we in most cases require the metric to be tested against some other common framework (``sklearn``, ``scipy`` etc).
1. Create a testing file in ``tests/"domain"/test_"new_metric".py``. Only one file is needed as it is intended to test

Просмотреть файл

@ -148,7 +148,7 @@ This pattern is implemented for the following operators (with ``a`` being metric
* Inequality (``a != b``)
* Bitwise OR (``a | b``)
* Power (``a ** b``)
* Substraction (``a - b``)
* Subtraction (``a - b``)
* True Division (``a / b``)
* Bitwise XOR (``a ^ b``)
* Absolute Value (``abs(a)``)

Просмотреть файл

@ -1 +1,2 @@
numpy
torch>=1.3.1

Просмотреть файл

@ -54,7 +54,7 @@ def test_add_state():
with pytest.raises(ValueError):
a.add_state("d4", 42, 'sum')
def custom_fx(x):
def custom_fx(_):
return -1
a.add_state("e", tensor(0), custom_fx)

Просмотреть файл

@ -31,6 +31,7 @@ torch.manual_seed(42)
def _sk_auroc_binary_prob(preds, target, num_classes, average='macro', max_fpr=None, multi_class='ovr'):
# todo: `multi_class` is unused
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr)

Просмотреть файл

@ -35,6 +35,7 @@ torch.manual_seed(42)
def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0):
# todo: `average` is unused
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
@ -42,6 +43,7 @@ def _sk_fbeta_binary_prob(preds, target, average='micro', beta=1.0):
def _sk_fbeta_binary(preds, target, average='micro', beta=1.0):
# todo: `average` is unused
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

Просмотреть файл

@ -143,12 +143,12 @@ class TestIoU(MetricTester):
pytest.param(True, 'none', 0, Tensor([0.5, 0.5])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
preds = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
preds[:60] = 1
iou_val = iou(
pred=pred,
preds=preds,
target=target,
ignore_index=ignore_index,
reduction=reduction,
@ -191,7 +191,7 @@ def test_iou(half_ones, reduction, ignore_index, expected):
)
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
iou_val = iou(
pred=tensor(pred),
preds=tensor(pred),
target=tensor(target),
ignore_index=ignore_index,
absent_score=absent_score,
@ -221,7 +221,7 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
)
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=tensor(pred),
preds=tensor(pred),
target=tensor(target),
ignore_index=ignore_index,
num_classes=num_classes,

Просмотреть файл

@ -36,6 +36,7 @@ torch.manual_seed(42)
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
# todo: `mdmc_average` is unused
if average == "none":
average = None
if num_classes == 1:
@ -211,6 +212,7 @@ class TestPrecisionRecall(MetricTester):
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
# todo: `metric_fn` is unused
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
@ -261,6 +263,7 @@ class TestPrecisionRecall(MetricTester):
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
# todo: `metric_class` is unused
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
@ -345,8 +348,7 @@ def test_top_k(
):
"""A simple test to check that top_k works as expected.
Just a sanity check, the tests in StatScores should already guarantee
the corectness of results.
Just a sanity check, the tests in StatScores should already guarantee the correctness of results.
"""
class_metric = metric_class(top_k=k, average=average, num_classes=3)

Просмотреть файл

@ -35,6 +35,7 @@ torch.manual_seed(42)
def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None):
# todo: `mdmc_reduce` is unused
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k
)

Просмотреть файл

@ -59,13 +59,13 @@ def test_to_categorical():
assert torch.allclose(result, expected.to(result.dtype))
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
@pytest.mark.parametrize(['preds', 'target', 'num_classes', 'expected_num_classes'], [
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
])
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
assert get_num_classes(pred, target, num_classes) == expected_num_classes
def test_get_num_classes(preds, target, num_classes, expected_num_classes):
assert get_num_classes(preds, target, num_classes) == expected_num_classes
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [

Просмотреть файл

@ -66,16 +66,7 @@ def test_multi_batch_image_gradients():
[5., 5., 5., 5., 5.],
[0., 0., 0., 0., 0.],
]
true_dx = [
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
]
true_dy = Tensor(true_dy)
true_dx = Tensor(true_dx)
dy, dx = image_gradients(image)

Просмотреть файл

@ -39,7 +39,7 @@ THRESHOLD = 0.5
def setup_ddp(rank, world_size):
""" Setup ddp enviroment """
""" Setup ddp environment """
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "8088"
@ -81,7 +81,7 @@ def _class_test(
metric_class: Metric,
sk_metric: Callable,
dist_sync_on_step: bool,
metric_args: dict = {},
metric_args: dict = None,
check_dist_sync_on_step: bool = True,
check_batch: bool = True,
atol: float = 1e-8,
@ -104,6 +104,8 @@ def _class_test(
check_batch: bool, if true will check if the metric is also correctly
calculated across devices for each batch (and not just at the end)
"""
if not metric_args:
metric_args = {}
# Instanciate lightning metric
metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
@ -145,7 +147,7 @@ def _functional_test(
target: Tensor,
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict = {},
metric_args: dict = None,
atol: float = 1e-8,
):
"""Utility function doing the actual comparison between lightning functional metric
@ -158,6 +160,8 @@ def _functional_test(
sk_metric: callable function that is used for comparison
metric_args: dict with additional arguments used for class initialization
"""
if not metric_args:
metric_args = {}
metric = partial(metric_functional, **metric_args)
for i in range(NUM_BATCHES):
@ -200,7 +204,7 @@ class MetricTester:
target: Tensor,
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict = {},
metric_args: dict = None,
):
"""Main method that should be used for testing functions. Call this inside
testing method
@ -229,7 +233,7 @@ class MetricTester:
metric_class: Metric,
sk_metric: Callable,
dist_sync_on_step: bool,
metric_args: dict = {},
metric_args: dict = None,
check_dist_sync_on_step: bool = True,
check_batch: bool = True,
):
@ -250,6 +254,8 @@ class MetricTester:
check_batch: bool, if true will check if the metric is also correctly
calculated across devices for each batch (and not just at the end)
"""
if not metric_args:
metric_args = {}
if ddp:
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")

Просмотреть файл

@ -68,7 +68,8 @@ class BoringModel(LightningModule):
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
@staticmethod
def loss(_, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

Просмотреть файл

@ -75,6 +75,7 @@ class TestMeanError(MetricTester):
def test_mean_error_class(
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step
):
# todo: `metric_functional` is unused
self.run_class_metric_test(
ddp=ddp,
preds=preds,
@ -85,6 +86,7 @@ class TestMeanError(MetricTester):
)
def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
# todo: `metric_class` is unused
self.run_functional_metric_test(
preds=preds,
target=target,

Просмотреть файл

@ -82,6 +82,7 @@ class TestR2Score(MetricTester):
)
def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs):
# todo: `num_outputs` is unused
self.run_functional_metric_test(
preds,
target,
@ -102,14 +103,14 @@ def test_error_on_multidim_tensors(metric_class=R2Score):
with pytest.raises(
ValueError,
match=r'Expected both prediction and target to be 1D or 2D tensors,'
r' but recevied tensors with dimension .'
r' but received tensors with dimension .'
):
metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5))
def test_error_on_too_few_samples(metric_class=R2Score):
metric = metric_class()
with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'):
with pytest.raises(ValueError, match='Needs at least two samples to calculate r2 score.'):
metric(torch.randn(1, ), torch.randn(1, ))
@ -118,7 +119,7 @@ def test_warning_on_too_large_adjusted(metric_class=R2Score):
with pytest.warns(
UserWarning,
match="More independent regressions than datapoints in"
match="More independent regressions than data points in"
" adjusted r2 score. Falls back to standard r2 score."
):
metric(torch.randn(10, ), torch.randn(10, ))

Просмотреть файл

@ -31,7 +31,7 @@ class AUC(Metric):
Args:
reorder: AUC expects its first input to be sorted. If this is not the case,
setting this argument to ``True`` will use a stable sorting algorithm to
sort the input in decending order
sort the input in descending order
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
@ -40,8 +40,8 @@ class AUC(Metric):
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Callback that performs the ``allgather`` operation on the metric state. When ``None``, DDP
will be used to perform the ``allgather``.
"""
def __init__(

Просмотреть файл

@ -118,7 +118,7 @@ class AUROC(Metric):
)
if self.max_fpr is not None:
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
if not isinstance(max_fpr, float) or not 0 < max_fpr <= 1:
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
if _TORCH_LOWER_1_6:

Просмотреть файл

@ -135,7 +135,7 @@ class ROC(Metric):
tensor with true positive rates.
If multiclass, this is a list of such tensors, one for each class.
thresholds:
thresholds used for computing false- and true postive rates
thresholds used for computing false- and true positive rates
"""
preds = torch.cat(self.preds, dim=0)

Просмотреть файл

@ -137,7 +137,8 @@ class MetricCollection(nn.ModuleDict):
def _set_prefix(self, k: str) -> str:
return k if self.prefix is None else self.prefix + k
def _check_prefix_arg(self, prefix: str) -> Optional[str]:
@staticmethod
def _check_prefix_arg(prefix: str) -> Optional[str]:
if prefix is not None:
if isinstance(prefix, str):
return prefix

Просмотреть файл

@ -29,6 +29,7 @@ def _accuracy_update(
) -> Tuple[Tensor, Tensor]:
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
correct, total = None, None
if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")

Просмотреть файл

@ -55,7 +55,7 @@ def _auroc_compute(
# check max_fpr parameter
if max_fpr is not None:
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
if not isinstance(max_fpr, float) and 0 < max_fpr <= 1:
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
if _TORCH_LOWER_1_6:
@ -157,7 +157,7 @@ def auroc(
max_fpr:
If not ``None``, calculates standardized partial AUC over the
range [0, max_fpr]. Should be a float between 0 and 1.
sample_weight: sample weights for each data point
sample_weights: sample weights for each data point
Example (binary case):

Просмотреть файл

@ -38,6 +38,7 @@ def _average_precision_compute(
pos_label: int,
sample_weights: Optional[Sequence] = None,
) -> Union[List[Tensor], Tensor]:
# todo: `sample_weights` is unused
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
# Return the step function integral
# The following works because the last entry of precision is

Просмотреть файл

@ -50,7 +50,11 @@ def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tens
def cohen_kappa(
preds: Tensor, target: Tensor, num_classes: int, weights: Optional[str] = None, threshold: float = 0.5
preds: Tensor,
target: Tensor,
num_classes: int,
weights: Optional[str] = None,
threshold: float = 0.5,
) -> Tensor:
r"""
Calculates `Cohen's kappa score <https://en.wikipedia.org/wiki/Cohen%27s_kappa>`_ that measures

Просмотреть файл

@ -38,6 +38,7 @@ def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None)
f"Argument average needs to one of the following: {allowed_normalize}"
confmat = confmat.float()
if normalize is not None and normalize != 'none':
cm = None
if normalize == 'true':
cm = confmat / confmat.sum(axis=1, keepdim=True)
elif normalize == 'pred':

Просмотреть файл

@ -36,7 +36,7 @@ def _iou_from_confmat(
scores[union == 0] = absent_score
# Remove the ignored class index from the scores.
if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes:
if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat([
scores[:ignore_index],
scores[ignore_index + 1:],
@ -45,7 +45,7 @@ def _iou_from_confmat(
def iou(
pred: Tensor,
preds: Tensor,
target: Tensor,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
@ -107,6 +107,6 @@ def iou(
"""
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
confmat = _confusion_matrix_update(pred, target, num_classes, threshold)
num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes)
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)

Просмотреть файл

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple
import torch
from torch import Tensor
@ -28,6 +28,7 @@ def _precision_compute(
average: str,
mdmc_average: Optional[str],
) -> Tensor:
# todo: `tn` is unused
return _reduce_stat_scores(
numerator=tp,
denominator=tp + fp,
@ -178,6 +179,8 @@ def _recall_compute(
average: str,
mdmc_average: Optional[str],
) -> Tensor:
# todo: `tp` is unused
# todo: `tn` is unused
return _reduce_stat_scores(
numerator=tp,
denominator=tp + fn,
@ -330,7 +333,7 @@ def precision_recall(
threshold: float = 0.5,
top_k: Optional[int] = None,
is_multiclass: Optional[bool] = None,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
r"""
Computes `Precision and Recall <https://en.wikipedia.org/wiki/Precision_and_recall>`_:

Просмотреть файл

@ -14,7 +14,7 @@
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import functional as F
from torch import Tensor, tensor
from torchmetrics.utilities import rank_zero_warn
@ -49,7 +49,7 @@ def _binary_clf_curve(
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
threshold_idxs = F.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]

Просмотреть файл

@ -97,4 +97,11 @@ def explained_variance(
tensor([0.9677, 1.0000])
"""
n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target = _explained_variance_update(preds, target)
return _explained_variance_compute(n_obs, sum_error, sum_squared_error, sum_target, sum_squared_target, multioutput)
return _explained_variance_compute(
n_obs,
sum_error,
sum_squared_error,
sum_target,
sum_squared_target,
multioutput,
)

Просмотреть файл

@ -35,7 +35,7 @@ def mean_absolute_error(preds: Tensor, target: Tensor) -> Tensor:
Computes mean absolute error
Args:
pred: estimated labels
preds: estimated labels
target: ground truth labels
Return:

Просмотреть файл

@ -37,7 +37,7 @@ def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor:
Computes mean relative error
Args:
pred: estimated labels
preds: estimated labels
target: ground truth labels
Return:

Просмотреть файл

@ -25,10 +25,10 @@ def _r2score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tens
if preds.ndim > 2:
raise ValueError(
'Expected both prediction and target to be 1D or 2D tensors,'
f' but recevied tensors with dimension {preds.shape}'
f' but received tensors with dimension {preds.shape}'
)
if len(preds) < 2:
raise ValueError('Needs atleast two samples to calculate r2 score.')
raise ValueError('Needs at least two samples to calculate r2 score.')
sum_error = torch.sum(target, dim=0)
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
@ -69,7 +69,7 @@ def _r2score_compute(
if adjusted != 0:
if adjusted > total - 1:
rank_zero_warn(
"More independent regressions than datapoints in"
"More independent regressions than data points in"
" adjusted r2 score. Falls back to standard r2 score.", UserWarning
)
elif adjusted == total - 1:

Просмотреть файл

@ -95,7 +95,7 @@ class Metric(nn.Module, ABC):
name: The name of the state variable. The variable will then be accessible at ``self.name``.
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
dist_reduce_fx (Optional): Function to reduce state across multiple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
@ -187,7 +187,7 @@ class Metric(nn.Module, ABC):
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])
assert isinstance(reduction_fn, (Callable)) or reduction_fn is None
assert isinstance(reduction_fn, Callable) or reduction_fn is None
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
setattr(self, attr, reduced)
@ -214,6 +214,7 @@ class Metric(nn.Module, ABC):
dist_sync_fn = gather_all_tensors
synced = False
cache = []
if self._to_sync and dist_sync_fn is not None:
# cache prior to syncing
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
@ -276,20 +277,20 @@ class Metric(nn.Module, ABC):
"""Overwrite _apply function such that we can also move metric states
to the correct device when `.to`, `.cuda`, etc methods are called
"""
self = super()._apply(fn)
this = super()._apply(fn)
# Also apply fn to metric states
for key in self._defaults.keys():
current_val = getattr(self, key)
for key in this._defaults.keys():
current_val = getattr(this, key)
if isinstance(current_val, Tensor):
setattr(self, key, fn(current_val))
setattr(this, key, fn(current_val))
elif isinstance(current_val, Sequence):
setattr(self, key, [fn(cur_v) for cur_v in current_val])
setattr(this, key, [fn(cur_v) for cur_v in current_val])
else:
raise TypeError(
"Expected metric state to be either a Tensor"
f"or a list of Tensor, but encountered {current_val}"
)
return self
return this
def persistent(self, mode: bool = False):
"""Method for post-init to change if metric states should be saved to
@ -449,7 +450,7 @@ def _neg(tensor: Tensor):
class CompositionalMetric(Metric):
"""Composition of two metrics with a specific operator which will be executed upon metric's compute """
"""Composition of two metrics with a specific operator which will be executed upon metrics compute """
def __init__(
self,

Просмотреть файл

@ -22,7 +22,7 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme
"""Load requirements from a file
>>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['torch...']
['numpy...', 'torch...']
"""
with open(os.path.join(path_dir, file_name), 'r') as file:
lines = [ln.strip() for ln in file.readlines()]

Просмотреть файл

@ -221,8 +221,8 @@ def _check_classification_inputs(
In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval.
When ``num_classes`` is given, it is checked that it is consitent with input cases (binary,
multi-label, ...), and that, if availible, the implied number of classes in the ``C``
When ``num_classes`` is given, it is checked that it is consistent with input cases (binary,
multi-label, ...), and that, if available, the implied number of classes in the ``C``
dimension is consistent with it (as well as that max label in target is smaller than it).
When ``num_classes`` is not specified in these cases, consistency of the highest target
@ -242,13 +242,13 @@ def _check_classification_inputs(
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
num_classes:
Number of classes. If not explicitly set, the number of classes will be infered
Number of classes. If not explicitly set, the number of classes will be inferred
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
Number of highest probability entries for each sample to convert to 1s - relevant
only for inputs with probability predictions. The default value (``None``) will be
interepreted as 1 for these inputs. If this parameter is set for multi-label inputs,
interpreted as 1 for these inputs. If this parameter is set for multi-label inputs,
it will take precedence over threshold.
Should be left unset (``None``) for inputs with label predictions.
@ -264,7 +264,7 @@ def _check_classification_inputs(
'multi-dim multi-class'
"""
# Baisc validation (that does not need case/type information)
# Basic validation (that does not need case/type information)
_basic_input_validation(preds, target, threshold, is_multiclass)
# Check that shape/types fall into one of the cases
@ -273,7 +273,7 @@ def _check_classification_inputs(
# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")
raise ValueError("Probabilities in `preds` must sum up to 1 across the `C` dimension.")
# Check consistency with the `C` dimension in case of multi-class data
if preds.shape != target.shape:
@ -370,7 +370,7 @@ def _input_format_classification(
Threshold probability value for transforming probability predictions to binary
(0 or 1) predictions, in the case of binary or multi-label inputs.
num_classes:
Number of classes. If not explicitly set, the number of classes will be infered
Number of classes. If not explicitly set, the number of classes will be inferred
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
@ -438,7 +438,7 @@ def _input_format_classification(
target = target.reshape(target.shape[0], -1)
preds = preds.reshape(preds.shape[0], -1)
# Some operatins above create an extra dimension for MC/binary case - this removes it
# Some operations above create an extra dimension for MC/binary case - this removes it
if preds.ndim > 2:
preds, target = preds.squeeze(-1), target.squeeze(-1)
@ -469,7 +469,7 @@ def _input_format_classification_one_hot(
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
if preds.ndim == target.ndim + 1:
# multi class probabilites
# multi class probabilities
preds = torch.argmax(preds, dim=1)
if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel:
@ -478,7 +478,7 @@ def _input_format_classification_one_hot(
target = to_onehot(target, num_classes=num_classes)
elif preds.ndim == target.ndim and preds.is_floating_point():
# binary or multilabel probablities
# binary or multilabel probabilities
preds = (preds >= threshold).long()
# transpose class as first dim and reshape

Просмотреть файл

@ -122,7 +122,7 @@ def to_categorical(tensor: Tensor, argmax_dim: int = 1) -> Tensor:
def get_num_classes(
pred: Tensor,
preds: Tensor,
target: Tensor,
num_classes: Optional[int] = None,
) -> int:
@ -130,7 +130,7 @@ def get_num_classes(
Calculates the number of classes for a given prediction and target tensor.
Args:
pred: predicted values
preds: predicted values
target: true labels
num_classes: number of classes if known
@ -138,7 +138,7 @@ def get_num_classes(
An integer that represents the number of classes.
"""
num_target_classes = int(target.max().detach().item() + 1)
num_pred_classes = int(pred.max().detach().item() + 1)
num_pred_classes = int(preds.max().detach().item() + 1)
num_all_classes = max(num_target_classes, num_pred_classes)
if num_classes is None:

Просмотреть файл

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Union
from typing import Optional, Union
class EnumStr(str, Enum):
@ -28,7 +28,7 @@ class EnumStr(str, Enum):
"""
@classmethod
def from_str(cls, value: str) -> 'EnumStr':
def from_str(cls, value: str) -> Optional['EnumStr']:
statuses = [status for status in dir(cls) if not status.startswith('_')]
for st in statuses:
if st.lower() == value.lower():
@ -63,7 +63,9 @@ class AverageMethod(EnumStr):
>>> None in list(AverageMethod)
True
>>> 'none' == AverageMethod.NONE == None
>>> AverageMethod.NONE == None
True
>>> AverageMethod.NONE == 'none'
True
"""