|
@ -0,0 +1,76 @@
|
|||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to making participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies both within project spaces and in public spaces
|
||||
when an individual is representing the project or its community. Examples of
|
||||
representing a project or community include using an official project e-mail
|
||||
address, posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event. Representation of a project may be
|
||||
further defined and clarified by project maintainers.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at waf2107@columbia.edu. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
|
@ -0,0 +1,5 @@
|
|||
# Contributing
|
||||
|
||||
Welcome to the PyTorch Lightning community! We're building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out!
|
||||
|
||||
**TBD**
|
|
@ -1,13 +1,13 @@
|
|||
sphinx>=2.0, <3.0
|
||||
sphinx>=3.0, !=3.5 # fails with sphinx.ext.viewcode
|
||||
recommonmark # fails with badges
|
||||
m2r # fails with multi-line text
|
||||
nbsphinx
|
||||
pandoc
|
||||
docutils
|
||||
sphinxcontrib-fulltoc
|
||||
nbsphinx>=0.8
|
||||
pandoc>=1.0
|
||||
docutils>=0.16
|
||||
sphinxcontrib-fulltoc>=1.0
|
||||
sphinxcontrib-mockautodoc
|
||||
|
||||
git+https://github.com/PytorchLightning/lightning_sphinx_theme.git
|
||||
# pip_shims
|
||||
sphinx-autodoc-typehints
|
||||
sphinx-paramlinks<0.4.0
|
||||
https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#egg=pt-lightning-sphinx-theme
|
||||
sphinx-autodoc-typehints>=1.0
|
||||
sphinx-paramlinks>=0.4.0
|
||||
sphinx-togglebutton>=0.2
|
||||
sphinx-copybutton>=0.3
|
|
@ -1,62 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg
|
||||
xmlns:dc="http://purl.org/dc/elements/1.1/"
|
||||
xmlns:cc="http://creativecommons.org/ns#"
|
||||
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
id="svg"
|
||||
version="1.1"
|
||||
width="16.000004"
|
||||
height="15.999986"
|
||||
viewBox="0 0 16.000004 15.999986"
|
||||
sodipodi:docname="lightning_icon.svg"
|
||||
inkscape:version="0.92.3 (2405546, 2018-03-11)">
|
||||
<metadata
|
||||
id="metadata13">
|
||||
<rdf:RDF>
|
||||
<cc:Work
|
||||
rdf:about="">
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:type
|
||||
rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
|
||||
<dc:title></dc:title>
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<defs
|
||||
id="defs11" />
|
||||
<sodipodi:namedview
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#666666"
|
||||
borderopacity="1"
|
||||
objecttolerance="10"
|
||||
gridtolerance="10"
|
||||
guidetolerance="10"
|
||||
inkscape:pageopacity="0"
|
||||
inkscape:pageshadow="2"
|
||||
inkscape:window-width="1920"
|
||||
inkscape:window-height="1028"
|
||||
id="namedview9"
|
||||
showgrid="false"
|
||||
inkscape:zoom="0.59"
|
||||
inkscape:cx="-669.05062"
|
||||
inkscape:cy="373.84245"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="0"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="svg" />
|
||||
<path
|
||||
style="fill:#fbfbfb;fill-rule:evenodd;stroke:none;stroke-width:0.04002798"
|
||||
inkscape:connector-curvature="0"
|
||||
d="m 8.987101,1.723485 c -0.05588,0.03422 -4.121881,4.096544 -4.184645,4.180924 -0.02317,0.0311 -0.04587,0.06016 -0.05044,0.06456 -0.0087,0.0084 -0.07477,0.145063 -0.09679,0.20014 -0.05848,0.146583 -0.05804,0.44387 0.001,0.592413 0.08426,0.21243 0.08826,0.216754 1.576864,1.706274 0.779463,0.779947 1.41719,1.426877 1.41719,1.437604 0,0.0232 -0.253177,0.79848 -0.273873,0.838707 -0.0079,0.0153 -0.01433,0.04087 -0.01433,0.05684 0,0.01597 -0.0059,0.03587 -0.01313,0.04423 -0.0072,0.0084 -0.03678,0.09086 -0.06568,0.18333 -0.02893,0.09246 -0.05904,0.180647 -0.06693,0.195937 -0.0079,0.0153 -0.01437,0.04087 -0.01437,0.05684 0,0.01597 -0.0059,0.03586 -0.01313,0.04423 -0.0072,0.0084 -0.03679,0.09086 -0.06569,0.18333 -0.02893,0.09246 -0.05904,0.180643 -0.06693,0.195937 -0.0079,0.0153 -0.01437,0.04187 -0.01437,0.05908 0,0.0172 -0.0072,0.03574 -0.016,0.04119 -0.0088,0.0054 -0.016,0.02607 -0.016,0.04579 0,0.01973 -0.006,0.04271 -0.0134,0.05108 -0.0074,0.0084 -0.04439,0.112477 -0.08222,0.23136 -0.03787,0.118884 -0.151103,0.461124 -0.251693,0.760534 -0.489984,1.45874 -0.462444,1.36155 -0.413611,1.45938 0.06917,0.138657 0.23128,0.199741 0.358251,0.134974 0.07057,-0.03602 4.143298,-4.099985 4.245368,-4.236242 0.03382,-0.04515 0.09094,-0.165796 0.109916,-0.232123 0.0088,-0.03083 0.0243,-0.08498 0.03442,-0.120363 0.03346,-0.11668 0.0068,-0.361134 -0.0566,-0.520084 C 10.880518,9.229614 10.738898,9.079187 9.372744,7.714673 8.601524,6.944416 7.970523,6.302806 7.970523,6.288916 c 0,-0.01393 0.02817,-0.107833 0.0626,-0.208663 0.03442,-0.100834 0.07881,-0.237367 0.09859,-0.303414 0.0198,-0.06605 0.04207,-0.12693 0.04947,-0.135293 0.0074,-0.0084 0.0135,-0.03133 0.0135,-0.05108 0,-0.01973 0.0072,-0.04035 0.016,-0.04579 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04804 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04804 0.0088,-0.0054 0.016,-0.02707 0.016,-0.04803 0,-0.02097 0.0072,-0.04259 0.016,-0.04803 0.0088,-0.0054 0.016,-0.02397 0.016,-0.04119 0,-0.0172 0.0065,-0.04379 0.0144,-0.05908 0.0079,-0.0153 0.119204,-0.34484 0.247334,-0.73231 C 9.064507,2.979766 9.220177,2.513319 9.28226,2.330632 9.408267,1.960092 9.41367,1.921146 9.35255,1.826839 9.27225,1.703032 9.099973,1.654399 8.986893,1.723566"
|
||||
id="path0" />
|
||||
<path
|
||||
style="fill:#540c8c;fill-rule:evenodd;stroke:none;stroke-width:0.04002798"
|
||||
inkscape:connector-curvature="0"
|
||||
d="m 0.07719102,0.01733399 c -0.02187,0.0111 -0.04875,0.03799 -0.05984,0.05984 -0.0161,0.03173 -0.01937,1.62421701 -0.01633,7.94479601 l 0.0038,7.905086 0.03647,0.03646 0.03646,0.03647 H 8.00241 15.927073 l 0.03646,-0.03647 0.03647,-0.03646 V 8.002393 0.07773399 l -0.03647,-0.03646 -0.03646,-0.03647 -7.905086,-0.0038 c -6.320579,-0.003 -7.91305298,2.4e-4 -7.94479598,0.01633 M 9.193764,1.668208 c 0.259903,0.09046 0.275193,0.212427 0.09363,0.74628 C 8.845834,3.776859 8.388843,5.102846 7.991127,6.302606 L 9.415644,7.72492 c 1.24415,1.242111 1.51682,1.523547 1.51682,1.565414 0,0.0051 0.0133,0.03987 0.02953,0.07718 0.12913,0.296607 0.0877,0.664983 -0.103314,0.91872 -0.141456,0.187933 -4.207341,4.228478 -4.273468,4.246848 -0.139417,0.03871 -0.248653,-0.006 -0.34324,-0.140417 -0.07665,-0.108996 -0.06985,-0.137256 0.287004,-1.194633 0.34663,-1.101761 0.75901,-2.243218 1.08916,-3.290661 0,-0.0078 -0.636164,-0.650377 -1.413707,-1.427921 C 4.877658,7.152643 4.728155,6.995813 4.673718,6.87361 4.661948,6.84718 4.645988,6.81305 4.638168,6.79776 4.630368,6.78246 4.624038,6.75689 4.624038,6.74092 c 0,-0.01597 -0.0076,-0.03659 -0.01687,-0.04587 -0.02253,-0.02253 -0.02253,-0.436904 0,-0.45944 0.0093,-0.0093 0.01687,-0.0327 0.01687,-0.05204 0,-0.0363 0.06917,-0.178363 0.130414,-0.267907 0.07965,-0.1164 4.221831,-4.237681 4.259458,-4.237921 0.02047,-1.2e-4 0.04803,-0.0072 0.06124,-0.01577 0.03147,-0.02033 0.04415,-0.01967 0.118603,0.0062"
|
||||
id="path1"
|
||||
sodipodi:nodetypes="ccscccccccccccscccccscccccccccsssscccc" />
|
||||
</svg>
|
До Ширина: | Высота: | Размер: 6.4 KiB |
|
@ -1,61 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg
|
||||
xmlns:dc="http://purl.org/dc/elements/1.1/"
|
||||
xmlns:cc="http://creativecommons.org/ns#"
|
||||
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
id="svg"
|
||||
version="1.1"
|
||||
width="400"
|
||||
height="400"
|
||||
viewBox="0, 0, 400,400"
|
||||
sodipodi:docname="lightning_logo.svg"
|
||||
inkscape:version="0.92.3 (2405546, 2018-03-11)">
|
||||
<metadata
|
||||
id="metadata13">
|
||||
<rdf:RDF>
|
||||
<cc:Work
|
||||
rdf:about="">
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:type
|
||||
rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<defs
|
||||
id="defs11" />
|
||||
<sodipodi:namedview
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#666666"
|
||||
borderopacity="1"
|
||||
objecttolerance="10"
|
||||
gridtolerance="10"
|
||||
guidetolerance="10"
|
||||
inkscape:pageopacity="0"
|
||||
inkscape:pageshadow="2"
|
||||
inkscape:window-width="1920"
|
||||
inkscape:window-height="1028"
|
||||
id="namedview9"
|
||||
showgrid="false"
|
||||
inkscape:zoom="9.44"
|
||||
inkscape:cx="203.07907"
|
||||
inkscape:cy="335.32491"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="0"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="svg" />
|
||||
<path
|
||||
style="fill:#fbfbfb;fill-rule:evenodd;stroke:none"
|
||||
inkscape:connector-curvature="0"
|
||||
d="m 224.6,43.137 c -1.396,0.855 -102.975,102.342 -104.543,104.45 -0.579,0.777 -1.146,1.503 -1.26,1.613 -0.218,0.21 -1.868,3.624 -2.418,5 -1.461,3.662 -1.45,11.089 0.022,14.8 2.105,5.307 2.205,5.415 39.394,42.627 19.473,19.485 35.405,35.647 35.405,35.915 0,0.58 -6.325,19.948 -6.842,20.953 -0.197,0.382 -0.358,1.021 -0.358,1.42 0,0.399 -0.147,0.896 -0.328,1.105 -0.18,0.209 -0.919,2.27 -1.641,4.58 -0.723,2.31 -1.475,4.513 -1.672,4.895 -0.198,0.382 -0.359,1.021 -0.359,1.42 0,0.399 -0.147,0.896 -0.328,1.105 -0.18,0.209 -0.919,2.27 -1.641,4.58 -0.723,2.31 -1.475,4.513 -1.672,4.895 -0.198,0.382 -0.359,1.046 -0.359,1.476 0,0.43 -0.18,0.893 -0.4,1.029 -0.22,0.136 -0.4,0.651 -0.4,1.144 0,0.493 -0.151,1.067 -0.335,1.276 -0.184,0.209 -1.109,2.81 -2.054,5.78 -0.946,2.97 -3.775,11.52 -6.288,19 -12.241,36.443 -11.553,34.015 -10.333,36.459 1.728,3.464 5.778,4.99 8.95,3.372 1.763,-0.9 103.51,-102.428 106.06,-105.832 0.845,-1.128 2.272,-4.142 2.746,-5.799 0.22,-0.77 0.607,-2.123 0.86,-3.007 0.836,-2.915 0.171,-9.022 -1.414,-12.993 -1.493,-3.741 -5.031,-7.499 -39.161,-41.588 C 214.964,173.569 199.2,157.54 199.2,157.193 c 0,-0.348 0.704,-2.694 1.564,-5.213 0.86,-2.519 1.969,-5.93 2.463,-7.58 0.495,-1.65 1.051,-3.171 1.236,-3.38 0.186,-0.209 0.337,-0.783 0.337,-1.276 0,-0.493 0.18,-1.008 0.4,-1.144 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.676 0.4,-1.2 0,-0.524 0.18,-1.064 0.4,-1.2 0.22,-0.136 0.4,-0.599 0.4,-1.029 0,-0.43 0.162,-1.094 0.36,-1.476 0.197,-0.382 2.978,-8.615 6.179,-18.295 3.2,-9.68 7.089,-21.333 8.64,-25.897 3.148,-9.257 3.283,-10.23 1.756,-12.586 -2.006,-3.093 -6.31,-4.308 -9.135,-2.58"
|
||||
id="path0" />
|
||||
<path
|
||||
style="fill:#540c8c;fill-rule:evenodd;stroke:none"
|
||||
inkscape:connector-curvature="0"
|
||||
d="M 2.008,0.513 C 1.462,0.79 0.79,1.462 0.513,2.008 0.111,2.801 0.029,42.585 0.105,200.489 L 0.2,397.978 1.111,398.889 2.022,399.8 H 200 397.978 l 0.911,-0.911 0.911,-0.911 V 200 2.022 L 398.889,1.111 397.978,0.2 200.489,0.105 C 42.585,0.029 2.801,0.111 2.008,0.513 m 227.755,41.243 c 6.493,2.26 6.875,5.307 2.339,18.644 -11.0313,34.035452 -22.44803,67.16196 -32.384,97.135 l 35.588,35.533 c 31.082,31.031 37.894,38.062 37.894,39.108 0,0.128 0.332,0.996 0.738,1.928 3.226,7.41 2.191,16.613 -2.581,22.952 -3.534,4.695 -105.11,105.638 -106.762,106.097 -3.483,0.967 -6.212,-0.15 -8.575,-3.508 -1.915,-2.723 -1.745,-3.429 7.17,-29.845 8.65971,-27.52475 18.96205,-56.04122 27.21,-82.209 0,-0.195 -15.893,-16.248 -35.318,-35.673 -33.146,-33.147 -36.881,-37.065 -38.241,-40.118 -0.294,-0.66 -0.693,-1.513 -0.888,-1.895 -0.194,-0.382 -0.353,-1.021 -0.353,-1.42 0,-0.399 -0.189,-0.914 -0.421,-1.146 -0.563,-0.563 -0.563,-10.915 0,-11.478 0.232,-0.232 0.421,-0.817 0.421,-1.3 0,-0.907 1.728,-4.456 3.258,-6.693 C 120.848,144.96 224.33,42 225.27,41.994 c 0.511,-0.003 1.2,-0.181 1.53,-0.394 0.786,-0.508 1.103,-0.491 2.963,0.156"
|
||||
id="path1"
|
||||
sodipodi:nodetypes="ccscccccccccccscccccscccccccccsssscccc" />
|
||||
</svg>
|
До Ширина: | Высота: | Размер: 5.2 KiB |
До Ширина: | Высота: | Размер: 16 KiB |
Двоичные данные
docs/source/_images/logos/lightning_logo.png
До Ширина: | Высота: | Размер: 8.3 KiB |
|
@ -1,62 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg
|
||||
xmlns:dc="http://purl.org/dc/elements/1.1/"
|
||||
xmlns:cc="http://creativecommons.org/ns#"
|
||||
xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
||||
xmlns:svg="http://www.w3.org/2000/svg"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
id="svg"
|
||||
version="1.1"
|
||||
width="47.999985"
|
||||
height="47.999943"
|
||||
viewBox="0 0 47.999985 47.999943"
|
||||
sodipodi:docname="lightning_logo.svg"
|
||||
inkscape:version="0.92.3 (2405546, 2018-03-11)">
|
||||
<metadata
|
||||
id="metadata13">
|
||||
<rdf:RDF>
|
||||
<cc:Work
|
||||
rdf:about="">
|
||||
<dc:format>image/svg+xml</dc:format>
|
||||
<dc:type
|
||||
rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
|
||||
<dc:title />
|
||||
</cc:Work>
|
||||
</rdf:RDF>
|
||||
</metadata>
|
||||
<defs
|
||||
id="defs11" />
|
||||
<sodipodi:namedview
|
||||
pagecolor="#ffffff"
|
||||
bordercolor="#666666"
|
||||
borderopacity="1"
|
||||
objecttolerance="10"
|
||||
gridtolerance="10"
|
||||
guidetolerance="10"
|
||||
inkscape:pageopacity="0"
|
||||
inkscape:pageshadow="2"
|
||||
inkscape:window-width="1920"
|
||||
inkscape:window-height="1028"
|
||||
id="namedview9"
|
||||
showgrid="false"
|
||||
inkscape:zoom="0.59"
|
||||
inkscape:cx="-347.96588"
|
||||
inkscape:cy="389.84243"
|
||||
inkscape:window-x="0"
|
||||
inkscape:window-y="0"
|
||||
inkscape:window-maximized="1"
|
||||
inkscape:current-layer="svg" />
|
||||
<path
|
||||
style="fill:#fbfbfb;fill-rule:evenodd;stroke:none;stroke-width:0.12008391"
|
||||
inkscape:connector-curvature="0"
|
||||
d="m 26.961294,5.1704519 c -0.16764,0.10267 -12.36564,12.2896301 -12.55393,12.5427701 -0.0695,0.0933 -0.13762,0.18048 -0.15131,0.19369 -0.0262,0.0252 -0.22432,0.43519 -0.29036,0.60042 -0.17544,0.43975 -0.17412,1.33161 0.003,1.77724 0.25278,0.63729 0.26479,0.65026 4.73059,5.11882 2.33839,2.33984 4.25157,4.28063 4.25157,4.31281 0,0.0696 -0.75953,2.39544 -0.82162,2.51612 -0.0237,0.0459 -0.043,0.12261 -0.043,0.17052 0,0.0479 -0.0177,0.1076 -0.0394,0.13269 -0.0216,0.0251 -0.11035,0.27259 -0.19705,0.54999 -0.0868,0.27739 -0.17713,0.54194 -0.20078,0.58781 -0.0238,0.0459 -0.0431,0.1226 -0.0431,0.17052 0,0.0479 -0.0177,0.10759 -0.0394,0.13269 -0.0216,0.0251 -0.11036,0.27259 -0.19706,0.54999 -0.0868,0.27739 -0.17712,0.54193 -0.20078,0.58781 -0.0238,0.0459 -0.0431,0.1256 -0.0431,0.17724 0,0.0516 -0.0216,0.10723 -0.048,0.12357 -0.0264,0.0163 -0.048,0.0782 -0.048,0.13737 0,0.0592 -0.0181,0.12813 -0.0402,0.15323 -0.0221,0.0251 -0.13318,0.33743 -0.24666,0.69408 -0.1136,0.35665 -0.45331,1.38337 -0.75508,2.2816 -1.46995,4.37622 -1.38733,4.08465 -1.24083,4.37814 0.2075,0.41597 0.69384,0.59922 1.07475,0.40492 0.21171,-0.10807 12.42989,-12.29995 12.7361,-12.70872 0.10147,-0.13545 0.27283,-0.49739 0.32975,-0.69637 0.0264,-0.0925 0.0729,-0.25493 0.10327,-0.36109 0.10039,-0.35004 0.0205,-1.0834 -0.1698,-1.56025 -0.17928,-0.44923 -0.60414,-0.90051 -4.7026,-4.99405 -2.31366,-2.31077 -4.20666,-4.2356 -4.20666,-4.27727 0,-0.0418 0.0845,-0.3235 0.18781,-0.62599 0.10327,-0.3025 0.23644,-0.7121 0.29577,-0.91024 0.0594,-0.19814 0.1262,-0.38079 0.14842,-0.40588 0.0223,-0.0251 0.0405,-0.094 0.0405,-0.15323 0,-0.0592 0.0216,-0.12105 0.048,-0.13738 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.14411 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.14411 0.0264,-0.0163 0.048,-0.0812 0.048,-0.1441 0,-0.0629 0.0216,-0.12777 0.048,-0.1441 0.0264,-0.0163 0.048,-0.0719 0.048,-0.12356 0,-0.0516 0.0195,-0.13137 0.0432,-0.17725 0.0237,-0.0459 0.35761,-1.03452 0.742,-2.19693 0.38427,-1.1624101 0.85128,-2.5617501 1.03753,-3.1098101 0.37802,-1.11162 0.39423,-1.22846 0.21087,-1.51138 -0.24089,-0.37142 -0.75773,-0.51732 -1.09697,-0.30982"
|
||||
id="path0" />
|
||||
<path
|
||||
style="fill:#540c8c;fill-rule:evenodd;stroke:none;stroke-width:0.12008391"
|
||||
inkscape:connector-curvature="0"
|
||||
d="m 0.2315739,0.05200186 c -0.0656,0.0333 -0.14626,0.11396 -0.17952,0.17952 -0.0483,0.0952 -0.0581,4.87265004 -0.049,23.83438014 l 0.0114,23.71525 0.1094,0.10939 0.10939,0.1094 h 23.7739701 23.77398 l 0.10939,-0.1094 0.1094,-0.10939 V 24.007172 0.23320186 l -0.1094,-0.10939 -0.10939,-0.1094 -23.71525,-0.0114 c -18.9617301,-0.009 -23.7391501,7.2e-4 -23.8343801,0.049 M 27.581274,5.0046319 c 0.77971,0.27139 0.82558,0.63728 0.28088,2.23884 -1.32468,4.0871101 -2.69565,8.0650701 -3.8888,11.6643501 l 4.27355,4.26694 c 3.73245,3.72633 4.55046,4.57064 4.55046,4.69624 0,0.0154 0.0399,0.11961 0.0886,0.23153 0.38739,0.88982 0.2631,1.99495 -0.30994,2.75616 -0.42437,0.5638 -12.62202,12.68543 -12.8204,12.74054 -0.41825,0.11613 -0.74596,-0.018 -1.02972,-0.42125 -0.22996,-0.32699 -0.20954,-0.41177 0.86101,-3.5839 1.03989,-3.30528 2.27703,-6.72965 3.26748,-9.87198 0,-0.0234 -1.90849,-1.95113 -4.24112,-4.28376 -3.98031,-3.98042 -4.42882,-4.45091 -4.59213,-4.81752 -0.0353,-0.0793 -0.0832,-0.18169 -0.10664,-0.22756 -0.0233,-0.0459 -0.0424,-0.12261 -0.0424,-0.17052 0,-0.0479 -0.0227,-0.10976 -0.0506,-0.13762 -0.0676,-0.0676 -0.0676,-1.31071 0,-1.37832 0.0279,-0.0279 0.0506,-0.0981 0.0506,-0.15611 0,-0.10891 0.20751,-0.53509 0.39124,-0.80372 0.23896,-0.3492 12.66549,-12.7130401 12.77837,-12.7137601 0.0614,-3.6e-4 0.1441,-0.0217 0.18372,-0.0473 0.0944,-0.061 0.13246,-0.059 0.35581,0.0187"
|
||||
id="path1"
|
||||
sodipodi:nodetypes="ccscccccccccccscccccscccccccccsssscccc" />
|
||||
</svg>
|
До Ширина: | Высота: | Размер: 6.2 KiB |
|
@ -0,0 +1,3 @@
|
|||
<svg width="38" height="44" viewBox="0 0 38 44" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M18.9368 1L1 11.5V32.5L18.9375 43L36.875 32.5V11.5L18.9368 1ZM15.8314 32.5014L17.7203 24.3499L13.4729 20.1555L22.0692 11.4993L20.1768 19.6634L24.4014 23.8354L15.8314 32.5014Z" fill="#792EE5"/>
|
||||
</svg>
|
После Ширина: | Высота: | Размер: 305 B |
После Ширина: | Высота: | Размер: 21 KiB |
После Ширина: | Высота: | Размер: 12 KiB |
|
@ -17,7 +17,6 @@ import builtins
|
|||
import glob
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
|
@ -29,6 +28,7 @@ sys.path.insert(0, os.path.abspath(PATH_ROOT))
|
|||
|
||||
builtins.__LIGHTNING_BOLT_SETUP__ = True
|
||||
|
||||
FOLDER_GENERATED = 'generated'
|
||||
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))
|
||||
|
||||
import torchmetrics # noqa: E402
|
||||
|
@ -36,7 +36,7 @@ import torchmetrics # noqa: E402
|
|||
# -- Project information -----------------------------------------------------
|
||||
|
||||
# this name shall match the project name in Github as it is used for linking to code
|
||||
project = "PyTorch-torchmetrics"
|
||||
project = "PyTorch-Metrics"
|
||||
copyright = torchmetrics.__copyright__
|
||||
author = torchmetrics.__author__
|
||||
|
||||
|
@ -50,28 +50,39 @@ release = torchmetrics.__version__
|
|||
github_user = "PyTorchLightning"
|
||||
github_repo = project
|
||||
|
||||
|
||||
# -- Project documents -------------------------------------------------------
|
||||
# export the READme
|
||||
with open(os.path.join(PATH_ROOT, "README.md"), "r") as fp:
|
||||
readme = fp.read()
|
||||
# TODO: temp fix removing SVG badges and GIF, because PDF cannot show them
|
||||
readme = re.sub(r"(\[!\[.*\))", "", readme)
|
||||
readme = re.sub(r"(!\[.*.gif\))", "", readme)
|
||||
for dir_name in (
|
||||
os.path.basename(p)
|
||||
for p in glob.glob(os.path.join(PATH_ROOT, "*"))
|
||||
if os.path.isdir(p)
|
||||
):
|
||||
readme = readme.replace("](%s/" % dir_name, "](%s/%s/" % (PATH_ROOT, dir_name))
|
||||
with open("readme.md", "w") as fp:
|
||||
fp.write(readme)
|
||||
|
||||
|
||||
def _transform_changelog(path_in: str, path_out: str) -> None:
|
||||
with open(path_in, 'r') as fp:
|
||||
chlog_lines = fp.readlines()
|
||||
# enrich short subsub-titles to be unique
|
||||
chlog_ver = ''
|
||||
for i, ln in enumerate(chlog_lines):
|
||||
if ln.startswith('## '):
|
||||
chlog_ver = ln[2:].split('-')[0].strip()
|
||||
elif ln.startswith('### '):
|
||||
ln = ln.replace('###', f'### {chlog_ver} -')
|
||||
chlog_lines[i] = ln
|
||||
with open(path_out, 'w') as fp:
|
||||
fp.writelines(chlog_lines)
|
||||
|
||||
|
||||
os.makedirs(os.path.join(PATH_HERE, FOLDER_GENERATED), exist_ok=True)
|
||||
# copy all documents from GH templates like contribution guide
|
||||
for md in glob.glob(os.path.join(PATH_ROOT, '.github', '*.md')):
|
||||
shutil.copy(md, os.path.join(PATH_HERE, FOLDER_GENERATED, os.path.basename(md)))
|
||||
# copy also the changelog
|
||||
_transform_changelog(
|
||||
os.path.join(PATH_ROOT, 'CHANGELOG.md'),
|
||||
os.path.join(PATH_HERE, FOLDER_GENERATED, 'CHANGELOG.md'),
|
||||
)
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
|
||||
needs_sphinx = "2.0"
|
||||
needs_sphinx = "3.4"
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
|
@ -97,7 +108,6 @@ extensions = [
|
|||
"sphinx.ext.githubpages",
|
||||
]
|
||||
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ["_templates"]
|
||||
|
||||
|
@ -135,7 +145,7 @@ language = None
|
|||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = [
|
||||
"PULL_REQUEST_TEMPLATE.md",
|
||||
os.path.join(FOLDER_GENERATED, "PULL_REQUEST_TEMPLATE.md"),
|
||||
]
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
|
@ -161,13 +171,12 @@ html_theme_options = {
|
|||
"logo_only": False,
|
||||
}
|
||||
|
||||
# TODO
|
||||
# html_logo = '_images/logos/lightning_logo-name.svg'
|
||||
html_logo = '_static/images/logo.svg'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ["_images", "_templates", "_static"]
|
||||
html_static_path = ["_static"]
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
|
@ -272,10 +281,8 @@ PACKAGES = [
|
|||
torchmetrics.__name__,
|
||||
]
|
||||
|
||||
apidoc_output_folder = os.path.join(PATH_HERE, "api")
|
||||
|
||||
|
||||
# def run_apidoc(_):
|
||||
# apidoc_output_folder = os.path.join(PATH_HERE, "api")
|
||||
# sys.path.insert(0, apidoc_output_folder)
|
||||
#
|
||||
# # delete api-doc files before generating them
|
||||
|
@ -317,7 +324,7 @@ def package_list_from_file(file):
|
|||
with open(file, "r") as fp:
|
||||
for ln in fp.readlines():
|
||||
found = [ln.index(ch) for ch in list(",=<>#") if ch in ln]
|
||||
pkg = ln[: min(found)] if found else ln
|
||||
pkg = ln[:min(found)] if found else ln
|
||||
if pkg.rstrip():
|
||||
mocked_packages.append(pkg.rstrip())
|
||||
return mocked_packages
|
||||
|
|
|
@ -1,168 +0,0 @@
|
|||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
torchmetrics.functional
|
||||
=======================
|
||||
.. TODO: this should work and then autofunction doesn't need the full path
|
||||
.. and then we don't need no index
|
||||
.. .. currentmodule:: torchmetrics.functional
|
||||
|
||||
|
||||
Classification functions
|
||||
------------------------
|
||||
|
||||
:hidden:`accuracy`
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.accuracy
|
||||
:noindex:
|
||||
|
||||
:hidden:`auc`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.auc
|
||||
:noindex:
|
||||
|
||||
:hidden:`auroc`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.auroc
|
||||
:noindex:
|
||||
|
||||
:hidden:`multiclass_auroc`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.multiclass_auroc
|
||||
:noindex:
|
||||
|
||||
:hidden:`average_precision`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.average_precision
|
||||
:noindex:
|
||||
|
||||
:hidden:`confusion_matrix`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.confusion_matrix
|
||||
:noindex:
|
||||
|
||||
:hidden:`dice_score`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.dice_score
|
||||
:noindex:
|
||||
|
||||
:hidden:`f1`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.f1
|
||||
:noindex:
|
||||
|
||||
:hidden:`fbeta`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.fbeta
|
||||
:noindex:
|
||||
|
||||
:hidden:`iou`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.iou
|
||||
:noindex:
|
||||
|
||||
:hidden:`roc`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.roc
|
||||
:noindex:
|
||||
|
||||
:hidden:`precision`
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.precision
|
||||
:noindex:
|
||||
|
||||
:hidden:`precision_recall`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.precision_recall
|
||||
:noindex:
|
||||
|
||||
:hidden:`precision_recall_curve`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.precision_recall_curve
|
||||
:noindex:
|
||||
|
||||
:hidden:`recall`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.recall
|
||||
:noindex:
|
||||
|
||||
:hidden:`stat_scores`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.stat_scores
|
||||
:noindex:
|
||||
|
||||
:hidden:`stat_scores_multiple_classes`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.classification.stat_scores_multiple_classes
|
||||
:noindex:
|
||||
|
||||
|
||||
Regression functions
|
||||
------------------------
|
||||
|
||||
:hidden:`explained_variance`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.explained_variance
|
||||
:noindex:
|
||||
|
||||
|
||||
:hidden:`mean_absolute_error`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.mean_absolute_error
|
||||
:noindex:
|
||||
|
||||
:hidden:`mean_squared_error`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.mean_squared_error
|
||||
:noindex:
|
||||
|
||||
:hidden:`mean_squared_log_error`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.mean_squared_log_error
|
||||
:noindex:
|
||||
|
||||
:hidden:`psnr`
|
||||
~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.psnr
|
||||
:noindex:
|
||||
|
||||
:hidden:`ssim`
|
||||
~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.ssim
|
||||
:noindex:
|
||||
|
||||
|
||||
NLP functions
|
||||
------------------------
|
||||
|
||||
:hidden:`bleu_score`
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.nlp.bleu_score
|
||||
:noindex:
|
||||
|
||||
Pairwise functions
|
||||
--------------------
|
||||
:hidden:`embedding_similarity`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.functional.self_supervised.embedding_similarity
|
||||
:noindex:
|
||||
|
||||
|
||||
Utility functions
|
||||
--------------------
|
||||
|
||||
:hidden:`select_topk`
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.utils.select_topk
|
||||
:noindex:
|
||||
|
||||
|
||||
:hidden:`to_categorical`
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.utils.to_categorical
|
||||
:noindex:
|
||||
|
||||
:hidden:`to_onehot`
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autofunction:: torchmetrics.utils.to_onehot
|
||||
:noindex:
|
|
@ -10,10 +10,22 @@ PyTorchMetrics documentation
|
|||
:name: start
|
||||
:caption: Start here
|
||||
|
||||
intro
|
||||
torchmetrics
|
||||
functional
|
||||
lightning
|
||||
pages/intro
|
||||
pages/implement
|
||||
pages/overview
|
||||
pages/lightning
|
||||
references/functional
|
||||
references/modules
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:name: community
|
||||
:caption: Community
|
||||
|
||||
|
||||
generated/CODE_OF_CONDUCT.md
|
||||
generated/CONTRIBUTING.md
|
||||
generated/CHANGELOG.md
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
@ -21,9 +33,3 @@ Indices and tables
|
|||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
||||
|
||||
.. This is here to make sphinx aware of the modules but not throw an error/warning
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
readme
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
|
||||
############
|
||||
Introduction
|
||||
############
|
||||
|
||||
``torchmetrics`` is a Metrics API created for easy metric development and usage in
|
||||
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
|
||||
common metric implementations.
|
||||
|
||||
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
|
||||
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
|
||||
serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the
|
||||
provided input.
|
||||
|
||||
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
|
||||
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
|
||||
logic present in ``.compute()`` is applied to state information from all processes.
|
||||
|
||||
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
||||
|
||||
.. NOTE: this can't actually be tested as epochs, train_data, and valid_data are undefined
|
||||
.. code-block:: python
|
||||
|
||||
from torchmetrics.classification import Accuracy
|
||||
|
||||
train_accuracy = Accuracy()
|
||||
valid_accuracy = Accuracy(compute_on_step=False)
|
||||
|
||||
for epoch in range(epochs):
|
||||
for x, y in train_data:
|
||||
y_hat = model(x)
|
||||
|
||||
# training step accuracy
|
||||
batch_acc = train_accuracy(y_hat, y)
|
||||
|
||||
for x, y in valid_data:
|
||||
y_hat = model(x)
|
||||
valid_accuracy(y_hat, y)
|
||||
|
||||
# total accuracy over all training batches
|
||||
total_train_accuracy = train_accuracy.compute()
|
||||
|
||||
# total accuracy over all validation batches
|
||||
total_valid_accuracy = valid_accuracy.compute()
|
||||
|
||||
.. note::
|
||||
|
||||
Metrics contain internal states that keep track of the data seen so far.
|
||||
Do not mix metric states across training, validation and testing.
|
||||
It is highly recommended to re-initialize the metric per mode as
|
||||
shown in the examples above.
|
||||
|
||||
.. note::
|
||||
|
||||
Metric states are **not** added to the models ``state_dict`` by default.
|
||||
To change this, after initializing the metric, the method ``.persistent(mode)`` can
|
||||
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.
|
||||
|
||||
*********************
|
||||
Implementing a Metric
|
||||
*********************
|
||||
|
||||
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:
|
||||
|
||||
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
|
||||
- ``update()``: Any code needed to update the state given any inputs to the metric.
|
||||
- ``compute()``: Computes a final value from the state of the metric.
|
||||
|
||||
All you need to do is call ``add_state`` correctly to implement a custom metric with DDP.
|
||||
``reset()`` is called on metric state variables added using ``add_state()``.
|
||||
|
||||
To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
|
||||
from the base ``Metric`` class.
|
||||
|
||||
Example implementation:
|
||||
|
||||
.. testcode::
|
||||
|
||||
from torchmetrics import Metric
|
||||
|
||||
class MyAccuracy(Metric):
|
||||
def __init__(self, dist_sync_on_step=False):
|
||||
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
||||
|
||||
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
preds, target = self._input_format(preds, target)
|
||||
assert preds.shape == target.shape
|
||||
|
||||
self.correct += torch.sum(preds == target)
|
||||
self.total += target.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.correct.float() / self.total
|
||||
|
||||
Metrics support backpropagation, if all computations involved in the metric calculation
|
||||
are differentiable. However, note that the cached state is detached from the computational
|
||||
graph and cannot be backpropagated. Not doing this would mean storing the computational
|
||||
graph for each update call, which can lead to out-of-memory errors.
|
||||
In practise this means that:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
metric = MyMetric()
|
||||
val = metric(pred, target) # this value can be backpropagated
|
||||
val = metric.compute() # this value cannot be backpropagated
|
|
@ -0,0 +1,87 @@
|
|||
*********************
|
||||
Implementing a Metric
|
||||
*********************
|
||||
|
||||
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:
|
||||
|
||||
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
|
||||
- ``update()``: Any code needed to update the state given any inputs to the metric.
|
||||
- ``compute()``: Computes a final value from the state of the metric.
|
||||
|
||||
All you need to do is call ``add_state`` correctly to implement a custom metric with DDP.
|
||||
``reset()`` is called on metric state variables added using ``add_state()``.
|
||||
|
||||
To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
|
||||
from the base ``Metric`` class.
|
||||
|
||||
Example implementation:
|
||||
|
||||
.. testcode::
|
||||
|
||||
from torchmetrics import Metric
|
||||
|
||||
class MyAccuracy(Metric):
|
||||
def __init__(self, dist_sync_on_step=False):
|
||||
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
||||
|
||||
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
preds, target = self._input_format(preds, target)
|
||||
assert preds.shape == target.shape
|
||||
|
||||
self.correct += torch.sum(preds == target)
|
||||
self.total += target.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.correct.float() / self.total
|
||||
|
||||
Metrics support backpropagation, if all computations involved in the metric calculation
|
||||
are differentiable. However, note that the cached state is detached from the computational
|
||||
graph and cannot be backpropagated. Not doing this would mean storing the computational
|
||||
graph for each update call, which can lead to out-of-memory errors.
|
||||
In practise this means that:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
metric = MyMetric()
|
||||
val = metric(pred, target) # this value can be backpropagated
|
||||
val = metric.compute() # this value cannot be backpropagated
|
||||
|
||||
|
||||
Internal implementation details
|
||||
-------------------------------
|
||||
|
||||
This section briefly describe how metrics work internally. We encourage looking at the source code for more info.
|
||||
Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
|
||||
synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the
|
||||
following internally:
|
||||
|
||||
1. Clears computed cache
|
||||
2. Calls user-defined ``update()``
|
||||
|
||||
Simiarly, calling ``compute()`` does the following internally
|
||||
|
||||
1. Syncs metric states between processes
|
||||
2. Reduce gathered metric states
|
||||
3. Calls the user defined ``compute()`` method on the gathered metric states
|
||||
4. Cache computed result
|
||||
|
||||
From a user's standpoint this has one important side-effect: computed results are cached. This means that no
|
||||
matter how many times ``compute`` is called after one and another, it will continue to return the same result.
|
||||
The cache is first emptied on the next call to ``update``.
|
||||
|
||||
``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
|
||||
metric state for accumulating over multiple batches. The ``forward()`` method achives this by combining calls
|
||||
to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``):
|
||||
|
||||
1. Calls ``update()`` to update the global metric states (for accumulation over multiple batches)
|
||||
2. Caches the global state
|
||||
3. Calls ``reset()`` to clear global metric state
|
||||
4. Calls ``update()`` to update local metric state
|
||||
5. Calls ``compute()`` to calculate metric for current batch
|
||||
6. Restores the global state
|
||||
|
||||
This procedure has the consequence of calling the user defined ``update`` **twice** during a single
|
||||
forward call (one to update global statistics and one for getting the batch statistics).
|
|
@ -0,0 +1,301 @@
|
|||
|
||||
############
|
||||
Introduction
|
||||
############
|
||||
|
||||
``torchmetrics`` is a Metrics API created for easy metric development and usage in
|
||||
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
|
||||
common metric implementations.
|
||||
|
||||
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
|
||||
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
|
||||
serves the dual purpose of calling ``update()`` on its input and simultaneously returning the value of the metric over the
|
||||
provided input.
|
||||
|
||||
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
|
||||
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
|
||||
logic present in ``.compute()`` is applied to state information from all processes.
|
||||
|
||||
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
||||
|
||||
.. NOTE: this can't actually be tested as epochs, train_data, and valid_data are undefined
|
||||
.. code-block:: python
|
||||
|
||||
from torchmetrics.classification import Accuracy
|
||||
|
||||
train_accuracy = Accuracy()
|
||||
valid_accuracy = Accuracy(compute_on_step=False)
|
||||
|
||||
for epoch in range(epochs):
|
||||
for x, y in train_data:
|
||||
y_hat = model(x)
|
||||
|
||||
# training step accuracy
|
||||
batch_acc = train_accuracy(y_hat, y)
|
||||
|
||||
for x, y in valid_data:
|
||||
y_hat = model(x)
|
||||
valid_accuracy(y_hat, y)
|
||||
|
||||
# total accuracy over all training batches
|
||||
total_train_accuracy = train_accuracy.compute()
|
||||
|
||||
# total accuracy over all validation batches
|
||||
total_valid_accuracy = valid_accuracy.compute()
|
||||
|
||||
.. note::
|
||||
|
||||
Metrics contain internal states that keep track of the data seen so far.
|
||||
Do not mix metric states across training, validation and testing.
|
||||
It is highly recommended to re-initialize the metric per mode as
|
||||
shown in the examples above.
|
||||
|
||||
.. note::
|
||||
|
||||
Metric states are **not** added to the models ``state_dict`` by default.
|
||||
To change this, after initializing the metric, the method ``.persistent(mode)`` can
|
||||
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.
|
||||
|
||||
|
||||
*******************
|
||||
Metrics and devices
|
||||
*******************
|
||||
|
||||
Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave
|
||||
similar to buffers and parameters of modules. This means that metrics states should
|
||||
be moved to the same device as the input of the metric:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torchmetrics import Accuracy
|
||||
|
||||
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
|
||||
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))
|
||||
|
||||
# Metric states are always initialized on cpu, and needs to be moved to
|
||||
# the correct device
|
||||
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
|
||||
out = confmat(preds, target)
|
||||
print(out.device) # cuda:0
|
||||
|
||||
|
||||
******************
|
||||
Metric Arithmetics
|
||||
******************
|
||||
|
||||
Metrics support most of python built-in operators for arithmetic, logic and bitwise operations.
|
||||
|
||||
For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary.
|
||||
It can now be done with:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
first_metric = MyFirstMetric()
|
||||
second_metric = MySecondMetric()
|
||||
|
||||
new_metric = first_metric + second_metric
|
||||
|
||||
``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but
|
||||
forwards only the keyword arguments that are available in respective metric's update declaration.
|
||||
|
||||
Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up.
|
||||
|
||||
This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats):
|
||||
|
||||
* Addition (``a + b``)
|
||||
* Bitwise AND (``a & b``)
|
||||
* Equality (``a == b``)
|
||||
* Floordivision (``a // b``)
|
||||
* Greater Equal (``a >= b``)
|
||||
* Greater (``a > b``)
|
||||
* Less Equal (``a <= b``)
|
||||
* Less (``a < b``)
|
||||
* Matrix Multiplication (``a @ b``)
|
||||
* Modulo (``a % b``)
|
||||
* Multiplication (``a * b``)
|
||||
* Inequality (``a != b``)
|
||||
* Bitwise OR (``a | b``)
|
||||
* Power (``a ** b``)
|
||||
* Substraction (``a - b``)
|
||||
* True Division (``a / b``)
|
||||
* Bitwise XOR (``a ^ b``)
|
||||
* Absolute Value (``abs(a)``)
|
||||
* Inversion (``~a``)
|
||||
* Negative Value (``neg(a)``)
|
||||
* Positive Value (``pos(a)``)
|
||||
|
||||
****************
|
||||
MetricCollection
|
||||
****************
|
||||
|
||||
In many cases it is beneficial to evaluate the model output by multiple metrics.
|
||||
In this case the `MetricCollection` class may come in handy. It accepts a sequence
|
||||
of metrics and wraps theses into a single callable metric class, with the same
|
||||
interface as any other metric.
|
||||
|
||||
Example:
|
||||
|
||||
.. testcode::
|
||||
|
||||
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
|
||||
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
|
||||
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
|
||||
metric_collection = MetricCollection([
|
||||
Accuracy(),
|
||||
Precision(num_classes=3, average='macro'),
|
||||
Recall(num_classes=3, average='macro')
|
||||
])
|
||||
print(metric_collection(preds, target))
|
||||
|
||||
.. testoutput::
|
||||
:options: +NORMALIZE_WHITESPACE
|
||||
|
||||
{'Accuracy': tensor(0.1250),
|
||||
'Precision': tensor(0.0667),
|
||||
'Recall': tensor(0.1111)}
|
||||
|
||||
Similarly it can also reduce the amount of code required to log multiple metrics
|
||||
inside your LightningModule
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def __init__(self):
|
||||
...
|
||||
metrics = pl.metrics.MetricCollection(...)
|
||||
self.train_metrics = metrics.clone()
|
||||
self.valid_metrics = metrics.clone()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.train_metrics(logits, y)
|
||||
# use log_dict instead of log
|
||||
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
self.valid_metrics(logits, y)
|
||||
# use log_dict instead of log
|
||||
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
|
||||
|
||||
.. note::
|
||||
|
||||
`MetricCollection` as default assumes that all the metrics in the collection
|
||||
have the same call signature. If this is not the case, input that should be
|
||||
given to different metrics can given as keyword arguments to the collection.
|
||||
|
||||
.. autoclass:: torchmetrics.MetricCollection
|
||||
:noindex:
|
||||
|
||||
|
||||
***************************
|
||||
Class vs Functional Metrics
|
||||
***************************
|
||||
|
||||
The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
|
||||
|
||||
Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
|
||||
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface.
|
||||
|
||||
**********************
|
||||
Classification Metrics
|
||||
**********************
|
||||
|
||||
Input types
|
||||
-----------
|
||||
|
||||
For the purposes of classification metrics, inputs (predictions and targets) are split
|
||||
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
|
||||
|
||||
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
|
||||
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
|
||||
:widths: 20, 10, 10, 10, 10
|
||||
|
||||
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
|
||||
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
|
||||
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
|
||||
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
|
||||
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
|
||||
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
|
||||
|
||||
.. note::
|
||||
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
|
||||
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
|
||||
|
||||
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
|
||||
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
|
||||
|
||||
.. testcode::
|
||||
|
||||
# Binary inputs
|
||||
binary_preds = torch.tensor([0.6, 0.1, 0.9])
|
||||
binary_target = torch.tensor([1, 0, 2])
|
||||
|
||||
# Multi-class inputs
|
||||
mc_preds = torch.tensor([0, 2, 1])
|
||||
mc_target = torch.tensor([0, 1, 2])
|
||||
|
||||
# Multi-class inputs with probabilities
|
||||
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
|
||||
mc_target_probs = torch.tensor([0, 1, 2])
|
||||
|
||||
# Multi-label inputs
|
||||
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
|
||||
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
|
||||
|
||||
|
||||
Using the is_multiclass parameter
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
|
||||
but are actually binary/multi-label - for example, if both predictions and targets are
|
||||
integer (binary) tensors. Or it could be the other way around, you want to treat
|
||||
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.
|
||||
|
||||
For these cases, the metrics where this distinction would make a difference, expose the
|
||||
``is_multiclass`` argument. Let's see how this is used on the example of
|
||||
:class:`~torchmetrics.StatScores` metric.
|
||||
|
||||
First, let's consider the case with label predictions with 2 classes, which we want to
|
||||
treat as binary.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from torchmetrics.functional import stat_scores
|
||||
|
||||
# These inputs are supposed to be binary, but appear as multi-class
|
||||
preds = torch.tensor([0, 1, 0])
|
||||
target = torch.tensor([1, 1, 0])
|
||||
|
||||
As you can see below, by default the inputs are treated
|
||||
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
|
||||
which is the same as converting the predictions to float beforehand.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
|
||||
tensor([[1, 1, 1, 0, 1],
|
||||
[1, 0, 1, 1, 2]])
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
|
||||
tensor([[1, 0, 1, 1, 2]])
|
||||
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
|
||||
tensor([[1, 0, 1, 1, 2]])
|
||||
|
||||
Next, consider the opposite example: inputs are binary (as predictions are probabilities),
|
||||
but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.
|
||||
|
||||
.. testcode::
|
||||
|
||||
preds = torch.tensor([0.2, 0.7, 0.3])
|
||||
target = torch.tensor([1, 1, 0])
|
||||
|
||||
In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
|
||||
tensor([[1, 0, 1, 1, 2]])
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
|
||||
tensor([[1, 1, 1, 0, 1],
|
||||
[1, 0, 1, 1, 2]])
|
|
@ -0,0 +1,104 @@
|
|||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
**********************
|
||||
Classification Metrics
|
||||
**********************
|
||||
|
||||
Input types
|
||||
-----------
|
||||
|
||||
For the purposes of classification metrics, inputs (predictions and targets) are split
|
||||
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
|
||||
|
||||
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
|
||||
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
|
||||
:widths: 20, 10, 10, 10, 10
|
||||
|
||||
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
|
||||
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
|
||||
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
|
||||
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
|
||||
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
|
||||
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
|
||||
|
||||
.. note::
|
||||
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
|
||||
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
|
||||
|
||||
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
|
||||
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
|
||||
|
||||
.. testcode::
|
||||
|
||||
# Binary inputs
|
||||
binary_preds = torch.tensor([0.6, 0.1, 0.9])
|
||||
binary_target = torch.tensor([1, 0, 2])
|
||||
|
||||
# Multi-class inputs
|
||||
mc_preds = torch.tensor([0, 2, 1])
|
||||
mc_target = torch.tensor([0, 1, 2])
|
||||
|
||||
# Multi-class inputs with probabilities
|
||||
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
|
||||
mc_target_probs = torch.tensor([0, 1, 2])
|
||||
|
||||
# Multi-label inputs
|
||||
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
|
||||
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
|
||||
|
||||
|
||||
Using the is_multiclass parameter
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
|
||||
but are actually binary/multi-label - for example, if both predictions and targets are
|
||||
integer (binary) tensors. Or it could be the other way around, you want to treat
|
||||
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.
|
||||
|
||||
For these cases, the metrics where this distinction would make a difference, expose the
|
||||
``is_multiclass`` argument. Let's see how this is used on the example of
|
||||
:class:`~pytorch_lightning.metrics.StatScores` metric.
|
||||
|
||||
First, let's consider the case with label predictions with 2 classes, which we want to
|
||||
treat as binary.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.metrics.functional import stat_scores
|
||||
|
||||
# These inputs are supposed to be binary, but appear as multi-class
|
||||
preds = torch.tensor([0, 1, 0])
|
||||
target = torch.tensor([1, 1, 0])
|
||||
|
||||
As you can see below, by default the inputs are treated
|
||||
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
|
||||
which is the same as converting the predictions to float beforehand.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
|
||||
tensor([[1, 1, 1, 0, 1],
|
||||
[1, 0, 1, 1, 2]])
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
|
||||
tensor([[1, 0, 1, 1, 2]])
|
||||
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
|
||||
tensor([[1, 0, 1, 1, 2]])
|
||||
|
||||
Next, consider the opposite example: inputs are binary (as predictions are probabilities),
|
||||
but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.
|
||||
|
||||
.. testcode::
|
||||
|
||||
preds = torch.tensor([0.2, 0.7, 0.3])
|
||||
target = torch.tensor([1, 1, 0])
|
||||
|
||||
In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.
|
||||
|
||||
.. doctest::
|
||||
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
|
||||
tensor([[1, 0, 1, 1, 2]])
|
||||
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
|
||||
tensor([[1, 1, 1, 0, 1],
|
||||
[1, 0, 1, 1, 2]])
|
|
@ -0,0 +1,222 @@
|
|||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
**********************
|
||||
Classification Metrics
|
||||
**********************
|
||||
|
||||
accuracy [func]
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.accuracy
|
||||
:noindex:
|
||||
|
||||
|
||||
auc [func]
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.auc
|
||||
:noindex:
|
||||
|
||||
|
||||
auroc [func]
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.auroc
|
||||
:noindex:
|
||||
|
||||
|
||||
average_precision [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.average_precision
|
||||
:noindex:
|
||||
|
||||
|
||||
confusion_matrix [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.confusion_matrix
|
||||
:noindex:
|
||||
|
||||
|
||||
dice_score [func]
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.dice_score
|
||||
:noindex:
|
||||
|
||||
|
||||
f1 [func]
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.f1
|
||||
:noindex:
|
||||
|
||||
|
||||
fbeta [func]
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.fbeta
|
||||
:noindex:
|
||||
|
||||
hamming_distance [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.hamming_distance
|
||||
:noindex:
|
||||
|
||||
iou [func]
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.iou
|
||||
:noindex:
|
||||
|
||||
|
||||
roc [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.roc
|
||||
:noindex:
|
||||
|
||||
|
||||
precision [func]
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.precision
|
||||
:noindex:
|
||||
|
||||
|
||||
precision_recall [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.precision_recall
|
||||
:noindex:
|
||||
|
||||
|
||||
precision_recall_curve [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.precision_recall_curve
|
||||
:noindex:
|
||||
|
||||
|
||||
recall [func]
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.recall
|
||||
:noindex:
|
||||
|
||||
select_topk [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.utils.select_topk
|
||||
:noindex:
|
||||
|
||||
|
||||
stat_scores [func]
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.stat_scores
|
||||
:noindex:
|
||||
|
||||
|
||||
stat_scores_multiple_classes [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.stat_scores_multiple_classes
|
||||
:noindex:
|
||||
|
||||
|
||||
to_categorical [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.utils.to_categorical
|
||||
:noindex:
|
||||
|
||||
|
||||
to_onehot [func]
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.utils.to_onehot
|
||||
:noindex:
|
||||
|
||||
|
||||
******************
|
||||
Regression Metrics
|
||||
******************
|
||||
|
||||
explained_variance [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.explained_variance
|
||||
:noindex:
|
||||
|
||||
|
||||
image_gradients [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.image_gradients
|
||||
:noindex:
|
||||
|
||||
|
||||
mean_absolute_error [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.mean_absolute_error
|
||||
:noindex:
|
||||
|
||||
|
||||
mean_squared_error [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.mean_squared_error
|
||||
:noindex:
|
||||
|
||||
|
||||
mean_squared_log_error [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.mean_squared_log_error
|
||||
:noindex:
|
||||
|
||||
|
||||
psnr [func]
|
||||
~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.psnr
|
||||
:noindex:
|
||||
|
||||
|
||||
ssim [func]
|
||||
~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.ssim
|
||||
:noindex:
|
||||
|
||||
|
||||
r2score [func]
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.r2score
|
||||
:noindex:
|
||||
|
||||
***
|
||||
NLP
|
||||
***
|
||||
|
||||
bleu_score [func]
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.bleu_score
|
||||
:noindex:
|
||||
|
||||
********
|
||||
Pairwise
|
||||
********
|
||||
|
||||
embedding_similarity [func]
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torchmetrics.functional.embedding_similarity
|
||||
:noindex:
|
|
@ -0,0 +1,141 @@
|
|||
**********************
|
||||
Classification Metrics
|
||||
**********************
|
||||
|
||||
Accuracy
|
||||
~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.Accuracy
|
||||
:noindex:
|
||||
|
||||
AveragePrecision
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.AveragePrecision
|
||||
:noindex:
|
||||
|
||||
AUC
|
||||
~~~
|
||||
|
||||
.. autoclass:: torchmetrics.AUC
|
||||
:noindex:
|
||||
|
||||
AUROC
|
||||
~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.AUROC
|
||||
:noindex:
|
||||
|
||||
ConfusionMatrix
|
||||
~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.ConfusionMatrix
|
||||
:noindex:
|
||||
|
||||
F1
|
||||
~~
|
||||
|
||||
.. autoclass:: torchmetrics.F1
|
||||
:noindex:
|
||||
|
||||
FBeta
|
||||
~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.FBeta
|
||||
:noindex:
|
||||
|
||||
IoU
|
||||
~~~
|
||||
|
||||
.. autoclass:: torchmetrics.IoU
|
||||
:noindex:
|
||||
|
||||
Hamming Distance
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.HammingDistance
|
||||
:noindex:
|
||||
|
||||
Precision
|
||||
~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.Precision
|
||||
:noindex:
|
||||
|
||||
PrecisionRecallCurve
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.PrecisionRecallCurve
|
||||
:noindex:
|
||||
|
||||
Recall
|
||||
~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.Recall
|
||||
:noindex:
|
||||
|
||||
ROC
|
||||
~~~
|
||||
|
||||
.. autoclass:: torchmetrics.ROC
|
||||
:noindex:
|
||||
|
||||
|
||||
StatScores
|
||||
~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.StatScores
|
||||
:noindex:
|
||||
|
||||
|
||||
******************
|
||||
Regression Metrics
|
||||
******************
|
||||
|
||||
ExplainedVariance
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.ExplainedVariance
|
||||
:noindex:
|
||||
|
||||
|
||||
MeanAbsoluteError
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.MeanAbsoluteError
|
||||
:noindex:
|
||||
|
||||
|
||||
MeanSquaredError
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.MeanSquaredError
|
||||
:noindex:
|
||||
|
||||
|
||||
MeanSquaredLogError
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.MeanSquaredLogError
|
||||
:noindex:
|
||||
|
||||
|
||||
PSNR
|
||||
~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.PSNR
|
||||
:noindex:
|
||||
|
||||
|
||||
SSIM
|
||||
~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.SSIM
|
||||
:noindex:
|
||||
|
||||
|
||||
R2Score
|
||||
~~~~~~~
|
||||
|
||||
.. autoclass:: torchmetrics.R2Score
|
||||
:noindex:
|
|
@ -1,90 +0,0 @@
|
|||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
torchmetrics
|
||||
===================================
|
||||
|
||||
Classification
|
||||
--------------------
|
||||
|
||||
For the purposes of classification metrics, inputs (predictions and targets) are split
|
||||
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
|
||||
|
||||
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
|
||||
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
|
||||
:widths: 20, 10, 10, 10, 10
|
||||
|
||||
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
|
||||
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
|
||||
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
|
||||
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
|
||||
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
|
||||
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
|
||||
|
||||
.. note::
|
||||
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
|
||||
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
|
||||
|
||||
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
|
||||
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
|
||||
|
||||
.. testcode::
|
||||
|
||||
# Binary inputs
|
||||
binary_preds = torch.tensor([0.6, 0.1, 0.9])
|
||||
binary_target = torch.tensor([1, 0, 2])
|
||||
|
||||
# Multi-class inputs
|
||||
mc_preds = torch.tensor([0, 2, 1])
|
||||
mc_target = torch.tensor([0, 1, 2])
|
||||
|
||||
# Multi-class inputs with probabilities
|
||||
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
|
||||
mc_target_probs = torch.tensor([0, 1, 2])
|
||||
|
||||
# Multi-label inputs
|
||||
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
|
||||
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
|
||||
|
||||
In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
|
||||
but are actually binary/multi-label. For example, if both predictions and targets are 1d
|
||||
binary tensors. Or it could be the other way around, you want to treat binary/multi-label
|
||||
inputs as 2-class (multi-dimensional) multi-class inputs.
|
||||
|
||||
For these cases, the metrics where this distinction would make a difference, expose the
|
||||
``is_multiclass`` argument.
|
||||
|
||||
.. currentmodule:: torchmetrics.classification
|
||||
|
||||
.. autosummary::
|
||||
:toctree: api
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
Accuracy
|
||||
AveragePrecision
|
||||
ConfusionMatrix
|
||||
F1
|
||||
FBeta
|
||||
Precision
|
||||
PrecisionRecallCurve
|
||||
Recall
|
||||
ROC
|
||||
|
||||
|
||||
Regression
|
||||
--------------------
|
||||
|
||||
.. currentmodule:: torchmetrics.regression
|
||||
|
||||
.. autosummary::
|
||||
:toctree: api
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
ExplainedVariance
|
||||
MeanAbsoluteError
|
||||
MeanSquaredError
|
||||
MeanSquaredLogError
|
||||
PSNR
|
||||
SSIM
|
2
setup.py
|
@ -49,7 +49,7 @@ setup(
|
|||
author=torchmetrics.__author__,
|
||||
author_email=torchmetrics.__author_email__,
|
||||
url=torchmetrics.__homepage__,
|
||||
download_url='https://github.com/PyTorchLightning/torchmetrics',
|
||||
download_url='https://github.com/PyTorchLightning/metrics/archive/master.zip',
|
||||
license=torchmetrics.__license__,
|
||||
packages=find_packages(exclude=['tests', 'docs']),
|
||||
long_description=load_long_describtion(),
|
||||
|
|
|
@ -38,7 +38,7 @@ class Accuracy(Metric):
|
|||
changed to subset accuracy (which requires all labels or sub-samples in the sample to
|
||||
be correctly predicted) by setting ``subset_accuracy=True``.
|
||||
|
||||
Accepts all input types listed in :ref:`extensions/metrics:input types`.
|
||||
Accepts all input types listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
threshold:
|
||||
|
@ -127,7 +127,7 @@ class Accuracy(Metric):
|
|||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information
|
||||
Update state with predictions and targets. See :ref:`pages/overview:input types` for more information
|
||||
on input types.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -35,7 +35,7 @@ class HammingDistance(Metric):
|
|||
treats each possible label separately - meaning that, for example, multi-class data is
|
||||
treated as if it were multi-label.
|
||||
|
||||
Accepts all input types listed in :ref:`extensions/metrics:input types`.
|
||||
Accepts all input types listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
threshold:
|
||||
|
@ -88,7 +88,7 @@ class HammingDistance(Metric):
|
|||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information
|
||||
Update state with predictions and targets. See :ref:`pages/overview:input types` for more information
|
||||
on input types.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -283,7 +283,7 @@ def _check_classification_inputs(
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
|
||||
|
@ -409,7 +409,7 @@ def _input_format_classification(
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -31,7 +31,7 @@ class Precision(StatScores):
|
|||
|
||||
The reduction method (how the precision scores are aggregated) is controlled by the
|
||||
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
num_classes:
|
||||
|
@ -67,11 +67,11 @@ class Precision(StatScores):
|
|||
- ``'samplewise'``: In this case, the statistics are computed separately for each
|
||||
sample on the ``N`` axis, and then averaged over samples.
|
||||
The computation for each sample is done by treating the flattened extra axes ``...``
|
||||
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
|
||||
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
|
||||
and computing the metric for the sample based on that.
|
||||
|
||||
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
|
||||
(see :ref:`extensions/metrics:input types`)
|
||||
(see :ref:`pages/overview:input types`)
|
||||
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
|
||||
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
|
||||
|
||||
|
@ -90,7 +90,7 @@ class Precision(StatScores):
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
compute_on_step:
|
||||
|
@ -181,7 +181,7 @@ class Recall(StatScores):
|
|||
|
||||
The reduction method (how the recall scores are aggregated) is controlled by the
|
||||
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
num_classes:
|
||||
|
@ -217,11 +217,11 @@ class Recall(StatScores):
|
|||
- ``'samplewise'``: In this case, the statistics are computed separately for each
|
||||
sample on the ``N`` axis, and then averaged over samples.
|
||||
The computation for each sample is done by treating the flattened extra axes ``...``
|
||||
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
|
||||
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
|
||||
and computing the metric for the sample based on that.
|
||||
|
||||
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
|
||||
(see :ref:`extensions/metrics:input types`)
|
||||
(see :ref:`pages/overview:input types`)
|
||||
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
|
||||
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
|
||||
|
||||
|
@ -241,7 +241,7 @@ class Recall(StatScores):
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
compute_on_step:
|
||||
|
|
|
@ -28,7 +28,7 @@ class StatScores(Metric):
|
|||
``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the
|
||||
multi-dimensional multi-class case.
|
||||
|
||||
Accepts all inputs listed in :ref:`extensions/metrics:input types`.
|
||||
Accepts all inputs listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
threshold:
|
||||
|
@ -71,7 +71,7 @@ class StatScores(Metric):
|
|||
one of the following:
|
||||
|
||||
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
|
||||
multi-class (see :ref:`extensions/metrics:input types` for the definition of input types).
|
||||
multi-class (see :ref:`pages/overview:input types` for the definition of input types).
|
||||
|
||||
- ``'samplewise'``: In this case, the statistics are computed separately for each
|
||||
sample on the ``N`` axis, and then the outputs are concatenated together. In each
|
||||
|
@ -86,7 +86,7 @@ class StatScores(Metric):
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
compute_on_step:
|
||||
|
@ -175,7 +175,7 @@ class StatScores(Metric):
|
|||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets. See :ref:`extensions/metrics:input types` for more information
|
||||
Update state with predictions and targets. See :ref:`pages/overview:input types` for more information
|
||||
on input types.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -72,7 +72,7 @@ def accuracy(
|
|||
changed to subset accuracy (which requires all labels or sub-samples in the sample to
|
||||
be correctly predicted) by setting ``subset_accuracy=True``.
|
||||
|
||||
Accepts all input types listed in :ref:`extensions/metrics:input types`.
|
||||
Accepts all input types listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model (probabilities, or labels)
|
||||
|
|
|
@ -51,7 +51,7 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float
|
|||
treats each possible label separately - meaning that, for example, multi-class data is
|
||||
treated as if it were multi-label.
|
||||
|
||||
Accepts all input types listed in :ref:`extensions/metrics:input types`.
|
||||
Accepts all input types listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
|
|
|
@ -60,7 +60,7 @@ def precision(
|
|||
|
||||
The reduction method (how the precision scores are aggregated) is controlled by the
|
||||
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model (probabilities or labels)
|
||||
|
@ -94,11 +94,11 @@ def precision(
|
|||
- ``'samplewise'``: In this case, the statistics are computed separately for each
|
||||
sample on the ``N`` axis, and then averaged over samples.
|
||||
The computation for each sample is done by treating the flattened extra axes ``...``
|
||||
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
|
||||
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
|
||||
and computing the metric for the sample based on that.
|
||||
|
||||
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
|
||||
(see :ref:`extensions/metrics:input types`)
|
||||
(see :ref:`pages/overview:input types`)
|
||||
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
|
||||
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
|
||||
|
||||
|
@ -123,7 +123,7 @@ def precision(
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
class_reduction:
|
||||
|
@ -225,7 +225,7 @@ def recall(
|
|||
|
||||
The reduction method (how the recall scores are aggregated) is controlled by the
|
||||
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model (probabilities, or labels)
|
||||
|
@ -256,11 +256,11 @@ def recall(
|
|||
- ``'samplewise'``: In this case, the statistics are computed separately for each
|
||||
sample on the ``N`` axis, and then averaged over samples.
|
||||
The computation for each sample is done by treating the flattened extra axes ``...``
|
||||
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
|
||||
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
|
||||
and computing the metric for the sample based on that.
|
||||
|
||||
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
|
||||
(see :ref:`extensions/metrics:input types`)
|
||||
(see :ref:`pages/overview:input types`)
|
||||
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
|
||||
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
|
||||
|
||||
|
@ -285,7 +285,7 @@ def recall(
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
class_reduction:
|
||||
|
@ -373,7 +373,7 @@ def precision_recall(
|
|||
|
||||
The reduction method (how the recall scores are aggregated) is controlled by the
|
||||
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model (probabilities, or labels)
|
||||
|
@ -404,11 +404,11 @@ def precision_recall(
|
|||
- ``'samplewise'``: In this case, the statistics are computed separately for each
|
||||
sample on the ``N`` axis, and then averaged over samples.
|
||||
The computation for each sample is done by treating the flattened extra axes ``...``
|
||||
(see :ref:`extensions/metrics:input types`) as the ``N`` dimension within the sample,
|
||||
(see :ref:`pages/overview:input types`) as the ``N`` dimension within the sample,
|
||||
and computing the metric for the sample based on that.
|
||||
|
||||
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
|
||||
(see :ref:`extensions/metrics:input types`)
|
||||
(see :ref:`pages/overview:input types`)
|
||||
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
|
||||
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
|
||||
|
||||
|
@ -433,7 +433,7 @@ def precision_recall(
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
class_reduction:
|
||||
|
|
|
@ -153,7 +153,7 @@ def stat_scores(
|
|||
|
||||
The reduction method (how the statistics are aggregated) is controlled by the
|
||||
``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`extensions/metrics:input types`.
|
||||
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`pages/overview:input types`.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model (probabilities or labels)
|
||||
|
@ -198,7 +198,7 @@ def stat_scores(
|
|||
one of the following:
|
||||
|
||||
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
|
||||
multi-class (see :ref:`extensions/metrics:input types` for the definition of input types).
|
||||
multi-class (see :ref:`pages/overview:input types` for the definition of input types).
|
||||
|
||||
- ``'samplewise'``: In this case, the statistics are computed separately for each
|
||||
sample on the ``N`` axis, and then the outputs are concatenated together. In each
|
||||
|
@ -213,7 +213,7 @@ def stat_scores(
|
|||
is_multiclass:
|
||||
Used only in certain special cases, where you want to treat inputs as a different type
|
||||
than what they appear to be. See the parameter's
|
||||
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
|
||||
:ref:`documentation section <pages/overview:using the is_multiclass parameter>`
|
||||
for a more detailed explanation and examples.
|
||||
|
||||
Return:
|
||||
|
|