* fix compatibility

* local

* flake8

* drop 3.6 min

* .
This commit is contained in:
Jirka Borovec 2021-04-21 15:55:09 +02:00 коммит произвёл GitHub
Родитель 5b7c40e6ae
Коммит e0f844c98b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 22 добавлений и 12 удалений

2
.github/workflows/ci_test-full.yml поставляемый
Просмотреть файл

@ -20,6 +20,8 @@ jobs:
python-version: [3.6, 3.8, 3.9]
requires: ['minimal', 'latest']
exclude:
- python-version: 3.6
requires: 'latest'
- python-version: 3.9
requires: 'minimal'

Просмотреть файл

@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2018-2021 William Falcon
Copyright 2020-2021 PytorchLightning team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

Просмотреть файл

@ -1,3 +1,5 @@
from torchmetrics.utilities.imports import _module_available
import operator
_PL_AVAILABLE = _module_available('pytorch_lightning')
from torchmetrics.utilities.imports import _compare_version
_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")

Просмотреть файл

@ -20,8 +20,8 @@ from torch import tensor
from torch.utils.data import DataLoader
from integrations.lightning_models import BoringModel, RandomDataset
from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3
from torchmetrics import Accuracy, AveragePrecision, Metric
from torchmetrics.utilities.imports import _LIGHTNING_GREATER_EQUAL_1_3
class SumMetric(Metric):

Просмотреть файл

@ -20,9 +20,9 @@ import pytest
import torch
from torch import nn, tensor
from tests.helpers import seed_all
from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _LIGHTNING_GREATER_EQUAL_1_3, _TORCH_LOWER_1_6
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6
seed_all(42)

Просмотреть файл

@ -1,14 +1,17 @@
import operator
import random
import numpy
import torch
from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6
from torchmetrics.utilities.imports import _TORCH_LOWER_1_4, _TORCH_LOWER_1_5, _TORCH_LOWER_1_6, _compare_version
_MARK_TORCH_MIN_1_4 = dict(condition=_TORCH_LOWER_1_4, reason="required PT >= 1.4")
_MARK_TORCH_MIN_1_5 = dict(condition=_TORCH_LOWER_1_5, reason="required PT >= 1.5")
_MARK_TORCH_MIN_1_6 = dict(condition=_TORCH_LOWER_1_6, reason="required PT >= 1.6")
_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
def seed_all(seed):
random.seed(seed)

Просмотреть файл

@ -13,6 +13,7 @@
# limitations under the License.
import functools
import inspect
import operator
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import deepcopy
@ -24,7 +25,7 @@ from torch import Tensor, nn
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.data import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _LIGHTNING_GREATER_EQUAL_1_3
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _compare_version
class Metric(nn.Module, ABC):
@ -70,6 +71,7 @@ class Metric(nn.Module, ABC):
dist_sync_fn: Callable = None,
):
super().__init__()
self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")
self.dist_sync_on_step = dist_sync_on_step
self.compute_on_step = compute_on_step
@ -257,9 +259,8 @@ class Metric(nn.Module, ABC):
"""
This method automatically resets the metric state variables to their default value.
"""
# lower lightning versions requires this implicitly to log metric objects correctly
# in self.log
if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3:
# lower lightning versions requires this implicitly to log metric objects correctly in self.log
if not _LIGHTNING_AVAILABLE or self._LIGHTNING_GREATER_EQUAL_1_3:
self._computed = None
for attr, default in self._defaults.items():

Просмотреть файл

@ -47,7 +47,10 @@ def _compare_version(package: str, op, version) -> Optional[bool]:
>>> import operator
>>> _compare_version("torch", operator.ge, "0.1")
True
>>> _compare_version("any_module", operator.ge, "0.0") # is None
"""
if not _module_available(package):
return None
try:
pkg = import_module(package)
pkg_version = pkg.__version__
@ -71,4 +74,3 @@ _TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_LIGHTNING_AVAILABLE = _module_available("pytorch_lightning")
_LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", operator.ge, "1.3.0")