Add testing agains each feat PT version (#127)

* draft

* update

* clean
This commit is contained in:
Jirka Borovec 2021-03-25 20:44:26 +01:00 коммит произвёл GitHub
Родитель 2601642ad4
Коммит 26eae3967d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 149 добавлений и 7 удалений

76
.github/workflows/ci_test-conda.yml поставляемый Normal file
Просмотреть файл

@ -0,0 +1,76 @@
name: PyTorch & Conda
# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
on: # Trigger the workflow on push or pull request, but only for the master branch
push:
branches: [master, "release/*"]
pull_request:
branches: [master, "release/*"]
jobs:
conda:
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
python-version: [3.7]
pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]
# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v2
with:
path: ~/conda_pkgs_dir
key: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('environment.yml') }}
restore-keys: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-
# Add another cache for Pip as not all packages lives in Conda env
- name: Cache pip
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('requirements/base.txt') }}
restore-keys: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-
# https://docs.conda.io/projects/conda/en/4.6.0/_downloads/52a95608c49671267e40c689e0bc00ca/conda-cheatsheet.pdf
# https://gist.github.com/mwouts/9842452d020c08faf9e84a3bba38a66f
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "4.7.12"
python-version: ${{ matrix.python-version }}
channels: conda-forge,pytorch,pytorch-test,pytorch-nightly
channel-priority: true
auto-activate-base: true
# environment-file: ./environment.yml
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
- name: Update Environment
run: |
conda info
conda install pytorch=${{ matrix.pytorch-version }} cpuonly
conda list
pip --version
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --quiet
pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet
pip list
python -c "import torch; assert torch.__version__[:3] == '${{ matrix.pytorch-version }}', torch.__version__"
shell: bash -l {0}
- name: Testing
run: |
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003
python -m pytest torchmetrics tests -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
shell: bash -l {0}
- name: Upload pytest test results
uses: actions/upload-artifact@master
with:
name: test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
path: junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: failure()

Просмотреть файл

@ -38,3 +38,4 @@ prune notebook*
prune temp*
prune test*
prune benchmark*
prune integration*

Просмотреть файл

@ -84,5 +84,5 @@ jobs:
condition: succeededOrFailed()
- bash: |
python -m pytest integrations --durations=25
python -m pytest integrations -v --durations=25
displayName: 'Integrations'

Просмотреть файл

@ -0,0 +1,3 @@
from torchmetrics.utilities.imports import _module_available
_PL_AVAILABLE = _module_available('pytorch_lightning')

Просмотреть файл

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

Просмотреть файл

@ -21,7 +21,10 @@ exclude_lines =
[flake8]
max-line-length = 120
exclude = .tox,*.egg,build,temp
exclude =
*.egg
build
temp
select = E,W,F
doctests = True
verbose = 2

Просмотреть файл

@ -21,7 +21,7 @@ from sklearn.metrics import multilabel_confusion_matrix
from torch import Tensor, tensor
from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.classification.inputs import _input_multiclass_prob as _input_mccls_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mcls
@ -104,8 +104,8 @@ def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, i
["macro", None, None, _input_binary, None],
["micro", None, None, _input_mdmc_prob, None],
["micro", None, None, _input_binary_prob, 0],
["micro", None, None, _input_mccls_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES],
["micro", None, None, _input_mcls_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES],
],
)
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
@ -141,8 +141,8 @@ def test_wrong_threshold():
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
(

Просмотреть файл

@ -1,6 +1,64 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from importlib import import_module
from importlib.util import find_spec
import torch
from pkg_resources import DistributionNotFound
def _module_available(module_path: str) -> bool:
"""
Check if a path is available in your environment
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
try:
return find_spec(module_path) is not None
except AttributeError:
# Python 3.6
return False
except ModuleNotFoundError:
# Python 3.7+
return False
def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements
>>> import operator
>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = LooseVersion(pkg.__version__)
except AttributeError:
return False
if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")):
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, LooseVersion(version))
_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")