* restructure

* config

* transfer

* yapf
This commit is contained in:
Jirka Borovec 2021-02-23 17:03:00 +01:00 коммит произвёл GitHub
Родитель d767bdcb5d
Коммит 55edecbbfe
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
32 изменённых файлов: 1105 добавлений и 714 удалений

76
.github/CODE_OF_CONDUCT.md поставляемый Normal file
Просмотреть файл

@ -0,0 +1,76 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at waf2107@columbia.edu. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

5
.github/CONTRIBUTING.md поставляемый Normal file
Просмотреть файл

@ -0,0 +1,5 @@
# Contributing
Welcome to the PyTorch Lightning community! We're building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out!
**TBD**

Просмотреть файл

@ -1,13 +1,13 @@
sphinx>=2.0, <3.0
sphinx>=3.0, !=3.5 # fails with sphinx.ext.viewcode
recommonmark # fails with badges
m2r # fails with multi-line text
nbsphinx
pandoc
docutils
sphinxcontrib-fulltoc
nbsphinx>=0.8
pandoc>=1.0
docutils>=0.16
sphinxcontrib-fulltoc>=1.0
sphinxcontrib-mockautodoc
git+https://github.com/PytorchLightning/lightning_sphinx_theme.git
# pip_shims
sphinx-autodoc-typehints
sphinx-paramlinks<0.4.0
https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#egg=pt-lightning-sphinx-theme
sphinx-autodoc-typehints>=1.0
sphinx-paramlinks>=0.4.0
sphinx-togglebutton>=0.2
sphinx-copybutton>=0.3

Просмотреть файл

@ -1,62 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
id="svg"
version="1.1"
width="16.000004"
height="15.999986"
viewBox="0 0 16.000004 15.999986"
sodipodi:docname="lightning_icon.svg"
inkscape:version="0.92.3 (2405546, 2018-03-11)">
<metadata
id="metadata13">
<rdf:RDF>
<cc:Work
rdf:about="">
<dc:format>image/svg+xml</dc:format>
<dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
<dc:title></dc:title>
</cc:Work>
</rdf:RDF>
</metadata>
<defs
id="defs11" />
<sodipodi:namedview
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1"
objecttolerance="10"
gridtolerance="10"
guidetolerance="10"
inkscape:pageopacity="0"
inkscape:pageshadow="2"
inkscape:window-width="1920"
inkscape:window-height="1028"
id="namedview9"
showgrid="false"
inkscape:zoom="0.59"
inkscape:cx="-669.05062"
inkscape:cy="373.84245"
inkscape:window-x="0"
inkscape:window-y="0"
inkscape:window-maximized="1"
inkscape:current-layer="svg" />
<path
style="fill:#fbfbfb;fill-rule:evenodd;stroke:none;stroke-width:0.04002798"
inkscape:connector-curvature="0"
d="m 8.987101,1.723485 c -0.05588,0.03422 -4.121881,4.096544 -4.184645,4.180924 -0.02317,0.0311 -0.04587,0.06016 -0.05044,0.06456 -0.0087,0.0084 -0.07477,0.145063 -0.09679,0.20014 -0.05848,0.146583 -0.05804,0.44387 0.001,0.592413 0.08426,0.21243 0.08826,0.216754 1.576864,1.706274 0.779463,0.779947 1.41719,1.426877 1.41719,1.437604 0,0.0232 -0.253177,0.79848 -0.273873,0.838707 -0.0079,0.0153 -0.01433,0.04087 -0.01433,0.05684 0,0.01597 -0.0059,0.03587 -0.01313,0.04423 -0.0072,0.0084 -0.03678,0.09086 -0.06568,0.18333 -0.02893,0.09246 -0.05904,0.180647 -0.06693,0.195937 -0.0079,0.0153 -0.01437,0.04087 -0.01437,0.05684 0,0.01597 -0.0059,0.03586 -0.01313,0.04423 -0.0072,0.0084 -0.03679,0.09086 -0.06569,0.18333 -0.02893,0.09246 -0.05904,0.180643 -0.06693,0.195937 -0.0079,0.0153 -0.01437,0.04187 -0.01437,0.05908 0,0.0172 -0.0072,0.03574 -0.016,0.04119 -0.0088,0.0054 -0.016,0.02607 -0.016,0.04579 0,0.01973 -0.006,0.04271 -0.0134,0.05108 -0.0074,0.0084 -0.04439,0.112477 -0.08222,0.23136 -0.03787,0.118884 -0.151103,0.461124 -0.251693,0.760534 -0.489984,1.45874 -0.462444,1.36155 -0.413611,1.45938 0.06917,0.138657 0.23128,0.199741 0.358251,0.134974 0.07057,-0.03602 4.143298,-4.099985 4.245368,-4.236242 0.03382,-0.04515 0.09094,-0.165796 0.109916,-0.232123 0.0088,-0.03083 0.0243,-0.08498 0.03442,-0.120363 0.03346,-0.11668 0.0068,-0.361134 -0.0566,-0.520084 C 10.880518,9.229614 10.738898,9.079187 9.372744,7.714673 8.601524,6.944416 7.970523,6.302806 7.970523,6.288916 c 0,-0.01393 0.02817,-0.107833 0.0626,-0.208663 0.03442,-0.100834 0.07881,-0.237367 0.09859,-0.303414 0.0198,-0.06605 0.04207,-0.12693 0.04947,-0.135293 0.0074,-0.0084 0.0135,-0.03133 0.0135,-0.05108 0,-0.01973 0.0072,-0.04035 0.016,-0.04579 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04804 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04804 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02397 0.016,-0.04119 0,-0.0172 0.0065,-0.04379 0.0144,-0.05908 0.0079,-0.0153 0.119204,-0.34484 0.247334,-0.73231 C 9.064507,2.979766 9.220177,2.513319 9.28226,2.330632 9.408267,1.960092 9.41367,1.921146 9.35255,1.826839 9.27225,1.703032 9.099973,1.654399 8.986893,1.723566"
id="path0" />
<path
style="fill:#540c8c;fill-rule:evenodd;stroke:none;stroke-width:0.04002798"
inkscape:connector-curvature="0"
d="m 0.07719102,0.01733399 c -0.02187,0.0111 -0.04875,0.03799 -0.05984,0.05984 -0.0161,0.03173 -0.01937,1.62421701 -0.01633,7.94479601 l 0.0038,7.905086 0.03647,0.03646 0.03646,0.03647 H 8.00241 15.927073 l 0.03646,-0.03647 0.03647,-0.03646 V 8.002393 0.07773399 l -0.03647,-0.03646 -0.03646,-0.03647 -7.905086,-0.0038 c -6.320579,-0.003 -7.91305298,2.4e-4 -7.94479598,0.01633 M 9.193764,1.668208 c 0.259903,0.09046 0.275193,0.212427 0.09363,0.74628 C 8.845834,3.776859 8.388843,5.102846 7.991127,6.302606 L 9.415644,7.72492 c 1.24415,1.242111 1.51682,1.523547 1.51682,1.565414 0,0.0051 0.0133,0.03987 0.02953,0.07718 0.12913,0.296607 0.0877,0.664983 -0.103314,0.91872 -0.141456,0.187933 -4.207341,4.228478 -4.273468,4.246848 -0.139417,0.03871 -0.248653,-0.006 -0.34324,-0.140417 -0.07665,-0.108996 -0.06985,-0.137256 0.287004,-1.194633 0.34663,-1.101761 0.75901,-2.243218 1.08916,-3.290661 0,-0.0078 -0.636164,-0.650377 -1.413707,-1.427921 C 4.877658,7.152643 4.728155,6.995813 4.673718,6.87361 4.661948,6.84718 4.645988,6.81305 4.638168,6.79776 4.630368,6.78246 4.624038,6.75689 4.624038,6.74092 c 0,-0.01597 -0.0076,-0.03659 -0.01687,-0.04587 -0.02253,-0.02253 -0.02253,-0.436904 0,-0.45944 0.0093,-0.0093 0.01687,-0.0327 0.01687,-0.05204 0,-0.0363 0.06917,-0.178363 0.130414,-0.267907 0.07965,-0.1164 4.221831,-4.237681 4.259458,-4.237921 0.02047,-1.2e-4 0.04803,-0.0072 0.06124,-0.01577 0.03147,-0.02033 0.04415,-0.01967 0.118603,0.0062"
id="path1"
sodipodi:nodetypes="ccscccccccccccscccccscccccccccsssscccc" />
</svg>

