Add testing agains each feat PT version (#127)
* draft * update * clean
This commit is contained in:
Родитель
2601642ad4
Коммит
26eae3967d
|
@ -0,0 +1,76 @@
|
|||
name: PyTorch & Conda
|
||||
|
||||
# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
|
||||
on: # Trigger the workflow on push or pull request, but only for the master branch
|
||||
push:
|
||||
branches: [master, "release/*"]
|
||||
pull_request:
|
||||
branches: [master, "release/*"]
|
||||
|
||||
jobs:
|
||||
conda:
|
||||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.7]
|
||||
pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]
|
||||
|
||||
# Timeout: https://stackoverflow.com/a/59076067/4521646
|
||||
timeout-minutes: 35
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Cache conda
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/conda_pkgs_dir
|
||||
key: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('environment.yml') }}
|
||||
restore-keys: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-
|
||||
|
||||
# Add another cache for Pip as not all packages lives in Conda env
|
||||
- name: Cache pip
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('requirements/base.txt') }}
|
||||
restore-keys: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-
|
||||
|
||||
# https://docs.conda.io/projects/conda/en/4.6.0/_downloads/52a95608c49671267e40c689e0bc00ca/conda-cheatsheet.pdf
|
||||
# https://gist.github.com/mwouts/9842452d020c08faf9e84a3bba38a66f
|
||||
- name: Setup Miniconda
|
||||
uses: conda-incubator/setup-miniconda@v2
|
||||
with:
|
||||
miniconda-version: "4.7.12"
|
||||
python-version: ${{ matrix.python-version }}
|
||||
channels: conda-forge,pytorch,pytorch-test,pytorch-nightly
|
||||
channel-priority: true
|
||||
auto-activate-base: true
|
||||
# environment-file: ./environment.yml
|
||||
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
|
||||
|
||||
- name: Update Environment
|
||||
run: |
|
||||
conda info
|
||||
conda install pytorch=${{ matrix.pytorch-version }} cpuonly
|
||||
conda list
|
||||
pip --version
|
||||
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --quiet
|
||||
pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet
|
||||
pip list
|
||||
python -c "import torch; assert torch.__version__[:3] == '${{ matrix.pytorch-version }}', torch.__version__"
|
||||
shell: bash -l {0}
|
||||
|
||||
- name: Testing
|
||||
run: |
|
||||
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003
|
||||
python -m pytest torchmetrics tests -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
|
||||
shell: bash -l {0}
|
||||
|
||||
- name: Upload pytest test results
|
||||
uses: actions/upload-artifact@master
|
||||
with:
|
||||
name: test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
|
||||
path: junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
|
||||
# Use always() to always run this step to publish test results when there are test failures
|
||||
if: failure()
|
|
@ -38,3 +38,4 @@ prune notebook*
|
|||
prune temp*
|
||||
prune test*
|
||||
prune benchmark*
|
||||
prune integration*
|
||||
|
|
|
@ -84,5 +84,5 @@ jobs:
|
|||
condition: succeededOrFailed()
|
||||
|
||||
- bash: |
|
||||
python -m pytest integrations --durations=25
|
||||
python -m pytest integrations -v --durations=25
|
||||
displayName: 'Integrations'
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from torchmetrics.utilities.imports import _module_available
|
||||
|
||||
_PL_AVAILABLE = _module_available('pytorch_lightning')
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
from torch.utils.data import Dataset
|
||||
|
|
|
@ -21,7 +21,10 @@ exclude_lines =
|
|||
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
exclude = .tox,*.egg,build,temp
|
||||
exclude =
|
||||
*.egg
|
||||
build
|
||||
temp
|
||||
select = E,W,F
|
||||
doctests = True
|
||||
verbose = 2
|
||||
|
|
|
@ -21,7 +21,7 @@ from sklearn.metrics import multilabel_confusion_matrix
|
|||
from torch import Tensor, tensor
|
||||
|
||||
from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
|
||||
from tests.classification.inputs import _input_multiclass_prob as _input_mccls_prob
|
||||
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
||||
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
||||
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
||||
from tests.classification.inputs import _input_multilabel as _input_mcls
|
||||
|
@ -104,8 +104,8 @@ def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, i
|
|||
["macro", None, None, _input_binary, None],
|
||||
["micro", None, None, _input_mdmc_prob, None],
|
||||
["micro", None, None, _input_binary_prob, 0],
|
||||
["micro", None, None, _input_mccls_prob, NUM_CLASSES],
|
||||
["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES],
|
||||
["micro", None, None, _input_mcls_prob, NUM_CLASSES],
|
||||
["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES],
|
||||
],
|
||||
)
|
||||
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
|
||||
|
@ -141,8 +141,8 @@ def test_wrong_threshold():
|
|||
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
|
||||
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
|
||||
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
|
||||
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
|
||||
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
|
||||
(
|
||||
|
|
|
@ -1,6 +1,64 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from distutils.version import LooseVersion
|
||||
from importlib import import_module
|
||||
from importlib.util import find_spec
|
||||
|
||||
import torch
|
||||
from pkg_resources import DistributionNotFound
|
||||
|
||||
|
||||
def _module_available(module_path: str) -> bool:
|
||||
"""
|
||||
Check if a path is available in your environment
|
||||
|
||||
>>> _module_available('os')
|
||||
True
|
||||
>>> _module_available('bla.bla')
|
||||
False
|
||||
"""
|
||||
try:
|
||||
return find_spec(module_path) is not None
|
||||
except AttributeError:
|
||||
# Python 3.6
|
||||
return False
|
||||
except ModuleNotFoundError:
|
||||
# Python 3.7+
|
||||
return False
|
||||
|
||||
|
||||
def _compare_version(package: str, op, version) -> bool:
|
||||
"""
|
||||
Compare package version with some requirements
|
||||
|
||||
>>> import operator
|
||||
>>> _compare_version("torch", operator.ge, "0.1")
|
||||
True
|
||||
"""
|
||||
try:
|
||||
pkg = import_module(package)
|
||||
except (ModuleNotFoundError, DistributionNotFound):
|
||||
return False
|
||||
try:
|
||||
pkg_version = LooseVersion(pkg.__version__)
|
||||
except AttributeError:
|
||||
return False
|
||||
if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")):
|
||||
# this is mock by sphinx, so it shall return True ro generate all summaries
|
||||
return True
|
||||
return op(pkg_version, LooseVersion(version))
|
||||
|
||||
|
||||
_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
|
||||
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
|
||||
|
|
Загрузка…
Ссылка в новой задаче