This commit is contained in:
Jirka Borovec 2021-02-22 10:42:39 +01:00
Родитель d76b3c6fb4
Коммит e11a419c68
4 изменённых файлов: 32 добавлений и 8 удалений

2
.github/workflows/code-format.yml поставляемый
Просмотреть файл

@ -37,7 +37,7 @@ jobs:
pip --version
- name: isort
run: |
isort . --check --diff
isort --settings-path=./pyproject.toml . --check --diff
typing-check-mypy:
runs-on: ubuntu-20.04

20
pyproject.toml Normal file
Просмотреть файл

@ -0,0 +1,20 @@
[build-system]
requires = [
"setuptools",
"wheel",
]
[tool.black]
# https://github.com/psf/black
line-length = 120
target-version = ["py38"]
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)"
[tool.isort]
known_first_party = [
"torchmetrics",
"tests",
]
skip_glob = []
profile = "black"
line_length = 120

Просмотреть файл

@ -63,6 +63,10 @@ ignore_missing_imports = True
[mypy-torchmetrics.metric]
ignore_errors = True
# todo: add proper typing to this module...
[mypy-torchmetrics.compositional]
ignore_errors = True
# todo: add proper typing to this module...
[mypy-torchmetrics.utils]
ignore_errors = True

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Callable, Union, Any
import torch
@ -41,18 +41,18 @@ class CompositionalMetric(Metric):
else:
self.metric_b = metric_b
def _sync_dist(self, dist_sync_fn=None):
def _sync_dist(self, dist_sync_fn: Callable = None) -> None:
# No syncing required here. syncing will be done in metric_a and metric_b
pass
def update(self, *args, **kwargs):
def update(self, *args, **kwargs) -> None:
if isinstance(self.metric_a, Metric):
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))
if isinstance(self.metric_b, Metric):
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))
def compute(self):
def compute(self) -> Any:
# also some parsing for kwargs?
if isinstance(self.metric_a, Metric):
@ -70,20 +70,20 @@ class CompositionalMetric(Metric):
return self.op(val_a, val_b)
def reset(self):
def reset(self) -> None:
if isinstance(self.metric_a, Metric):
self.metric_a.reset()
if isinstance(self.metric_b, Metric):
self.metric_b.reset()
def persistent(self, mode: bool = False):
def persistent(self, mode: bool = False) -> None:
if isinstance(self.metric_a, Metric):
self.metric_a.persistent(mode=mode)
if isinstance(self.metric_b, Metric):
self.metric_b.persistent(mode=mode)
def __repr__(self):
def __repr__(self) -> str:
_op_metrics = f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
repr_str = (self.__class__.__name__ + _op_metrics)