fix compatibility to PL (#192)
* fix compatibility * local * flake8 * drop 3.6 min * .
This commit is contained in:
Родитель
5b7c40e6ae
Коммит
e0f844c98b
|
@ -20,6 +20,8 @@ jobs:
|
|||
python-version: [3.6, 3.8, 3.9]
|
||||
requires: ['minimal', 'latest']
|
||||
exclude:
|
||||
- python-version: 3.6
|
||||
requires: 'latest'
|
||||
- python-version: 3.9
|
||||
requires: 'minimal'
|
||||
|
||||
|
|
2
LICENSE
2
LICENSE
|
@ -186,7 +186,7 @@
|
|||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2018-2021 William Falcon
|
||||
Copyright 2020-2021 PytorchLightning team
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from torchmetrics.utilities.imports import _module_available
|
||||
import operator
|
||||
|
||||
_PL_AVAILABLE = _module_available('pytorch_lightning')
|
||||
from torchmetrics.utilities.imports import _compare_version
|
||||
|
||||
_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
|
||||
|
|
|
@ -20,8 +20,8 @@ from torch import tensor
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from integrations.lightning_models import BoringModel, RandomDataset
|
||||
from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3
|
||||
from torchmetrics import Accuracy, AveragePrecision, Metric
|
||||
from torchmetrics.utilities.imports import _LIGHTNING_GREATER_EQUAL_1_3
|
||||
|
||||
|
||||
class SumMetric(Metric):
|
||||
|
|
|
@ -20,9 +20,9 @@ import pytest
|
|||
import torch
|
||||
from torch import nn, tensor
|
||||
|
||||
from tests.helpers import seed_all
|
||||
from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all
|
||||
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
|
||||
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _LIGHTNING_GREATER_EQUAL_1_3, _TORCH_LOWER_1_6
|
||||
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6
|
||||
|
||||
seed_all(42)
|
||||
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import operator
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6
|
||||
from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6, _compare_version
|
||||
|
||||
_MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4, reason="required PT >= 1.4")
|
||||
_MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5, reason="required PT >= 1.5")
|
||||
_MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6, reason="required PT >= 1.6")
|
||||
|
||||
_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
|
||||
|
||||
|
||||
def seed_all(seed):
|
||||
random.seed(seed)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
|
@ -24,7 +25,7 @@ from torch import Tensor, nn
|
|||
from torchmetrics.utilities import apply_to_collection
|
||||
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
|
||||
from torchmetrics.utilities.distributed import gather_all_tensors
|
||||
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _LIGHTNING_GREATER_EQUAL_1_3
|
||||
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version
|
||||
|
||||
|
||||
class Metric(nn.Module, ABC):
|
||||
|
@ -70,6 +71,7 @@ class Metric(nn.Module, ABC):
|
|||
dist_sync_fn: Callable = None,
|
||||
):
|
||||
super().__init__()
|
||||
self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
|
||||
|
||||
self.dist_sync_on_step = dist_sync_on_step
|
||||
self.compute_on_step = compute_on_step
|
||||
|
@ -257,9 +259,8 @@ class Metric(nn.Module, ABC):
|
|||
"""
|
||||
This method automatically resets the metric state variables to their default value.
|
||||
"""
|
||||
# lower lightning versions requires this implicitly to log metric objects correctly
|
||||
# in self.log
|
||||
if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3:
|
||||
# lower lightning versions requires this implicitly to log metric objects correctly in self.log
|
||||
if not _LIGHTNING_AVAILABLE or self._LIGHTNING_GREATER_EQUAL_1_3:
|
||||
self._computed = None
|
||||
|
||||
for attr, default in self._defaults.items():
|
||||
|
|
|
@ -47,7 +47,10 @@ def _compare_version(package: str, op, version) -> Optional[bool]:
|
|||
>>> import operator
|
||||
>>> _compare_version("torch", operator.ge, "0.1")
|
||||
True
|
||||
>>> _compare_version("any_module", operator.ge, "0.0") # is None
|
||||
"""
|
||||
if not _module_available(package):
|
||||
return None
|
||||
try:
|
||||
pkg = import_module(package)
|
||||
pkg_version = pkg.__version__
|
||||
|
@ -71,4 +74,3 @@ _TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0")
|
|||
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
|
||||
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
|
||||
_LIGHTNING_AVAILABLE = _module_available("pytorch_lightning")
|
||||
_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
|
||||
|
|
Загрузка…
Ссылка в новой задаче