From e11a419c687a3f43c536285cf593d3ccc7425510 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 22 Feb 2021 10:42:39 +0100 Subject: [PATCH] mypy --- .github/workflows/code-format.yml | 2 +- pyproject.toml | 20 ++++++++++++++++++++ setup.cfg | 4 ++++ torchmetrics/compositional.py | 14 +++++++------- 4 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 pyproject.toml diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 2eee8ea..2d3fe31 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -37,7 +37,7 @@ jobs: pip --version - name: isort run: | - isort . --check --diff + isort --settings-path=./pyproject.toml . --check --diff typing-check-mypy: runs-on: ubuntu-20.04 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..829caf5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = [ + "setuptools", + "wheel", +] + +[tool.black] +# https://github.com/psf/black +line-length = 120 +target-version = ["py38"] +exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)" + +[tool.isort] +known_first_party = [ + "torchmetrics", + "tests", +] +skip_glob = [] +profile = "black" +line_length = 120 diff --git a/setup.cfg b/setup.cfg index b94b96b..84a272b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,6 +63,10 @@ ignore_missing_imports = True [mypy-torchmetrics.metric] ignore_errors = True +# todo: add proper typing to this module... +[mypy-torchmetrics.compositional] +ignore_errors = True + # todo: add proper typing to this module... [mypy-torchmetrics.utils] ignore_errors = True diff --git a/torchmetrics/compositional.py b/torchmetrics/compositional.py index db22422..910d7f2 100644 --- a/torchmetrics/compositional.py +++ b/torchmetrics/compositional.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Callable, Union, Any import torch @@ -41,18 +41,18 @@ class CompositionalMetric(Metric): else: self.metric_b = metric_b - def _sync_dist(self, dist_sync_fn=None): + def _sync_dist(self, dist_sync_fn: Callable = None) -> None: # No syncing required here. syncing will be done in metric_a and metric_b pass - def update(self, *args, **kwargs): + def update(self, *args, **kwargs) -> None: if isinstance(self.metric_a, Metric): self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) if isinstance(self.metric_b, Metric): self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) - def compute(self): + def compute(self) -> Any: # also some parsing for kwargs? if isinstance(self.metric_a, Metric): @@ -70,20 +70,20 @@ class CompositionalMetric(Metric): return self.op(val_a, val_b) - def reset(self): + def reset(self) -> None: if isinstance(self.metric_a, Metric): self.metric_a.reset() if isinstance(self.metric_b, Metric): self.metric_b.reset() - def persistent(self, mode: bool = False): + def persistent(self, mode: bool = False) -> None: if isinstance(self.metric_a, Metric): self.metric_a.persistent(mode=mode) if isinstance(self.metric_b, Metric): self.metric_b.persistent(mode=mode) - def __repr__(self): + def __repr__(self) -> str: _op_metrics = f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" repr_str = (self.__class__.__name__ + _op_metrics)