До

Ширина:  |  Высота:  |  Размер: 6.4 KiB

Просмотреть файл

@ -1,61 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
id="svg"
version="1.1"
width="400"
height="400"
viewBox="0, 0, 400,400"
sodipodi:docname="lightning_logo.svg"
inkscape:version="0.92.3 (2405546, 2018-03-11)">
<metadata
id="metadata13">
<rdf:RDF>
<cc:Work
rdf:about="">
<dc:format>image/svg+xml</dc:format>
<dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
</cc:Work>
</rdf:RDF>
</metadata>
<defs
id="defs11" />
<sodipodi:namedview
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1"
objecttolerance="10"
gridtolerance="10"
guidetolerance="10"
inkscape:pageopacity="0"
inkscape:pageshadow="2"
inkscape:window-width="1920"
inkscape:window-height="1028"
id="namedview9"
showgrid="false"
inkscape:zoom="9.44"
inkscape:cx="203.07907"
inkscape:cy="335.32491"
inkscape:window-x="0"
inkscape:window-y="0"
inkscape:window-maximized="1"
inkscape:current-layer="svg" />
<path
style="fill:#fbfbfb;fill-rule:evenodd;stroke:none"
inkscape:connector-curvature="0"
d="m 224.6,43.137 c -1.396,0.855 -102.975,102.342 -104.543,104.45 -0.579,0.777 -1.146,1.503 -1.26,1.613 -0.218,0.21 -1.868,3.624 -2.418,5 -1.461,3.662 -1.45,11.089 0.022,14.8 2.105,5.307 2.205,5.415 39.394,42.627 19.473,19.485 35.405,35.647 35.405,35.915 0,0.58 -6.325,19.948 -6.842,20.953 -0.197,0.382 -0.358,1.021 -0.358,1.42 0,0.399 -0.147,0.896 -0.328,1.105 -0.18,0.209 -0.919,2.27 -1.641,4.58 -0.723,2.31 -1.475,4.513 -1.672,4.895 -0.198,0.382 -0.359,1.021 -0.359,1.42 0,0.399 -0.147,0.896 -0.328,1.105 -0.18,0.209 -0.919,2.27 -1.641,4.58 -0.723,2.31 -1.475,4.513 -1.672,4.895 -0.198,0.382 -0.359,1.046 -0.359,1.476 0,0.43 -0.18,0.893 -0.4,1.029 -0.22,0.136 -0.4,0.651 -0.4,1.144 0,0.493 -0.151,1.067 -0.335,1.276 -0.184,0.209 -1.109,2.81 -2.054,5.78 -0.946,2.97 -3.775,11.52 -6.288,19 -12.241,36.443 -11.553,34.015 -10.333,36.459 1.728,3.464 5.778,4.99 8.95,3.372 1.763,-0.9 103.51,-102.428 106.06,-105.832 0.845,-1.128 2.272,-4.142 2.746,-5.799 0.22,-0.77 0.607,-2.123 0.86,-3.007 0.836,-2.915 0.171,-9.022 -1.414,-12.993 -1.493,-3.741 -5.031,-7.499 -39.161,-41.588 C 214.964,173.569 199.2,157.54 199.2,157.193 c 0,-0.348 0.704,-2.694 1.564,-5.213 0.86,-2.519 1.969,-5.93 2.463,-7.58 0.495,-1.65 1.051,-3.171 1.236,-3.38 0.186,-0.209 0.337,-0.783 0.337,-1.276 0,-0.493 0.18,-1.008 0.4,-1.144 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.599 0.4,-1.029 0,-0.43 0.162,-1.094 0.36,-1.476 0.197,-0.382 2.978,-8.615 6.179,-18.295 3.2,-9.68 7.089,-21.333 8.64,-25.897 3.148,-9.257 3.283,-10.23 1.756,-12.586 -2.006,-3.093 -6.31,-4.308 -9.135,-2.58"
id="path0" />
<path
style="fill:#540c8c;fill-rule:evenodd;stroke:none"
inkscape:connector-curvature="0"
d="M 2.008,0.513 C 1.462,0.79 0.79,1.462 0.513,2.008 0.111,2.801 0.029,42.585 0.105,200.489 L 0.2,397.978 1.111,398.889 2.022,399.8 H 200 397.978 l 0.911,-0.911 0.911,-0.911 V 200 2.022 L 398.889,1.111 397.978,0.2 200.489,0.105 C 42.585,0.029 2.801,0.111 2.008,0.513 m 227.755,41.243 c 6.493,2.26 6.875,5.307 2.339,18.644 -11.0313,34.035452 -22.44803,67.16196 -32.384,97.135 l 35.588,35.533 c 31.082,31.031 37.894,38.062 37.894,39.108 0,0.128 0.332,0.996 0.738,1.928 3.226,7.41 2.191,16.613 -2.581,22.952 -3.534,4.695 -105.11,105.638 -106.762,106.097 -3.483,0.967 -6.212,-0.15 -8.575,-3.508 -1.915,-2.723 -1.745,-3.429 7.17,-29.845 8.65971,-27.52475 18.96205,-56.04122 27.21,-82.209 0,-0.195 -15.893,-16.248 -35.318,-35.673 -33.146,-33.147 -36.881,-37.065 -38.241,-40.118 -0.294,-0.66 -0.693,-1.513 -0.888,-1.895 -0.194,-0.382 -0.353,-1.021 -0.353,-1.42 0,-0.399 -0.189,-0.914 -0.421,-1.146 -0.563,-0.563 -0.563,-10.915 0,-11.478 0.232,-0.232 0.421,-0.817 0.421,-1.3 0,-0.907 1.728,-4.456 3.258,-6.693 C 120.848,144.96 224.33,42 225.27,41.994 c 0.511,-0.003 1.2,-0.181 1.53,-0.394 0.786,-0.508 1.103,-0.491 2.963,0.156"
id="path1"
sodipodi:nodetypes="ccscccccccccccscccccscccccccccsssscccc" />
</svg>

