Machine learning metrics for distributed, scalable PyTorch applications.
Перейти к файлу
Jirka Borovec 73330e460f
Docs: update & larger Logo [wip] (#30)
* readme

* pruning

* pruning

* ci

* fix

* h1

* logos

* 3
2021-03-09 15:29:24 +00:00
.github Docs: update & larger Logo [wip] (#30) 2021-03-09 15:29:24 +00:00
docs Docs: update & larger Logo [wip] (#30) 2021-03-09 15:29:24 +00:00
tests fuse compositional (#25) 2021-03-05 11:39:13 +00:00
torchmetrics Docs: update & larger Logo [wip] (#30) 2021-03-09 15:29:24 +00:00
.codecov.yml Initial commit 2020-12-22 21:02:46 +01:00
.gitignore Initial commit 2020-12-22 21:02:46 +01:00
.mergify.yml Initial commit 2020-12-22 21:02:46 +01:00
.pep8speaks.yml Initial commit 2020-12-22 21:02:46 +01:00
.readthedocs.yml Initial commit 2020-12-22 21:02:46 +01:00
CHANGELOG.md Refactor MetricCollection (#19) 2021-03-01 08:20:38 +00:00
LICENSE take over Apache-2.0 2021-02-24 12:05:16 +01:00
MANIFEST.in add makefile 2021-02-22 15:56:18 +01:00
Makefile add makefile 2021-02-22 15:56:18 +01:00
README.md Docs: update & larger Logo [wip] (#30) 2021-03-09 15:29:24 +00:00
azure-pipelines.yml CI: add multi-GPU @Azure (#22) 2021-03-01 21:10:06 +00:00
pyproject.toml mypy 2021-02-22 10:43:07 +01:00
requirements.txt add support for python 3.9 and PT 1.3 (#23) 2021-03-02 13:05:15 +00:00
setup.cfg Refactor MetricCollection (#19) 2021-03-01 08:20:38 +00:00
setup.py add support for python 3.9 and PT 1.3 (#23) 2021-03-02 13:05:15 +00:00

README.md

Collection of metrics for easy evaluating machine learning models


WebsiteWhat is TorchmetricsInstallationDocsBuild-in metricsOwn metricCommunityLicense


CI testing Check Code formatting Build Status codecov Documentation Status


What is Torchmetrics

Torchmetrics is a metrics API created for easy metric development and usage in both PyTorch and PyTorch Lightning. It was originally a part of Pytorch Lightning, but got split off so users could take advantage of the large collection of metrics implemented without having to install Pytorch Lightning (eventhough we would love for you to try it out). We currently have around 25+ metrics implemented and we continuesly is adding more metrics, both within already covered domains (classification, regression ect.) but also new domains (object detection ect.). We make sure that all our metrics are rigorously tested such that you can trust them.

Installation

Pip / conda

pip install torchmetrics -U
conda install torchmetrics

Pip from source

# with git
pip install git+https://github.com/PytorchLightning/metrics.git@master

# OR from an archive
pip install https://github.com/PyTorchLightning/metrics/archive/master.zip

Build-in metrics

Similar to torch.nn most metrics comes both as class based version and simple functional version.

  • The class based metrics offers the most functionality, by supporting both accumulation over multiple batches and automatic syncrenization between multiple devices.
import torch
# import our library
import torchmetrics 

# initialize metric
metric = torchmetrics.Accuracy()

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")    

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# Reseting internal state such that metric ready for new data
metric.reset()
  • Functional based metrics follows a simple input-output paradigme: a single batch is feed in and the metric is computed for only that
import torch
# import our library
import torchmetrics

# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)

Implementing your own metric

Implementing your own metric is as easy as subclassing an torch.nn.Module. Simply, subclass torchmetrics.Metric and do the following:

  1. Implement __init__ where you call self.add_statefor every internal state that is needed for the metrics computations
  2. Implement update method, where all logic that is nessesary for updating metric states go
  3. Implement compute method, where the final metric computations happens

Example: Root mean squared error

Root mean squared error is great example to showcase why many metric computations needs to be divided into two functions. It is defined as:

To proper calculate RMSE, we need two metric states: sum_squared_error to keep track of the squared error between the target and the predictions and n_observations to know how many observations we have encountered.

class RMSE(torchmetrics.Metric):
    def __init__(self)
	# dist_reduce_fx indicates the function that should be used to reduce 
	# state from multiple processes
        self.add_state("sum_squared_errors", torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("n_observations", torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds, target):
        # update metric states
        sum_squared_errors += torch.sum((preds - target) ** 2)
        n_observations += preds.numel()
       
    def compute(self):
        # compute final result
        return torch.sqrt(sum_squared_errors / n_observations)

Because sqrt(a+b) != sqrt(a) + sqrt(b) we cannot implement this metric as a simple mean of the RMSE score calculated per batch and instead needs to implement all logic that needs to happen before the square root in update and the remaining in compute.

Community

For help or questions, join our huge community on Slack!

Contribute!

The lightning + torchmetric team is hard at work adding even more metrics. But we're looking for incredible contributors like you to submit new metrics and improve existing ones!

Join our Slack to get help becoming a contributor!

Citations

Were excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, ignite, sklearn and fast.ai. When/if a paper is written about this, well be happy to cite these frameworks and the corresponding authors.

License

Please observe the Apache 2.0 license that is listed in this repository. In addition the Lightning framework is Patent Pending.