Родитель
be89a1b731
Коммит
2e4cb70c46
|
@ -13,25 +13,29 @@
|
|||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
|
||||
# import m2r
|
||||
import builtins
|
||||
import glob
|
||||
import inspect
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
|
||||
import pt_lightning_sphinx_theme
|
||||
|
||||
PATH_HERE = os.path.abspath(os.path.dirname(__file__))
|
||||
PATH_ROOT = os.path.join(PATH_HERE, "..", "..")
|
||||
sys.path.insert(0, os.path.abspath(PATH_ROOT))
|
||||
|
||||
builtins.__LIGHTNING_BOLT_SETUP__ = True
|
||||
_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
|
||||
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
|
||||
sys.path.insert(0, os.path.abspath(_PATH_ROOT))
|
||||
|
||||
FOLDER_GENERATED = 'generated'
|
||||
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))
|
||||
|
||||
import torchmetrics # noqa: E402
|
||||
try:
|
||||
from torchmetrics import info
|
||||
except ImportError:
|
||||
# alternative https://stackoverflow.com/a/67692/4521646
|
||||
spec = spec_from_file_location("torchmetrics/info.py", os.path.join(_PATH_ROOT, "torchmetrics", "info.py"))
|
||||
info = module_from_spec(spec)
|
||||
spec.loader.exec_module(info)
|
||||
|
||||
html_favicon = '_static/images/icon.svg'
|
||||
|
||||
|
@ -39,13 +43,13 @@ html_favicon = '_static/images/icon.svg'
|
|||
|
||||
# this name shall match the project name in Github as it is used for linking to code
|
||||
project = "PyTorch-Metrics"
|
||||
copyright = torchmetrics.__copyright__
|
||||
author = torchmetrics.__author__
|
||||
copyright = info.__copyright__
|
||||
author = info.__author__
|
||||
|
||||
# The short X.Y version
|
||||
version = torchmetrics.__version__
|
||||
version = info.__version__
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = torchmetrics.__version__
|
||||
release = info.__version__
|
||||
|
||||
# Options for the linkcode extension
|
||||
# ----------------------------------
|
||||
|
@ -70,14 +74,14 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
|
|||
fp.writelines(chlog_lines)
|
||||
|
||||
|
||||
os.makedirs(os.path.join(PATH_HERE, FOLDER_GENERATED), exist_ok=True)
|
||||
os.makedirs(os.path.join(_PATH_HERE, FOLDER_GENERATED), exist_ok=True)
|
||||
# copy all documents from GH templates like contribution guide
|
||||
for md in glob.glob(os.path.join(PATH_ROOT, '.github', '*.md')):
|
||||
shutil.copy(md, os.path.join(PATH_HERE, FOLDER_GENERATED, os.path.basename(md)))
|
||||
for md in glob.glob(os.path.join(_PATH_ROOT, '.github', '*.md')):
|
||||
shutil.copy(md, os.path.join(_PATH_HERE, FOLDER_GENERATED, os.path.basename(md)))
|
||||
# copy also the changelog
|
||||
_transform_changelog(
|
||||
os.path.join(PATH_ROOT, 'CHANGELOG.md'),
|
||||
os.path.join(PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'),
|
||||
os.path.join(_PATH_ROOT, 'CHANGELOG.md'),
|
||||
os.path.join(_PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'),
|
||||
)
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
@ -166,8 +170,8 @@ html_theme_path = [pt_lightning_sphinx_theme.get_html_theme_path()]
|
|||
# documentation.
|
||||
|
||||
html_theme_options = {
|
||||
"pytorch_project": torchmetrics.__homepage__,
|
||||
"canonical_url": torchmetrics.__homepage__,
|
||||
"pytorch_project": info.__homepage__,
|
||||
"canonical_url": info.__homepage__,
|
||||
"collapse_navigation": False,
|
||||
"display_version": True,
|
||||
"logo_only": False,
|
||||
|
@ -233,7 +237,7 @@ texinfo_documents = [
|
|||
project + " Documentation",
|
||||
author,
|
||||
project,
|
||||
torchmetrics.__docs__,
|
||||
info.__docs__,
|
||||
"Miscellaneous",
|
||||
),
|
||||
]
|
||||
|
@ -280,11 +284,11 @@ todo_include_todos = True
|
|||
|
||||
# packages for which sphinx-apidoc should generate the docs (.rst files)
|
||||
PACKAGES = [
|
||||
torchmetrics.__name__,
|
||||
info.__name__,
|
||||
]
|
||||
|
||||
# def run_apidoc(_):
|
||||
# apidoc_output_folder = os.path.join(PATH_HERE, "api")
|
||||
# apidoc_output_folder = os.path.join(_PATH_HERE, "api")
|
||||
# sys.path.insert(0, apidoc_output_folder)
|
||||
#
|
||||
# # delete api-doc files before generating them
|
||||
|
@ -294,7 +298,7 @@ PACKAGES = [
|
|||
# for pkg in PACKAGES:
|
||||
# argv = ['-e',
|
||||
# '-o', apidoc_output_folder,
|
||||
# os.path.join(PATH_ROOT, pkg),
|
||||
# os.path.join(_PATH_ROOT, pkg),
|
||||
# '**/test_*',
|
||||
# '--force',
|
||||
# '--private',
|
||||
|
@ -311,10 +315,10 @@ def setup(app):
|
|||
|
||||
|
||||
# copy all notebooks to local folder
|
||||
path_nbs = os.path.join(PATH_HERE, "notebooks")
|
||||
path_nbs = os.path.join(_PATH_HERE, "notebooks")
|
||||
if not os.path.isdir(path_nbs):
|
||||
os.mkdir(path_nbs)
|
||||
for path_ipynb in glob.glob(os.path.join(PATH_ROOT, "notebooks", "*.ipynb")):
|
||||
for path_ipynb in glob.glob(os.path.join(_PATH_ROOT, "notebooks", "*.ipynb")):
|
||||
path_ipynb2 = os.path.join(path_nbs, os.path.basename(path_ipynb))
|
||||
shutil.copy(path_ipynb, path_ipynb2)
|
||||
|
||||
|
@ -340,7 +344,7 @@ PACKAGE_MAPPING = {
|
|||
MOCK_PACKAGES = []
|
||||
if SPHINX_MOCK_REQUIREMENTS:
|
||||
# mock also base packages when we are on RTD since we don't install them there
|
||||
MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, "requirements.txt"))
|
||||
MOCK_PACKAGES += package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt"))
|
||||
MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]
|
||||
|
||||
autodoc_mock_imports = MOCK_PACKAGES
|
||||
|
|
44
setup.py
44
setup.py
|
@ -1,22 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Always prefer setuptools over distutils
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
_PATH_ROOT = os.path.realpath(os.path.dirname(__file__))
|
||||
try:
|
||||
import builtins
|
||||
from torchmetrics import info, setup_tools
|
||||
except ImportError:
|
||||
import __builtin__ as builtins
|
||||
# alternative https://stackoverflow.com/a/67692/4521646
|
||||
sys.path.append("torchmetrics")
|
||||
import info
|
||||
import setup_tools
|
||||
|
||||
# https://packaging.python.org/guides/single-sourcing-package-version/
|
||||
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
|
||||
PATH_ROOT = os.path.dirname(__file__)
|
||||
builtins.__LIGHTNING_SETUP__ = True
|
||||
|
||||
import torchmetrics # noqa: E402
|
||||
from torchmetrics.setup_tools import _load_readme_description, _load_requirements # noqa: E402
|
||||
long_description = setup_tools._load_readme_description(
|
||||
_PATH_ROOT,
|
||||
homepage=info.__homepage__,
|
||||
version=f'v{info.__version__}',
|
||||
)
|
||||
|
||||
# https://packaging.python.org/discussions/install-requires-vs-requirements /
|
||||
# keep the meta-data here for simplicity in reading this file... it's not obvious
|
||||
|
@ -25,26 +27,26 @@ from torchmetrics.setup_tools import _load_readme_description, _load_requirement
|
|||
# engineer specific practices
|
||||
setup(
|
||||
name='torchmetrics',
|
||||
version=torchmetrics.__version__,
|
||||
description=torchmetrics.__docs__,
|
||||
author=torchmetrics.__author__,
|
||||
author_email=torchmetrics.__author_email__,
|
||||
url=torchmetrics.__homepage__,
|
||||
download_url='https://github.com/PyTorchLightning/metrics/archive/master.zip',
|
||||
license=torchmetrics.__license__,
|
||||
version=info.__version__,
|
||||
description=info.__docs__,
|
||||
author=info.__author__,
|
||||
author_email=info.__author_email__,
|
||||
url=info.__homepage__,
|
||||
download_url=os.path.join(info.__homepage__, 'archive', 'master.zip'),
|
||||
license=info.__license__,
|
||||
packages=find_packages(exclude=['tests', 'docs']),
|
||||
long_description=_load_readme_description(PATH_ROOT, version=f'v{torchmetrics.__version__}'),
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
keywords=['deep learning', 'machine learning', 'pytorch', 'metrics', 'AI'],
|
||||
python_requires='>=3.6',
|
||||
setup_requires=[],
|
||||
install_requires=_load_requirements(PATH_ROOT),
|
||||
install_requires=setup_tools._load_requirements(_PATH_ROOT),
|
||||
project_urls={
|
||||
"Bug Tracker": "https://github.com/PyTorchLightning/torchmetrics/issues",
|
||||
"Bug Tracker": os.path.join(info.__homepage__, 'issues'),
|
||||
"Documentation": "https://torchmetrics.rtfd.io/en/latest/",
|
||||
"Source Code": "https://github.com/PyTorchLightning/torchmetrics",
|
||||
"Source Code": info.__homepage__,
|
||||
},
|
||||
classifiers=[
|
||||
'Environment :: Console',
|
||||
|
|
|
@ -2,22 +2,15 @@
|
|||
import logging as __logging
|
||||
import os
|
||||
|
||||
__version__ = '0.2.1dev'
|
||||
__author__ = 'PyTorchLightning et al.'
|
||||
__author_email__ = 'name@pytorchlightning.ai'
|
||||
__license__ = 'Apache-2.0'
|
||||
__copyright__ = f'Copyright (c) 2020-2021, {__author__}.'
|
||||
__homepage__ = 'https://github.com/PyTorchLightning/metrics'
|
||||
__docs__ = "PyTorch native Metrics"
|
||||
__long_doc__ = """
|
||||
Torchmetrics is a metrics API created for easy metric development and usage in both PyTorch and
|
||||
[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). It was originally a part of
|
||||
Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics
|
||||
implemented without having to install Pytorch Lightning (even though we would love for you to try it out).
|
||||
We currently have around 25+ metrics implemented and we continuously is adding more metrics, both within
|
||||
already covered domains (classification, regression ect.) but also new domains (object detection ect.).
|
||||
We make sure that all our metrics are rigorously tested such that you can trust them.
|
||||
"""
|
||||
from torchmetrics.info import ( # noqa: F401
|
||||
__author__,
|
||||
__author_email__,
|
||||
__copyright__,
|
||||
__docs__,
|
||||
__homepage__,
|
||||
__license__,
|
||||
__version__,
|
||||
)
|
||||
|
||||
_logger = __logging.getLogger("torchmetrics")
|
||||
_logger.addHandler(__logging.StreamHandler())
|
||||
|
@ -26,46 +19,31 @@ _logger.setLevel(__logging.INFO)
|
|||
_PACKAGE_ROOT = os.path.dirname(__file__)
|
||||
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
|
||||
|
||||
try:
|
||||
# This variable is injected in the __builtins__ by the build
|
||||
# process. It used to enable importing subpackages of skimage when
|
||||
# the binaries are not built
|
||||
_ = None if __LIGHTNING_SETUP__ else None
|
||||
except NameError:
|
||||
__LIGHTNING_SETUP__: bool = False
|
||||
|
||||
if __LIGHTNING_SETUP__:
|
||||
import sys # pragma: no-cover
|
||||
|
||||
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
|
||||
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
|
||||
else:
|
||||
|
||||
from torchmetrics.classification import ( # noqa: F401
|
||||
AUC,
|
||||
AUROC,
|
||||
F1,
|
||||
ROC,
|
||||
Accuracy,
|
||||
AveragePrecision,
|
||||
CohenKappa,
|
||||
ConfusionMatrix,
|
||||
FBeta,
|
||||
HammingDistance,
|
||||
IoU,
|
||||
Precision,
|
||||
PrecisionRecallCurve,
|
||||
Recall,
|
||||
StatScores,
|
||||
)
|
||||
from torchmetrics.collections import MetricCollection # noqa: F401
|
||||
from torchmetrics.metric import Metric # noqa: F401
|
||||
from torchmetrics.regression import ( # noqa: F401
|
||||
PSNR,
|
||||
SSIM,
|
||||
ExplainedVariance,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredError,
|
||||
MeanSquaredLogError,
|
||||
R2Score,
|
||||
)
|
||||
from torchmetrics.classification import ( # noqa: F401 E402
|
||||
AUC,
|
||||
AUROC,
|
||||
F1,
|
||||
ROC,
|
||||
Accuracy,
|
||||
AveragePrecision,
|
||||
CohenKappa,
|
||||
ConfusionMatrix,
|
||||
FBeta,
|
||||
HammingDistance,
|
||||
IoU,
|
||||
Precision,
|
||||
PrecisionRecallCurve,
|
||||
Recall,
|
||||
StatScores,
|
||||
)
|
||||
from torchmetrics.collections import MetricCollection # noqa: F401 E402
|
||||
from torchmetrics.metric import Metric # noqa: F401 E402
|
||||
from torchmetrics.regression import ( # noqa: F401 E402
|
||||
PSNR,
|
||||
SSIM,
|
||||
ExplainedVariance,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredError,
|
||||
MeanSquaredLogError,
|
||||
R2Score,
|
||||
)
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch import Tensor, tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torchmetrics.utilities import rank_zero_warn
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
__version__ = '0.2.1dev'
|
||||
__author__ = 'PyTorchLightning et al.'
|
||||
__author_email__ = 'name@pytorchlightning.ai'
|
||||
__license__ = 'Apache-2.0'
|
||||
__copyright__ = f'Copyright (c) 2020-2021, {__author__}.'
|
||||
__homepage__ = 'https://github.com/PyTorchLightning/metrics'
|
||||
__docs__ = "PyTorch native Metrics"
|
||||
__long_doc__ = """
|
||||
Torchmetrics is a metrics API created for easy metric development and usage in both PyTorch and
|
||||
[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/). It was originally a part of
|
||||
Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics
|
||||
implemented without having to install Pytorch Lightning (even though we would love for you to try it out).
|
||||
We currently have around 25+ metrics implemented and we continuously is adding more metrics, both within
|
||||
already covered domains (classification, regression ect.) but also new domains (object detection ect.).
|
||||
We make sure that all our metrics are rigorously tested such that you can trust them.
|
||||
"""
|
|
@ -15,7 +15,7 @@ import os
|
|||
import re
|
||||
from typing import List
|
||||
|
||||
from torchmetrics import _PROJECT_ROOT, __homepage__, __version__
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
|
||||
def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]:
|
||||
|
@ -39,10 +39,10 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme
|
|||
return reqs
|
||||
|
||||
|
||||
def _load_readme_description(path_dir: str, homepage: str = __homepage__, version: str = __version__) -> str:
|
||||
def _load_readme_description(path_dir: str, homepage: str, version: str) -> str:
|
||||
"""Load readme as decribtion
|
||||
|
||||
>>> _load_readme_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
>>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
'<div align="center">...'
|
||||
"""
|
||||
path_readme = os.path.join(path_dir, "README.md")
|
||||
|
|
Загрузка…
Ссылка в новой задаче