До

Ширина:  |  Высота:  |  Размер: 5.2 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

До

Ширина:  |  Высота:  |  Размер: 16 KiB

Двоичные данные
docs/source/_images/logos/lightning_logo.png

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 8.3 KiB

Просмотреть файл

@ -1,62 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:dc="http://purl.org/dc/elements/1.1/"
xmlns:cc="http://creativecommons.org/ns#"
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
id="svg"
version="1.1"
width="47.999985"
height="47.999943"
viewBox="0 0 47.999985 47.999943"
sodipodi:docname="lightning_logo.svg"
inkscape:version="0.92.3 (2405546, 2018-03-11)">
<metadata
id="metadata13">
<rdf:RDF>
<cc:Work
rdf:about="">
<dc:format>image/svg+xml</dc:format>
<dc:type
rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
<dc:title />
</cc:Work>
</rdf:RDF>
</metadata>
<defs
id="defs11" />
<sodipodi:namedview
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1"
objecttolerance="10"
gridtolerance="10"
guidetolerance="10"
inkscape:pageopacity="0"
inkscape:pageshadow="2"
inkscape:window-width="1920"
inkscape:window-height="1028"
id="namedview9"
showgrid="false"
inkscape:zoom="0.59"
inkscape:cx="-347.96588"
inkscape:cy="389.84243"
inkscape:window-x="0"
inkscape:window-y="0"
inkscape:window-maximized="1"
inkscape:current-layer="svg" />
<path
style="fill:#fbfbfb;fill-rule:evenodd;stroke:none;stroke-width:0.12008391"
inkscape:connector-curvature="0"
d="m 26.961294,5.1704519 c -0.16764,0.10267 -12.36564,12.2896301 -12.55393,12.5427701 -0.0695,0.0933 -0.13762,0.18048 -0.15131,0.19369 -0.0262,0.0252 -0.22432,0.43519 -0.29036,0.60042 -0.17544,0.43975 -0.17412,1.33161 0.003,1.77724 0.25278,0.63729 0.26479,0.65026 4.73059,5.11882 2.33839,2.33984 4.25157,4.28063 4.25157,4.31281 0,0.0696 -0.75953,2.39544 -0.82162,2.51612 -0.0237,0.0459 -0.043,0.12261 -0.043,0.17052 0,0.0479 -0.0177,0.1076 -0.0394,0.13269 -0.0216,0.0251 -0.11035,0.27259 -0.19705,0.54999 -0.0868,0.27739 -0.17713,0.54194 -0.20078,0.58781 -0.0238,0.0459 -0.0431,0.1226 -0.0431,0.17052 0,0.0479 -0.0177,0.10759 -0.0394,0.13269 -0.0216,0.0251 -0.11036,0.27259 -0.19706,0.54999 -0.0868,0.27739 -0.17712,0.54193 -0.20078,0.58781 -0.0238,0.0459 -0.0431,0.1256 -0.0431,0.17724 0,0.0516 -0.0216,0.10723 -0.048,0.12357 -0.0264,0.0163 -0.048,0.0782 -0.048,0.13737 0,0.0592 -0.0181,0.12813 -0.0402,0.15323 -0.0221,0.0251 -0.13318,0.33743 -0.24666,0.69408 -0.1136,0.35665 -0.45331,1.38337 -0.75508,2.2816 -1.46995,4.37622 -1.38733,4.08465 -1.24083,4.37814 0.2075,0.41597 0.69384,0.59922 1.07475,0.40492 0.21171,-0.10807 12.42989,-12.29995 12.7361,-12.70872 0.10147,-0.13545 0.27283,-0.49739 0.32975,-0.69637 0.0264,-0.0925 0.0729,-0.25493 0.10327,-0.36109 0.10039,-0.35004 0.0205,-1.0834 -0.1698,-1.56025 -0.17928,-0.44923 -0.60414,-0.90051 -4.7026,-4.99405 -2.31366,-2.31077 -4.20666,-4.2356 -4.20666,-4.27727 0,-0.0418 0.0845,-0.3235 0.18781,-0.62599 0.10327,-0.3025 0.23644,-0.7121 0.29577,-0.91024 0.0594,-0.19814 0.1262,-0.38079 0.14842,-0.40588 0.0223,-0.0251 0.0405,-0.094 0.0405,-0.15323 0,-0.0592 0.0216,-0.12105 0.048,-0.13738 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.14411 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.14411 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0719 0.048,-0.12356 0,-0.0516 0.0195,-0.13137 0.0432,-0.17725 0.0237,-0.0459 0.35761,-1.03452 0.742,-2.19693 0.38427,-1.1624101 0.85128,-2.5617501 1.03753,-3.1098101 0.37802,-1.11162 0.39423,-1.22846 0.21087,-1.51138 -0.24089,-0.37142 -0.75773,-0.51732 -1.09697,-0.30982"
id="path0" />
<path
style="fill:#540c8c;fill-rule:evenodd;stroke:none;stroke-width:0.12008391"
inkscape:connector-curvature="0"
d="m 0.2315739,0.05200186 c -0.0656,0.0333 -0.14626,0.11396 -0.17952,0.17952 -0.0483,0.0952 -0.0581,4.87265004 -0.049,23.83438014 l 0.0114,23.71525 0.1094,0.10939 0.10939,0.1094 h 23.7739701 23.77398 l 0.10939,-0.1094 0.1094,-0.10939 V 24.007172 0.23320186 l -0.1094,-0.10939 -0.10939,-0.1094 -23.71525,-0.0114 c -18.9617301,-0.009 -23.7391501,7.2e-4 -23.8343801,0.049 M 27.581274,5.0046319 c 0.77971,0.27139 0.82558,0.63728 0.28088,2.23884 -1.32468,4.0871101 -2.69565,8.0650701 -3.8888,11.6643501 l 4.27355,4.26694 c 3.73245,3.72633 4.55046,4.57064 4.55046,4.69624 0,0.0154 0.0399,0.11961 0.0886,0.23153 0.38739,0.88982 0.2631,1.99495 -0.30994,2.75616 -0.42437,0.5638 -12.62202,12.68543 -12.8204,12.74054 -0.41825,0.11613 -0.74596,-0.018 -1.02972,-0.42125 -0.22996,-0.32699 -0.20954,-0.41177 0.86101,-3.5839 1.03989,-3.30528 2.27703,-6.72965 3.26748,-9.87198 0,-0.0234 -1.90849,-1.95113 -4.24112,-4.28376 -3.98031,-3.98042 -4.42882,-4.45091 -4.59213,-4.81752 -0.0353,-0.0793 -0.0832,-0.18169 -0.10664,-0.22756 -0.0233,-0.0459 -0.0424,-0.12261 -0.0424,-0.17052 0,-0.0479 -0.0227,-0.10976 -0.0506,-0.13762 -0.0676,-0.0676 -0.0676,-1.31071 0,-1.37832 0.0279,-0.0279 0.0506,-0.0981 0.0506,-0.15611 0,-0.10891 0.20751,-0.53509 0.39124,-0.80372 0.23896,-0.3492 12.66549,-12.7130401 12.77837,-12.7137601 0.0614,-3.6e-4 0.1441,-0.0217 0.18372,-0.0473 0.0944,-0.061 0.13246,-0.059 0.35581,0.0187"
id="path1"
sodipodi:nodetypes="ccscccccccccccscccccscccccccccsssscccc" />
</svg>

