flake8
This commit is contained in:
Родитель
d00b73c9ec
Коммит
fd1c11e26e
|
@ -51,7 +51,6 @@ release = torchmetrics.__version__
|
|||
github_user = 'PyTorchLightning'
|
||||
github_repo = project
|
||||
|
||||
|
||||
# -- Project documents -------------------------------------------------------
|
||||
# export the READme
|
||||
with open(os.path.join(PATH_ROOT, 'README.md'), 'r') as fp:
|
||||
|
@ -139,7 +138,6 @@ exclude_patterns = [
|
|||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = None
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
|
@ -178,7 +176,6 @@ html_static_path = ['_images', '_templates', '_static']
|
|||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
|
@ -211,9 +208,7 @@ latex_documents = [
|
|||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, project, project + ' Documentation', [author], 1)
|
||||
]
|
||||
man_pages = [(master_doc, project, project + ' Documentation', [author], 1)]
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
|
@ -221,9 +216,7 @@ man_pages = [
|
|||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, project, project + ' Documentation', author, project,
|
||||
torchmetrics.__docs__,
|
||||
'Miscellaneous'),
|
||||
(master_doc, project, project + ' Documentation', author, project, torchmetrics.__docs__, 'Miscellaneous'),
|
||||
]
|
||||
|
||||
# -- Options for Epub output -------------------------------------------------
|
||||
|
@ -282,13 +275,10 @@ def run_apidoc(_):
|
|||
shutil.rmtree(apidoc_output_folder)
|
||||
|
||||
for pkg in PACKAGES:
|
||||
argv = ['-e',
|
||||
'-o', apidoc_output_folder,
|
||||
os.path.join(PATH_ROOT, pkg),
|
||||
'**/test_*',
|
||||
'--force',
|
||||
'--private',
|
||||
'--module-first']
|
||||
argv = [
|
||||
'-e', '-o', apidoc_output_folder,
|
||||
os.path.join(PATH_ROOT, pkg), '**/test_*', '--force', '--private', '--module-first'
|
||||
]
|
||||
|
||||
apidoc.main(argv)
|
||||
|
||||
|
@ -338,6 +328,7 @@ autodoc_mock_imports = MOCK_PACKAGES
|
|||
# Resolve function
|
||||
# This function is used to populate the (source) links in the API
|
||||
def linkcode_resolve(domain, info):
|
||||
|
||||
def find_source():
|
||||
# try to find the file and line number, based on code from numpy:
|
||||
# https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286
|
||||
|
|
12
setup.cfg
12
setup.cfg
|
@ -39,6 +39,18 @@ description-file = README.md
|
|||
# long_description = file:README.md
|
||||
# long_description_content_type = text/markdown
|
||||
|
||||
[yapf]
|
||||
based_on_style = pep8
|
||||
spaces_before_comment = 2
|
||||
split_before_logical_operator = true
|
||||
split_before_arithmetic_operator = true
|
||||
COLUMN_LIMIT = 120
|
||||
COALESCE_BRACKETS = true
|
||||
DEDENT_CLOSING_BRACKETS = true
|
||||
ALLOW_SPLIT_BEFORE_DICT_VALUE = false
|
||||
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
|
||||
NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false
|
||||
|
||||
[mypy]
|
||||
# Typing tests is low priority, but enabling type checking on the
|
||||
# untyped test functions (using `--check-untyped-defs`) is still
|
||||
|
|
4
setup.py
4
setup.py
|
@ -52,23 +52,19 @@ setup(
|
|||
download_url='https://github.com/PyTorchLightning/torchmetrics',
|
||||
license=torchmetrics.__license__,
|
||||
packages=find_packages(exclude=['tests', 'docs']),
|
||||
|
||||
long_description=load_long_describtion(),
|
||||
long_description_content_type='text/markdown',
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
|
||||
keywords=['deep learning', 'pytorch', 'AI'],
|
||||
python_requires='>=3.6',
|
||||
setup_requires=[],
|
||||
install_requires=load_requirements(PATH_ROOT),
|
||||
|
||||
project_urls={
|
||||
"Bug Tracker": "https://github.com/PyTorchLightning/torchmetrics/issues",
|
||||
"Documentation": "https://torchmetrics.rtfd.io/en/latest/",
|
||||
"Source Code": "https://github.com/PyTorchLightning/torchmetrics",
|
||||
},
|
||||
|
||||
classifiers=[
|
||||
'Environment :: Console',
|
||||
'Natural Language :: English',
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from tests.test_metric import Dummy # noqa: F401
|
||||
from tests.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, MetricTester # noqa: F401
|
|
@ -17,6 +17,7 @@ from torch.utils.data import Dataset
|
|||
|
||||
|
||||
class RandomDictStringDataset(Dataset):
|
||||
|
||||
def __init__(self, size, length):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
@ -29,6 +30,7 @@ class RandomDictStringDataset(Dataset):
|
|||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
|
||||
def __init__(self, size, length):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
|
|
@ -124,5 +124,8 @@ def test_warning_on_nan(tmpdir):
|
|||
preds = torch.randint(3, size=(20, ))
|
||||
target = torch.randint(3, size=(20, ))
|
||||
|
||||
with pytest.warns(UserWarning, match='.* nan values found in confusion matrix have been replaced with zeros.'):
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match='.* nan values found in confusion matrix have been replaced with zeros.',
|
||||
):
|
||||
confusion_matrix(preds, target, num_classes=5, normalize='true')
|
||||
|
|
|
@ -106,7 +106,6 @@ def test_metric_lightning(tmpdir):
|
|||
# assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum)
|
||||
# assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
|
||||
|
||||
|
||||
# todo: need to be fixed
|
||||
# def test_scriptable(tmpdir):
|
||||
# class TestModel(BoringModel):
|
||||
|
@ -145,7 +144,6 @@ def test_metric_lightning(tmpdir):
|
|||
# script_output = script_model(rand_input)
|
||||
# assert torch.allclose(output, script_output)
|
||||
|
||||
|
||||
# def test_metric_collection_lightning_log(tmpdir):
|
||||
#
|
||||
# class TestModel(BoringModel):
|
||||
|
|
|
@ -29,28 +29,12 @@ if __LIGHTNING_SETUP__:
|
|||
else:
|
||||
|
||||
from torchmetrics.classification import ( # noqa: F401
|
||||
Accuracy,
|
||||
AUC,
|
||||
AUROC,
|
||||
AveragePrecision,
|
||||
ConfusionMatrix,
|
||||
F1,
|
||||
FBeta,
|
||||
HammingDistance,
|
||||
IoU,
|
||||
Precision,
|
||||
PrecisionRecallCurve,
|
||||
Recall,
|
||||
ROC,
|
||||
Accuracy, AUC, AUROC, AveragePrecision, ConfusionMatrix, F1, FBeta,
|
||||
HammingDistance, IoU, Precision, PrecisionRecallCurve, Recall, ROC,
|
||||
StatScores,
|
||||
)
|
||||
from torchmetrics.metric import Metric, MetricCollection # noqa: F401
|
||||
from torchmetrics.regression import ( # noqa: F401
|
||||
ExplainedVariance,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredError,
|
||||
MeanSquaredLogError,
|
||||
PSNR,
|
||||
R2Score,
|
||||
SSIM,
|
||||
ExplainedVariance, MeanAbsoluteError, MeanSquaredError,
|
||||
MeanSquaredLogError, PSNR, R2Score, SSIM,
|
||||
)
|
||||
|
|
|
@ -84,9 +84,7 @@ class CompositionalMetric(Metric):
|
|||
self.metric_b.persistent(mode=mode)
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (
|
||||
self.__class__.__name__
|
||||
+ f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
|
||||
)
|
||||
_op_metrics = f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
|
||||
repr_str = (self.__class__.__name__ + _op_metrics)
|
||||
|
||||
return repr_str
|
||||
|
|
|
@ -16,12 +16,8 @@ from torchmetrics.functional.auc import auc # noqa: F401
|
|||
from torchmetrics.functional.auroc import auroc # noqa: F401
|
||||
from torchmetrics.functional.average_precision import average_precision # noqa: F401
|
||||
from torchmetrics.functional.classification import ( # noqa: F401
|
||||
dice_score,
|
||||
get_num_classes,
|
||||
multiclass_auroc,
|
||||
stat_scores_multiple_classes,
|
||||
to_categorical,
|
||||
to_onehot,
|
||||
dice_score, get_num_classes, multiclass_auroc,
|
||||
stat_scores_multiple_classes, to_categorical, to_onehot,
|
||||
)
|
||||
from torchmetrics.functional.confusion_matrix import confusion_matrix # noqa: F401
|
||||
from torchmetrics.functional.explained_variance import explained_variance # noqa: F401
|
||||
|
|
|
@ -318,10 +318,10 @@ class Metric(nn.Module, ABC):
|
|||
# filter all parameters based on update signature except those of
|
||||
# type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs)
|
||||
_params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
|
||||
_sign_params = self._update_signature.parameters
|
||||
filtered_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items() if k in self._update_signature.parameters.keys()
|
||||
and self._update_signature.parameters[k].kind not in _params
|
||||
for k, v in kwargs.items() if (k in _sign_params.keys() and _sign_params[k].kind not in _params)
|
||||
}
|
||||
|
||||
# if no kwargs filtered, return al kwargs as default
|
||||
|
|
Загрузка…
Ссылка в новой задаче