This commit is contained in:
Jirka Borovec 2021-02-22 01:34:34 +01:00
Родитель d00b73c9ec
Коммит fd1c11e26e
12 изменённых файлов: 37 добавлений и 55 удалений

Просмотреть файл

@ -51,7 +51,6 @@ release = torchmetrics.__version__
github_user = 'PyTorchLightning'
github_repo = project
# -- Project documents -------------------------------------------------------
# export the READme
with open(os.path.join(PATH_ROOT, 'README.md'), 'r') as fp:
@ -139,7 +138,6 @@ exclude_patterns = [
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
@ -178,7 +176,6 @@ html_static_path = ['_images', '_templates', '_static']
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
@ -211,9 +208,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, project, project + ' Documentation', [author], 1)
]
man_pages = [(master_doc, project, project + ' Documentation', [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
@ -221,9 +216,7 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, project, project + ' Documentation', author, project,
torchmetrics.__docs__,
'Miscellaneous'),
(master_doc, project, project + ' Documentation', author, project, torchmetrics.__docs__, 'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
@ -282,13 +275,10 @@ def run_apidoc(_):
shutil.rmtree(apidoc_output_folder)
for pkg in PACKAGES:
argv = ['-e',
'-o', apidoc_output_folder,
os.path.join(PATH_ROOT, pkg),
'**/test_*',
'--force',
'--private',
'--module-first']
argv = [
'-e', '-o', apidoc_output_folder,
os.path.join(PATH_ROOT, pkg), '**/test_*', '--force', '--private', '--module-first'
]
apidoc.main(argv)
@ -338,6 +328,7 @@ autodoc_mock_imports = MOCK_PACKAGES
# Resolve function
# This function is used to populate the (source) links in the API
def linkcode_resolve(domain, info):
def find_source():
# try to find the file and line number, based on code from numpy:
# https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286

Просмотреть файл

@ -39,6 +39,18 @@ description-file = README.md
# long_description = file:README.md
# long_description_content_type = text/markdown
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true
split_before_arithmetic_operator = true
COLUMN_LIMIT = 120
COALESCE_BRACKETS = true
DEDENT_CLOSING_BRACKETS = true
ALLOW_SPLIT_BEFORE_DICT_VALUE = false
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false
[mypy]
# Typing tests is low priority, but enabling type checking on the
# untyped test functions (using `--check-untyped-defs`) is still

Просмотреть файл

@ -52,23 +52,19 @@ setup(
download_url='https://github.com/PyTorchLightning/torchmetrics',
license=torchmetrics.__license__,
packages=find_packages(exclude=['tests', 'docs']),
long_description=load_long_describtion(),
long_description_content_type='text/markdown',
include_package_data=True,
zip_safe=False,
keywords=['deep learning', 'pytorch', 'AI'],
python_requires='>=3.6',
setup_requires=[],
install_requires=load_requirements(PATH_ROOT),
project_urls={
"Bug Tracker": "https://github.com/PyTorchLightning/torchmetrics/issues",
"Documentation": "https://torchmetrics.rtfd.io/en/latest/",
"Source Code": "https://github.com/PyTorchLightning/torchmetrics",
},
classifiers=[
'Environment :: Console',
'Natural Language :: English',

Просмотреть файл

Просмотреть файл

@ -0,0 +1,2 @@
from tests.test_metric import Dummy # noqa: F401
from tests.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, MetricTester # noqa: F401

Просмотреть файл

@ -17,6 +17,7 @@ from torch.utils.data import Dataset
class RandomDictStringDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
@ -29,6 +30,7 @@ class RandomDictStringDataset(Dataset):
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

Просмотреть файл

@ -124,5 +124,8 @@ def test_warning_on_nan(tmpdir):
preds = torch.randint(3, size=(20, ))
target = torch.randint(3, size=(20, ))
with pytest.warns(UserWarning, match='.* nan values found in confusion matrix have been replaced with zeros.'):
with pytest.warns(
UserWarning,
match='.* nan values found in confusion matrix have been replaced with zeros.',
):
confusion_matrix(preds, target, num_classes=5, normalize='true')

Просмотреть файл

@ -106,7 +106,6 @@ def test_metric_lightning(tmpdir):
# assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
# assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
# todo: need to be fixed
# def test_scriptable(tmpdir):
# class TestModel(BoringModel):
@ -145,7 +144,6 @@ def test_metric_lightning(tmpdir):
# script_output = script_model(rand_input)
# assert torch.allclose(output, script_output)
# def test_metric_collection_lightning_log(tmpdir):
#
# class TestModel(BoringModel):

Просмотреть файл

@ -29,28 +29,12 @@ if __LIGHTNING_SETUP__:
else:
from torchmetrics.classification import ( # noqa: F401
Accuracy,
AUC,
AUROC,
AveragePrecision,
ConfusionMatrix,
F1,
FBeta,
HammingDistance,
IoU,
Precision,
PrecisionRecallCurve,
Recall,
ROC,
Accuracy, AUC, AUROC, AveragePrecision, ConfusionMatrix, F1, FBeta,
HammingDistance, IoU, Precision, PrecisionRecallCurve, Recall, ROC,
StatScores,
)
from torchmetrics.metric import Metric, MetricCollection # noqa: F401
from torchmetrics.regression import ( # noqa: F401
ExplainedVariance,
MeanAbsoluteError,
MeanSquaredError,
MeanSquaredLogError,
PSNR,
R2Score,
SSIM,
ExplainedVariance, MeanAbsoluteError, MeanSquaredError,
MeanSquaredLogError, PSNR, R2Score, SSIM,
)

Просмотреть файл

@ -84,9 +84,7 @@ class CompositionalMetric(Metric):
self.metric_b.persistent(mode=mode)
def __repr__(self):
repr_str = (
self.__class__.__name__
+ f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
)
_op_metrics = f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
repr_str = (self.__class__.__name__ + _op_metrics)
return repr_str

Просмотреть файл

@ -16,12 +16,8 @@ from torchmetrics.functional.auc import auc # noqa: F401
from torchmetrics.functional.auroc import auroc # noqa: F401
from torchmetrics.functional.average_precision import average_precision # noqa: F401
from torchmetrics.functional.classification import ( # noqa: F401
dice_score,
get_num_classes,
multiclass_auroc,
stat_scores_multiple_classes,
to_categorical,
to_onehot,
dice_score, get_num_classes, multiclass_auroc,
stat_scores_multiple_classes, to_categorical, to_onehot,
)
from torchmetrics.functional.confusion_matrix import confusion_matrix # noqa: F401
from torchmetrics.functional.explained_variance import explained_variance # noqa: F401

Просмотреть файл

@ -318,10 +318,10 @@ class Metric(nn.Module, ABC):
# filter all parameters based on update signature except those of
# type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs)
_params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
_sign_params = self._update_signature.parameters
filtered_kwargs = {
k: v
for k, v in kwargs.items() if k in self._update_signature.parameters.keys()
and self._update_signature.parameters[k].kind not in _params
for k, v in kwargs.items() if (k in _sign_params.keys() and _sign_params[k].kind not in _params)
}
# if no kwargs filtered, return al kwargs as default