До

Ширина:  |  Высота:  |  Размер: 6.2 KiB

Просмотреть файл

@ -0,0 +1,3 @@
<svg width="38" height="44" viewBox="0 0 38 44" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M18.9368 1L1 11.5V32.5L18.9375 43L36.875 32.5V11.5L18.9368 1ZM15.8314 32.5014L17.7203 24.3499L13.4729 20.1555L22.0692 11.4993L20.1768 19.6634L24.4014 23.8354L15.8314 32.5014Z" fill="#792EE5"/>
</svg>

После

Ширина:  |  Высота:  |  Размер: 305 B

Двоичные данные
docs/source/_static/images/logo.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 21 KiB

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

После

Ширина:  |  Высота:  |  Размер: 12 KiB

Просмотреть файл

@ -17,7 +17,6 @@ import builtins
import glob
import inspect
import os
import re
import shutil
import sys
@ -29,6 +28,7 @@ sys.path.insert(0, os.path.abspath(PATH_ROOT))
builtins.__LIGHTNING_BOLT_SETUP__ = True
FOLDER_GENERATED = 'generated'
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))
import torchmetrics # noqa: E402
@ -36,7 +36,7 @@ import torchmetrics # noqa: E402
# -- Project information -----------------------------------------------------
# this name shall match the project name in Github as it is used for linking to code
project = "PyTorch-torchmetrics"
project = "PyTorch-Metrics"
copyright = torchmetrics.__copyright__
author = torchmetrics.__author__
@ -50,28 +50,39 @@ release = torchmetrics.__version__
github_user = "PyTorchLightning"
github_repo = project
# -- Project documents -------------------------------------------------------
# export the READme
with open(os.path.join(PATH_ROOT, "README.md"), "r") as fp:
readme = fp.read()
# TODO: temp fix removing SVG badges and GIF, because PDF cannot show them
readme = re.sub(r"(\[!\[.*\))", "", readme)
readme = re.sub(r"(!\[.*.gif\))", "", readme)
for dir_name in (
os.path.basename(p)
for p in glob.glob(os.path.join(PATH_ROOT, "*"))
if os.path.isdir(p)
):
readme = readme.replace("](%s/" % dir_name, "](%s/%s/" % (PATH_ROOT, dir_name))
with open("readme.md", "w") as fp:
fp.write(readme)
def _transform_changelog(path_in: str, path_out: str) -> None:
with open(path_in, 'r') as fp:
chlog_lines = fp.readlines()
# enrich short subsub-titles to be unique
chlog_ver = ''
for i, ln in enumerate(chlog_lines):
if ln.startswith('## '):
chlog_ver = ln[2:].split('-')[0].strip()
elif ln.startswith('### '):
ln = ln.replace('###', f'### {chlog_ver} -')
chlog_lines[i] = ln
with open(path_out, 'w') as fp:
fp.writelines(chlog_lines)
os.makedirs(os.path.join(PATH_HERE, FOLDER_GENERATED), exist_ok=True)
# copy all documents from GH templates like contribution guide
for md in glob.glob(os.path.join(PATH_ROOT, '.github', '*.md')):
shutil.copy(md, os.path.join(PATH_HERE, FOLDER_GENERATED, os.path.basename(md)))
# copy also the changelog
_transform_changelog(
os.path.join(PATH_ROOT, 'CHANGELOG.md'),
os.path.join(PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'),
)
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
needs_sphinx = "2.0"
needs_sphinx = "3.4"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
@ -97,7 +108,6 @@ extensions = [
"sphinx.ext.githubpages",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@ -135,7 +145,7 @@ language = None
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [
"PULL_REQUEST_TEMPLATE.md",
os.path.join(FOLDER_GENERATED, "PULL_REQUEST_TEMPLATE.md"),
]
# The name of the Pygments (syntax highlighting) style to use.
@ -161,13 +171,12 @@ html_theme_options = {
"logo_only": False,
}
# TODO
# html_logo = '_images/logos/lightning_logo-name.svg'
html_logo = '_static/images/logo.svg'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_images", "_templates", "_static"]
html_static_path = ["_static"]
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
@ -272,10 +281,8 @@ PACKAGES = [
torchmetrics.__name__,
]
apidoc_output_folder = os.path.join(PATH_HERE, "api")
# def run_apidoc(_):
# apidoc_output_folder = os.path.join(PATH_HERE, "api")
# sys.path.insert(0, apidoc_output_folder)
#
# # delete api-doc files before generating them
@ -317,7 +324,7 @@ def package_list_from_file(file):
with open(file, "r") as fp:
for ln in fp.readlines():
found = [ln.index(ch) for ch in list(",=<>#") if ch in ln]
pkg = ln[: min(found)] if found else ln
pkg = ln[:min(found)] if found else ln
if pkg.rstrip():
mocked_packages.append(pkg.rstrip())
return mocked_packages

Просмотреть файл

@ -1,168 +0,0 @@
.. role:: hidden
:class: hidden-section
torchmetrics.functional
=======================
.. TODO: this should work and then autofunction doesn't need the full path
.. and then we don't need no index
.. .. currentmodule:: torchmetrics.functional
Classification functions
------------------------
:hidden:`accuracy`
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.accuracy
:noindex:
:hidden:`auc`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.auc
:noindex:
:hidden:`auroc`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.auroc
:noindex:
:hidden:`multiclass_auroc`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.multiclass_auroc
:noindex:
:hidden:`average_precision`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.average_precision
:noindex:
:hidden:`confusion_matrix`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.confusion_matrix
:noindex:
:hidden:`dice_score`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.dice_score
:noindex:
:hidden:`f1`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.f1
:noindex:
:hidden:`fbeta`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.fbeta
:noindex:
:hidden:`iou`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.iou
:noindex:
:hidden:`roc`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.roc
:noindex:
:hidden:`precision`
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.precision
:noindex:
:hidden:`precision_recall`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.precision_recall
:noindex:
:hidden:`precision_recall_curve`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.precision_recall_curve
:noindex:
:hidden:`recall`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.recall
:noindex:
:hidden:`stat_scores`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.stat_scores
:noindex:
:hidden:`stat_scores_multiple_classes`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.classification.stat_scores_multiple_classes
:noindex:
Regression functions
------------------------
:hidden:`explained_variance`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.explained_variance
:noindex:
:hidden:`mean_absolute_error`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.mean_absolute_error
:noindex:
:hidden:`mean_squared_error`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.mean_squared_error
:noindex:
:hidden:`mean_squared_log_error`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.mean_squared_log_error
:noindex:
:hidden:`psnr`
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.psnr
:noindex:
:hidden:`ssim`
~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.ssim
:noindex:
NLP functions
------------------------
:hidden:`bleu_score`
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.nlp.bleu_score
:noindex:
Pairwise functions
--------------------
:hidden:`embedding_similarity`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.self_supervised.embedding_similarity
:noindex:
Utility functions
--------------------
:hidden:`select_topk`
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.utils.select_topk
:noindex:
:hidden:`to_categorical`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.utils.to_categorical
:noindex:
:hidden:`to_onehot`
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.utils.to_onehot
:noindex:

Просмотреть файл

@ -10,10 +10,22 @@ PyTorchMetrics documentation
:name: start
:caption: Start here
intro
torchmetrics
functional
lightning
pages/intro
pages/implement
pages/overview
pages/lightning
references/functional
references/modules
.. toctree::
:maxdepth: 1
:name: community
:caption: Community
generated/CODE_OF_CONDUCT.md
generated/CONTRIBUTING.md
generated/CHANGELOG.md
Indices and tables
==================
@ -21,9 +33,3 @@ Indices and tables
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
.. This is here to make sphinx aware of the modules but not throw an error/warning
.. toctree::
:hidden:
readme

Просмотреть файл

@ -1,108 +0,0 @@
############
Introduction
############
``torchmetrics`` is a Metrics API created for easy metric development and usage in
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
common metric implementations.
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the
provided input.
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
logic present in ``.compute()`` is applied to state information from all processes.
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
.. NOTE: this can't actually be tested as epochs, train_data, and valid_data are undefined
.. code-block:: python
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy(compute_on_step=False)
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
.. note::
Metrics contain internal states that keep track of the data seen so far.
Do not mix metric states across training, validation and testing.
It is highly recommended to re-initialize the metric per mode as
shown in the examples above.
.. note::
Metric states are **not** added to the models ``state_dict`` by default.
To change this, after initializing the metric, the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.
*********************
Implementing a Metric
*********************
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
- ``update()``: Any code needed to update the state given any inputs to the metric.
- ``compute()``: Computes a final value from the state of the metric.
All you need to do is call ``add_state`` correctly to implement a custom metric with DDP.
``reset()`` is called on metric state variables added using ``add_state()``.
To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
from the base ``Metric`` class.
Example implementation:
.. testcode::
from torchmetrics import Metric
class MyAccuracy(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:
.. code-block:: python
metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated

Просмотреть файл

@ -0,0 +1,87 @@
*********************
Implementing a Metric
*********************
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
- ``update()``: Any code needed to update the state given any inputs to the metric.
- ``compute()``: Computes a final value from the state of the metric.
All you need to do is call ``add_state`` correctly to implement a custom metric with DDP.
``reset()`` is called on metric state variables added using ``add_state()``.
To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
from the base ``Metric`` class.
Example implementation:
.. testcode::
from torchmetrics import Metric
class MyAccuracy(Metric):
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:
.. code-block:: python
metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated
Internal implementation details
-------------------------------
This section briefly describe how metrics work internally. We encourage looking at the source code for more info.
Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the
following internally:
1. Clears computed cache
2. Calls user-defined ``update()``
Simiarly, calling ``compute()`` does the following internally
1. Syncs metric states between processes
2. Reduce gathered metric states
3. Calls the user defined ``compute()`` method on the gathered metric states
4. Cache computed result
From a user's standpoint this has one important side-effect: computed results are cached. This means that no
matter how many times ``compute`` is called after one and another, it will continue to return the same result.
The cache is first emptied on the next call to ``update``.
``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achives this by combining calls
to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``):
1. Calls ``update()`` to update the global metric states (for accumulation over multiple batches)
2. Caches the global state
3. Calls ``reset()`` to clear global metric state
4. Calls ``update()`` to update local metric state
5. Calls ``compute()`` to calculate metric for current batch
6. Restores the global state
This procedure has the consequence of calling the user defined ``update`` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).

301
docs/source/pages/intro.rst Normal file
Просмотреть файл

@ -0,0 +1,301 @@
############
Introduction
############
``torchmetrics`` is a Metrics API created for easy metric development and usage in
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
common metric implementations.
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the
provided input.
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
logic present in ``.compute()`` is applied to state information from all processes.
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
.. NOTE: this can't actually be tested as epochs, train_data, and valid_data are undefined
.. code-block:: python
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy(compute_on_step=False)
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
.. note::
Metrics contain internal states that keep track of the data seen so far.
Do not mix metric states across training, validation and testing.
It is highly recommended to re-initialize the metric per mode as
shown in the examples above.
.. note::
Metric states are **not** added to the models ``state_dict`` by default.
To change this, after initializing the metric, the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.
*******************
Metrics and devices
*******************
Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave
similar to buffers and parameters of modules. This means that metrics states should
be moved to the same device as the input of the metric:
.. code-block:: python
from torchmetrics import Accuracy
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))
# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0
******************
Metric Arithmetics
******************
Metrics support most of python built-in operators for arithmetic, logic and bitwise operations.
For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary.
It can now be done with:
.. code-block:: python
first_metric = MyFirstMetric()
second_metric = MySecondMetric()
new_metric = first_metric + second_metric
``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but
forwards only the keyword arguments that are available in respective metric's update declaration.
Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up.
This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats):
* Addition (``a + b``)
* Bitwise AND (``a & b``)
* Equality (``a == b``)
* Floordivision (``a // b``)
* Greater Equal (``a >= b``)
* Greater (``a > b``)
* Less Equal (``a <= b``)
* Less (``a < b``)
* Matrix Multiplication (``a @ b``)
* Modulo (``a % b``)
* Multiplication (``a * b``)
* Inequality (``a != b``)
* Bitwise OR (``a | b``)
* Power (``a ** b``)
* Substraction (``a - b``)
* True Division (``a / b``)
* Bitwise XOR (``a ^ b``)
* Absolute Value (``abs(a)``)
* Inversion (``~a``)
* Negative Value (``neg(a)``)
* Positive Value (``pos(a)``)
****************
MetricCollection
****************
In many cases it is beneficial to evaluate the model output by multiple metrics.
In this case the `MetricCollection` class may come in handy. It accepts a sequence
of metrics and wraps theses into a single callable metric class, with the same
interface as any other metric.
Example:
.. testcode::
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
Accuracy(),
Precision(num_classes=3, average='macro'),
Recall(num_classes=3, average='macro')
])
print(metric_collection(preds, target))
.. testoutput::
:options: +NORMALIZE_WHITESPACE
{'Accuracy': tensor(0.1250),
'Precision': tensor(0.0667),
'Recall': tensor(0.1111)}
Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule
.. code-block:: python
def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
self.valid_metrics = metrics.clone()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
.. note::
`MetricCollection` as default assumes that all the metrics in the collection
have the same call signature. If this is not the case, input that should be
given to different metrics can given as keyword arguments to the collection.
.. autoclass:: torchmetrics.MetricCollection
:noindex:
***************************
Class vs Functional Metrics
***************************
The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface.
**********************
Classification Metrics
**********************
Input types
-----------
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
:widths: 20, 10, 10, 10, 10
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
.. note::
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
.. testcode::
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])
# Multi-class inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])
# Multi-class inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])
# Multi-label inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
Using the is_multiclass parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
integer (binary) tensors. Or it could be the other way around, you want to treat
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.
For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument. Let's see how this is used on the example of
:class:`~torchmetrics.StatScores` metric.
First, let's consider the case with label predictions with 2 classes, which we want to
treat as binary.
.. testcode::
from torchmetrics.functional import stat_scores
# These inputs are supposed to be binary, but appear as multi-class
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])
As you can see below, by default the inputs are treated
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.
.. doctest::
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Next, consider the opposite example: inputs are binary (as predictions are probabilities),
but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.
.. testcode::
preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])
In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.
.. doctest::
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])

Просмотреть файл

Просмотреть файл

@ -0,0 +1,104 @@
.. role:: hidden
:class: hidden-section
**********************
Classification Metrics
**********************
Input types
-----------
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
:widths: 20, 10, 10, 10, 10
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
.. note::
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
.. testcode::
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])
# Multi-class inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])
# Multi-class inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])
# Multi-label inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
Using the is_multiclass parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
integer (binary) tensors. Or it could be the other way around, you want to treat
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.
For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument. Let's see how this is used on the example of
:class:`~pytorch_lightning.metrics.StatScores` metric.
First, let's consider the case with label predictions with 2 classes, which we want to
treat as binary.
.. testcode::
from pytorch_lightning.metrics.functional import stat_scores
# These inputs are supposed to be binary, but appear as multi-class
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])
As you can see below, by default the inputs are treated
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.
.. doctest::
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Next, consider the opposite example: inputs are binary (as predictions are probabilities),
but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.
.. testcode::
preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])
In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.
.. doctest::
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])

Просмотреть файл

@ -0,0 +1,222 @@
.. role:: hidden
:class: hidden-section
**********************
Classification Metrics
**********************
accuracy [func]
~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.accuracy
:noindex:
auc [func]
~~~~~~~~~~
.. autofunction:: torchmetrics.functional.auc
:noindex:
auroc [func]
~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.auroc
:noindex:
average_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.average_precision
:noindex:
confusion_matrix [func]
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.confusion_matrix
:noindex:
dice_score [func]
~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.dice_score
:noindex:
f1 [func]
~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.f1
:noindex:
fbeta [func]
~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.fbeta
:noindex:
hamming_distance [func]
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.hamming_distance
:noindex:
iou [func]
~~~~~~~~~~
.. autofunction:: torchmetrics.functional.iou
:noindex:
roc [func]
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.roc
:noindex:
precision [func]
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.precision
:noindex:
precision_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.precision_recall
:noindex:
precision_recall_curve [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.precision_recall_curve
:noindex:
recall [func]
~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.recall
:noindex:
select_topk [func]
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.utils.select_topk
:noindex:
stat_scores [func]
~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.stat_scores
:noindex:
stat_scores_multiple_classes [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.stat_scores_multiple_classes
:noindex:
to_categorical [func]
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.utils.to_categorical
:noindex:
to_onehot [func]
~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.utils.to_onehot
:noindex:
******************
Regression Metrics
******************
explained_variance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.explained_variance
:noindex:
image_gradients [func]
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.image_gradients
:noindex:
mean_absolute_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.mean_absolute_error
:noindex:
mean_squared_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.mean_squared_error
:noindex:
mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.mean_squared_log_error
:noindex:
psnr [func]
~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.psnr
:noindex:
ssim [func]
~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.ssim
:noindex:
r2score [func]
~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.r2score
:noindex:
***
NLP
***
bleu_score [func]
~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.bleu_score
:noindex:
********
Pairwise
********
embedding_similarity [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: torchmetrics.functional.embedding_similarity
:noindex:

Просмотреть файл

@ -0,0 +1,141 @@
**********************
Classification Metrics
**********************
Accuracy
~~~~~~~~
.. autoclass:: torchmetrics.Accuracy
:noindex:
AveragePrecision
~~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.AveragePrecision
:noindex:
AUC
~~~
.. autoclass:: torchmetrics.AUC
:noindex:
AUROC
~~~~~
.. autoclass:: torchmetrics.AUROC
:noindex:
ConfusionMatrix
~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.ConfusionMatrix
:noindex:
F1
~~
.. autoclass:: torchmetrics.F1
:noindex:
FBeta
~~~~~
.. autoclass:: torchmetrics.FBeta
:noindex:
IoU
~~~
.. autoclass:: torchmetrics.IoU
:noindex:
Hamming Distance
~~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.HammingDistance
:noindex:
Precision
~~~~~~~~~
.. autoclass:: torchmetrics.Precision
:noindex:
PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.PrecisionRecallCurve
:noindex:
Recall
~~~~~~
.. autoclass:: torchmetrics.Recall
:noindex:
ROC
~~~
.. autoclass:: torchmetrics.ROC
:noindex:
StatScores
~~~~~~~~~~
.. autoclass:: torchmetrics.StatScores
:noindex:
******************
Regression Metrics
******************
ExplainedVariance
~~~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.ExplainedVariance
:noindex:
MeanAbsoluteError
~~~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.MeanAbsoluteError
:noindex:
MeanSquaredError
~~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.MeanSquaredError
:noindex:
MeanSquaredLogError
~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchmetrics.MeanSquaredLogError
:noindex:
PSNR
~~~~
.. autoclass:: torchmetrics.PSNR
:noindex:
SSIM
~~~~
.. autoclass:: torchmetrics.SSIM
:noindex:
R2Score
~~~~~~~
.. autoclass:: torchmetrics.R2Score
:noindex:

Просмотреть файл

@ -1,90 +0,0 @@
.. role:: hidden
:class: hidden-section
torchmetrics
===================================
Classification
--------------------
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
:widths: 20, 10, 10, 10, 10
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
.. note::
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
.. testcode::
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])
# Multi-class inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])
# Multi-class inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])
# Multi-label inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label. For example, if both predictions and targets are 1d
binary tensors. Or it could be the other way around, you want to treat binary/multi-label
inputs as 2-class (multi-dimensional) multi-class inputs.
For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument.
.. currentmodule:: torchmetrics.classification
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
Accuracy
AveragePrecision
ConfusionMatrix
F1
FBeta
Precision
PrecisionRecallCurve
Recall
ROC
Regression
--------------------
.. currentmodule:: torchmetrics.regression
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
ExplainedVariance
MeanAbsoluteError
MeanSquaredError
MeanSquaredLogError
PSNR
SSIM

Просмотреть файл

@ -49,7 +49,7 @@ setup(
author=torchmetrics.__author__,
author_email=torchmetrics.__author_email__,
url=torchmetrics.__homepage__,
download_url='https://github.com/PyTorchLightning/torchmetrics',
download_url='https://github.com/PyTorchLightning/metrics/archive/master.zip',
license=torchmetrics.__license__,
packages=find_packages(exclude=['tests', 'docs']),
long_description=load_long_describtion(),

Просмотреть файл

@ -38,7 +38,7 @@ class Accuracy(Metric):
changed to subset accuracy (which requires all labels or sub-samples in the sample to
be correctly predicted) by setting ``subset_accuracy=True``.
Accepts all input types listed in :ref:`extensions/metrics:input types`.
Accepts all input types listed in :ref:`pages/overview:input types`.
Args:
threshold:
@ -127,7 +127,7 @@ class Accuracy(Metric):
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information
Update state with predictions and targets. See :ref:`pages/overview:input types` for more information
on input types.
Args:

Просмотреть файл

@ -35,7 +35,7 @@ class HammingDistance(Metric):
treats each possible label separately - meaning that, for example, multi-class data is
treated as if it were multi-label.
Accepts all input types listed in :ref:`extensions/metrics:input types`.
Accepts all input types listed in :ref:`pages/overview:input types`.
Args:
threshold:
@ -88,7 +88,7 @@ class HammingDistance(Metric):
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information
Update state with predictions and targets. See :ref:`pages/overview:input types` for more information
on input types.
Args:

Просмотреть файл

@ -283,7 +283,7 @@ def _check_classification_inputs(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
@ -409,7 +409,7 @@ def _input_format_classification(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
Returns:

Просмотреть файл

@ -31,7 +31,7 @@ class Precision(StatScores):
The reduction method (how the precision scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
Args:
num_classes:
@ -67,11 +67,11 @@ class Precision(StatScores):
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`extensions/metrics:input types`)
(see :ref:`pages/overview:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
@ -90,7 +90,7 @@ class Precision(StatScores):
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
compute_on_step:
@ -181,7 +181,7 @@ class Recall(StatScores):
The reduction method (how the recall scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
Args:
num_classes:
@ -217,11 +217,11 @@ class Recall(StatScores):
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`extensions/metrics:input types`)
(see :ref:`pages/overview:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
@ -241,7 +241,7 @@ class Recall(StatScores):
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
compute_on_step:

Просмотреть файл

@ -28,7 +28,7 @@ class StatScores(Metric):
``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the
multi-dimensional multi-class case.
Accepts all inputs listed in :ref:`extensions/metrics:input types`.
Accepts all inputs listed in :ref:`pages/overview:input types`.
Args:
threshold:
@ -71,7 +71,7 @@ class StatScores(Metric):
one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class (see :ref:`extensions/metrics:input types` for the definition of input types).
multi-class (see :ref:`pages/overview:input types` for the definition of input types).
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then the outputs are concatenated together. In each
@ -86,7 +86,7 @@ class StatScores(Metric):
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
compute_on_step:
@ -175,7 +175,7 @@ class StatScores(Metric):
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information
Update state with predictions and targets. See :ref:`pages/overview:input types` for more information
on input types.
Args:

Просмотреть файл

@ -72,7 +72,7 @@ def accuracy(
changed to subset accuracy (which requires all labels or sub-samples in the sample to
be correctly predicted) by setting ``subset_accuracy=True``.
Accepts all input types listed in :ref:`extensions/metrics:input types`.
Accepts all input types listed in :ref:`pages/overview:input types`.
Args:
preds: Predictions from model (probabilities, or labels)

Просмотреть файл

@ -51,7 +51,7 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float
treats each possible label separately - meaning that, for example, multi-class data is
treated as if it were multi-label.
Accepts all input types listed in :ref:`extensions/metrics:input types`.
Accepts all input types listed in :ref:`pages/overview:input types`.
Args:
preds: Predictions from model

Просмотреть файл

@ -60,7 +60,7 @@ def precision(
The reduction method (how the precision scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
Args:
preds: Predictions from model (probabilities or labels)
@ -94,11 +94,11 @@ def precision(
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`extensions/metrics:input types`)
(see :ref:`pages/overview:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
@ -123,7 +123,7 @@ def precision(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
class_reduction:
@ -225,7 +225,7 @@ def recall(
The reduction method (how the recall scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
Args:
preds: Predictions from model (probabilities, or labels)
@ -256,11 +256,11 @@ def recall(
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`extensions/metrics:input types`)
(see :ref:`pages/overview:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
@ -285,7 +285,7 @@ def recall(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
class_reduction:
@ -373,7 +373,7 @@ def precision_recall(
The reduction method (how the recall scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
Args:
preds: Predictions from model (probabilities, or labels)
@ -404,11 +404,11 @@ def precision_recall(
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`extensions/metrics:input types`)
(see :ref:`pages/overview:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
@ -433,7 +433,7 @@ def precision_recall(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
class_reduction:

Просмотреть файл

@ -153,7 +153,7 @@ def stat_scores(
The reduction method (how the statistics are aggregated) is controlled by the
``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
Args:
preds: Predictions from model (probabilities or labels)
@ -198,7 +198,7 @@ def stat_scores(
one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class (see :ref:`extensions/metrics:input types` for the definition of input types).
multi-class (see :ref:`pages/overview:input types` for the definition of input types).
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then the outputs are concatenated together. In each
@ -213,7 +213,7 @@ def stat_scores(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
for a more detailed explanation and examples.
Return: