This commit is contained in:
Родитель
d76b3c6fb4
Коммит
e11a419c68
|
@ -37,7 +37,7 @@ jobs:
|
|||
pip --version
|
||||
- name: isort
|
||||
run: |
|
||||
isort . --check --diff
|
||||
isort --settings-path=./pyproject.toml . --check --diff
|
||||
|
||||
typing-check-mypy:
|
||||
runs-on: ubuntu-20.04
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
[build-system]
|
||||
requires = [
|
||||
"setuptools",
|
||||
"wheel",
|
||||
]
|
||||
|
||||
[tool.black]
|
||||
# https://github.com/psf/black
|
||||
line-length = 120
|
||||
target-version = ["py38"]
|
||||
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)"
|
||||
|
||||
[tool.isort]
|
||||
known_first_party = [
|
||||
"torchmetrics",
|
||||
"tests",
|
||||
]
|
||||
skip_glob = []
|
||||
profile = "black"
|
||||
line_length = 120
|
|
@ -63,6 +63,10 @@ ignore_missing_imports = True
|
|||
[mypy-torchmetrics.metric]
|
||||
ignore_errors = True
|
||||
|
||||
# todo: add proper typing to this module...
|
||||
[mypy-torchmetrics.compositional]
|
||||
ignore_errors = True
|
||||
|
||||
# todo: add proper typing to this module...
|
||||
[mypy-torchmetrics.utils]
|
||||
ignore_errors = True
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Union
|
||||
from typing import Callable, Union, Any
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -41,18 +41,18 @@ class CompositionalMetric(Metric):
|
|||
else:
|
||||
self.metric_b = metric_b
|
||||
|
||||
def _sync_dist(self, dist_sync_fn=None):
|
||||
def _sync_dist(self, dist_sync_fn: Callable = None) -> None:
|
||||
# No syncing required here. syncing will be done in metric_a and metric_b
|
||||
pass
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
if isinstance(self.metric_a, Metric):
|
||||
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))
|
||||
|
||||
if isinstance(self.metric_b, Metric):
|
||||
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))
|
||||
|
||||
def compute(self):
|
||||
def compute(self) -> Any:
|
||||
|
||||
# also some parsing for kwargs?
|
||||
if isinstance(self.metric_a, Metric):
|
||||
|
@ -70,20 +70,20 @@ class CompositionalMetric(Metric):
|
|||
|
||||
return self.op(val_a, val_b)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
if isinstance(self.metric_a, Metric):
|
||||
self.metric_a.reset()
|
||||
|
||||
if isinstance(self.metric_b, Metric):
|
||||
self.metric_b.reset()
|
||||
|
||||
def persistent(self, mode: bool = False):
|
||||
def persistent(self, mode: bool = False) -> None:
|
||||
if isinstance(self.metric_a, Metric):
|
||||
self.metric_a.persistent(mode=mode)
|
||||
if isinstance(self.metric_b, Metric):
|
||||
self.metric_b.persistent(mode=mode)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
_op_metrics = f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
|
||||
repr_str = (self.__class__.__name__ + _op_metrics)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче