* refactor setup

* info

* importing

* flake8
This commit is contained in:
Jirka Borovec 2021-03-19 15:38:56 +01:00 коммит произвёл GitHub
Родитель be89a1b731
Коммит 2e4cb70c46
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 109 добавлений и 109 удалений

Просмотреть файл

@ -13,25 +13,29 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
# import m2r
import builtins
import glob
import inspect
import os
import shutil
import sys
from importlib.util import module_from_spec, spec_from_file_location
import pt_lightning_sphinx_theme
PATH_HERE = os.path.abspath(os.path.dirname(__file__))
PATH_ROOT = os.path.join(PATH_HERE, "..", "..")
sys.path.insert(0, os.path.abspath(PATH_ROOT))
builtins.__LIGHTNING_BOLT_SETUP__ = True
_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
sys.path.insert(0, os.path.abspath(_PATH_ROOT))
FOLDER_GENERATED = 'generated'
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))
import torchmetrics # noqa: E402
try:
from torchmetrics import info
except ImportError:
# alternative https://stackoverflow.com/a/67692/4521646
spec = spec_from_file_location("torchmetrics/info.py", os.path.join(_PATH_ROOT, "torchmetrics", "info.py"))
info = module_from_spec(spec)
spec.loader.exec_module(info)
html_favicon = '_static/images/icon.svg'
@ -39,13 +43,13 @@ html_favicon = '_static/images/icon.svg'
# this name shall match the project name in Github as it is used for linking to code
project = "PyTorch-Metrics"
copyright = torchmetrics.__copyright__
author = torchmetrics.__author__
copyright = info.__copyright__
author = info.__author__
# The short X.Y version
version = torchmetrics.__version__
version = info.__version__
# The full version, including alpha/beta/rc tags
release = torchmetrics.__version__
release = info.__version__
# Options for the linkcode extension
# ----------------------------------
@ -70,14 +74,14 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
fp.writelines(chlog_lines)
os.makedirs(os.path.join(PATH_HERE, FOLDER_GENERATED), exist_ok=True)
os.makedirs(os.path.join(_PATH_HERE, FOLDER_GENERATED), exist_ok=True)
# copy all documents from GH templates like contribution guide
for md in glob.glob(os.path.join(PATH_ROOT, '.github', '*.md')):
shutil.copy(md, os.path.join(PATH_HERE, FOLDER_GENERATED, os.path.basename(md)))
for md in glob.glob(os.path.join(_PATH_ROOT, '.github', '*.md')):
shutil.copy(md, os.path.join(_PATH_HERE, FOLDER_GENERATED, os.path.basename(md)))
# copy also the changelog
_transform_changelog(
os.path.join(PATH_ROOT, 'CHANGELOG.md'),
os.path.join(PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'),
os.path.join(_PATH_ROOT, 'CHANGELOG.md'),
os.path.join(_PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'),
)
# -- General configuration ---------------------------------------------------
@ -166,8 +170,8 @@ html_theme_path = [pt_lightning_sphinx_theme.get_html_theme_path()]
# documentation.
html_theme_options = {
"pytorch_project": torchmetrics.__homepage__,
"canonical_url": torchmetrics.__homepage__,
"pytorch_project": info.__homepage__,
"canonical_url": info.__homepage__,
"collapse_navigation": False,
"display_version": True,
"logo_only": False,
@ -233,7 +237,7 @@ texinfo_documents = [
project + " Documentation",
author,
project,
torchmetrics.__docs__,
info.__docs__,
"Miscellaneous",
),
]
@ -280,11 +284,11 @@ todo_include_todos = True
# packages for which sphinx-apidoc should generate the docs (.rst files)
PACKAGES = [
torchmetrics.__name__,
info.__name__,
]
# def run_apidoc(_):
# apidoc_output_folder = os.path.join(PATH_HERE, "api")
# apidoc_output_folder = os.path.join(_PATH_HERE, "api")
# sys.path.insert(0, apidoc_output_folder)
#
# # delete api-doc files before generating them
@ -294,7 +298,7 @@ PACKAGES = [
# for pkg in PACKAGES:
# argv = ['-e',
# '-o', apidoc_output_folder,
# os.path.join(PATH_ROOT, pkg),
# os.path.join(_PATH_ROOT, pkg),
# '**/test_*',
# '--force',
# '--private',
@ -311,10 +315,10 @@ def setup(app):
# copy all notebooks to local folder
path_nbs = os.path.join(PATH_HERE, "notebooks")
path_nbs = os.path.join(_PATH_HERE, "notebooks")
if not os.path.isdir(path_nbs):
os.mkdir(path_nbs)
for path_ipynb in glob.glob(os.path.join(PATH_ROOT, "notebooks", "*.ipynb")):
for path_ipynb in glob.glob(os.path.join(_PATH_ROOT, "notebooks", "*.ipynb")):
path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb))
shutil.copy(path_ipynb, path_ipynb2)
@ -340,7 +344,7 @@ PACKAGE_MAPPING = {
MOCK_PACKAGES = []
if SPHINX_MOCK_REQUIREMENTS:
# mock also base packages when we are on RTD since we don't install them there
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, "requirements.txt"))
MOCK_PACKAGES += package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt"))
MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]
autodoc_mock_imports = MOCK_PACKAGES

Просмотреть файл

@ -1,22 +1,24 @@
#!/usr/bin/env python
import os
import sys
# Always prefer setuptools over distutils
from setuptools import find_packages, setup
_PATH_ROOT = os.path.realpath(os.path.dirname(__file__))
try:
import builtins
from torchmetrics import info, setup_tools
except ImportError:
import __builtin__ as builtins
# alternative https://stackoverflow.com/a/67692/4521646
sys.path.append("torchmetrics")
import info
import setup_tools
# https://packaging.python.org/guides/single-sourcing-package-version/
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
PATH_ROOT = os.path.dirname(__file__)
builtins.__LIGHTNING_SETUP__ = True
import torchmetrics # noqa: E402
from torchmetrics.setup_tools import _load_readme_description, _load_requirements # noqa: E402
long_description = setup_tools._load_readme_description(
_PATH_ROOT,
homepage=info.__homepage__,
version=f'v{info.__version__}',
)
# https://packaging.python.org/discussions/install-requires-vs-requirements /
# keep the meta-data here for simplicity in reading this file... it's not obvious
@ -25,26 +27,26 @@ from torchmetrics.setup_tools import _load_readme_description, _load_requirement
# engineer specific practices
setup(
name='torchmetrics',
version=torchmetrics.__version__,
description=torchmetrics.__docs__,
author=torchmetrics.__author__,
author_email=torchmetrics.__author_email__,
url=torchmetrics.__homepage__,
download_url='https://github.com/PyTorchLightning/metrics/archive/master.zip',
license=torchmetrics.__license__,
version=info.__version__,
description=info.__docs__,
author=info.__author__,
author_email=info.__author_email__,
url=info.__homepage__,
download_url=os.path.join(info.__homepage__, 'archive', 'master.zip'),
license=info.__license__,
packages=find_packages(exclude=['tests', 'docs']),
long_description=_load_readme_description(PATH_ROOT, version=f'v{torchmetrics.__version__}'),
long_description=long_description,
long_description_content_type='text/markdown',
include_package_data=True,
zip_safe=False,
keywords=['deep learning', 'machine learning', 'pytorch', 'metrics', 'AI'],
python_requires='>=3.6',
setup_requires=[],
install_requires=_load_requirements(PATH_ROOT),
install_requires=setup_tools._load_requirements(_PATH_ROOT),
project_urls={
"Bug Tracker": "https://github.com/PyTorchLightning/torchmetrics/issues",
"Bug Tracker": os.path.join(info.__homepage__, 'issues'),
"Documentation": "https://torchmetrics.rtfd.io/en/latest/",
"Source Code": "https://github.com/PyTorchLightning/torchmetrics",
"Source Code": info.__homepage__,
},
classifiers=[
'Environment :: Console',

Просмотреть файл

@ -2,22 +2,15 @@
import logging as __logging
import os
__version__ = '0.2.1dev'
__author__ = 'PyTorchLightning et al.'
__author_email__ = 'name@pytorchlightning.ai'
__license__ = 'Apache-2.0'
__copyright__ = f'Copyright (c) 2020-2021, {__author__}.'
__homepage__ = 'https://github.com/PyTorchLightning/metrics'
__docs__ = "PyTorch native Metrics"
__long_doc__ = """
Torchmetrics is a metrics API created for easy metric development and usage in both PyTorch and
[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). It was originally a part of
Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics
implemented without having to install Pytorch Lightning (even though we would love for you to try it out).
We currently have around 25+ metrics implemented and we continuously is adding more metrics, both within
already covered domains (classification, regression ect.) but also new domains (object detection ect.).
We make sure that all our metrics are rigorously tested such that you can trust them.
"""
from torchmetrics.info import ( # noqa: F401
__author__,
__author_email__,
__copyright__,
__docs__,
__homepage__,
__license__,
__version__,
)
_logger = __logging.getLogger("torchmetrics")
_logger.addHandler(__logging.StreamHandler())
@ -26,46 +19,31 @@ _logger.setLevel(__logging.INFO)
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
try:
# This variable is injected in the __builtins__ by the build
# process. It used to enable importing subpackages of skimage when
# the binaries are not built
_ = None if __LIGHTNING_SETUP__ else None
except NameError:
__LIGHTNING_SETUP__: bool = False
if __LIGHTNING_SETUP__:
import sys # pragma: no-cover
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
else:
from torchmetrics.classification import ( # noqa: F401
AUC,
AUROC,
F1,
ROC,
Accuracy,
AveragePrecision,
CohenKappa,
ConfusionMatrix,
FBeta,
HammingDistance,
IoU,
Precision,
PrecisionRecallCurve,
Recall,
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: F401
from torchmetrics.metric import Metric # noqa: F401
from torchmetrics.regression import ( # noqa: F401
PSNR,
SSIM,
ExplainedVariance,
MeanAbsoluteError,
MeanSquaredError,
MeanSquaredLogError,
R2Score,
)
from torchmetrics.classification import ( # noqa: F401 E402
AUC,
AUROC,
F1,
ROC,
Accuracy,
AveragePrecision,
CohenKappa,
ConfusionMatrix,
FBeta,
HammingDistance,
IoU,
Precision,
PrecisionRecallCurve,
Recall,
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: F401 E402
from torchmetrics.metric import Metric # noqa: F401 E402
from torchmetrics.regression import ( # noqa: F401 E402
PSNR,
SSIM,
ExplainedVariance,
MeanAbsoluteError,
MeanSquaredError,
MeanSquaredLogError,
R2Score,
)

Просмотреть файл

@ -14,8 +14,8 @@
from typing import List, Optional, Sequence, Tuple, Union
import torch
from torch.nn import functional as F
from torch import Tensor, tensor
from torch.nn import functional as F
from torchmetrics.utilities import rank_zero_warn

16
torchmetrics/info.py Normal file
Просмотреть файл

@ -0,0 +1,16 @@
__version__ = '0.2.1dev'
__author__ = 'PyTorchLightning et al.'
__author_email__ = 'name@pytorchlightning.ai'
__license__ = 'Apache-2.0'
__copyright__ = f'Copyright (c) 2020-2021, {__author__}.'
__homepage__ = 'https://github.com/PyTorchLightning/metrics'
__docs__ = "PyTorch native Metrics"
__long_doc__ = """
Torchmetrics is a metrics API created for easy metric development and usage in both PyTorch and
[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). It was originally a part of
Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics
implemented without having to install Pytorch Lightning (even though we would love for you to try it out).
We currently have around 25+ metrics implemented and we continuously is adding more metrics, both within
already covered domains (classification, regression ect.) but also new domains (object detection ect.).
We make sure that all our metrics are rigorously tested such that you can trust them.
"""

Просмотреть файл

@ -15,7 +15,7 @@ import os
import re
from typing import List
from torchmetrics import _PROJECT_ROOT, __homepage__, __version__
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]:
@ -39,10 +39,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme
return reqs
def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str:
def _load_readme_description(path_dir: str, homepage: str, version: str) -> str:
"""Load readme as decribtion
>>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
>>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'<div align="center">...'
"""
path_readme = os.path.join(path_dir, "README.md")