зеркало из https://github.com/microsoft/torchgeo.git
Bump torchmetrics from 0.11.4 to 1.0.0 in /requirements (#1465)
* Bump torchmetrics from 0.11.4 to 1.0.0 in /requirements Bumps [torchmetrics](https://github.com/Lightning-AI/torchmetrics) from 0.11.4 to 1.0.0. - [Release notes](https://github.com/Lightning-AI/torchmetrics/releases) - [Changelog](https://github.com/Lightning-AI/torchmetrics/blob/master/CHANGELOG.md) - [Commits](https://github.com/Lightning-AI/torchmetrics/compare/v0.11.4...v1.0.0) --- updated-dependencies: - dependency-name: torchmetrics dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <support@github.com> * Bump max version * mdmc_average -> multidim_average * debugging * Workaround for bug in MAP dict * Workaround for bug in MAP dict * Workaround for bug in MAP dict * Workaround for bug in MAP dict --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
2aabf3381d
Коммит
e75edab85b
|
@ -68,7 +68,7 @@ dependencies = [
|
|||
# torch 1.12+ required by torchvision
|
||||
"torch>=1.12,<3",
|
||||
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
|
||||
"torchmetrics>=0.10,<0.12",
|
||||
"torchmetrics>=0.10,<2",
|
||||
# torchvision 0.13+ required for torchvision.models._api.WeightsEnum
|
||||
"torchvision>=0.13,<0.16",
|
||||
]
|
||||
|
|
|
@ -17,5 +17,5 @@ segmentation-models-pytorch==0.3.3
|
|||
shapely==2.0.1
|
||||
timm==0.9.2
|
||||
torch==2.0.1
|
||||
torchmetrics==0.11.4
|
||||
torchmetrics==1.0.0
|
||||
torchvision==0.15.2
|
||||
|
|
|
@ -12,6 +12,7 @@ import torchvision.models.detection
|
|||
from lightning.pytorch import LightningModule
|
||||
from torch import Tensor
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torchmetrics import MetricCollection
|
||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
from torchvision.models import resnet as R
|
||||
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
||||
|
@ -187,8 +188,9 @@ class ObjectDetectionTask(LightningModule):
|
|||
|
||||
self.config_task()
|
||||
|
||||
self.val_metrics = MeanAveragePrecision()
|
||||
self.test_metrics = MeanAveragePrecision()
|
||||
metrics = MetricCollection([MeanAveragePrecision()])
|
||||
self.val_metrics = metrics.clone(prefix="val_")
|
||||
self.test_metrics = metrics.clone(prefix="test_")
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Forward pass of the model.
|
||||
|
@ -273,8 +275,11 @@ class ObjectDetectionTask(LightningModule):
|
|||
def on_validation_epoch_end(self) -> None:
|
||||
"""Logs epoch level validation metrics."""
|
||||
metrics = self.val_metrics.compute()
|
||||
renamed_metrics = {f"val_{i}": metrics[i] for i in metrics.keys()}
|
||||
self.log_dict(renamed_metrics)
|
||||
|
||||
# https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
|
||||
metrics.pop("val_classes", None)
|
||||
|
||||
self.log_dict(metrics)
|
||||
self.val_metrics.reset()
|
||||
|
||||
def test_step(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
@ -297,8 +302,11 @@ class ObjectDetectionTask(LightningModule):
|
|||
def on_test_epoch_end(self) -> None:
|
||||
"""Logs epoch level test metrics."""
|
||||
metrics = self.test_metrics.compute()
|
||||
renamed_metrics = {f"test_{i}": metrics[i] for i in metrics.keys()}
|
||||
self.log_dict(renamed_metrics)
|
||||
|
||||
# https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
|
||||
metrics.pop("test_classes", None)
|
||||
|
||||
self.log_dict(metrics)
|
||||
self.test_metrics.reset()
|
||||
|
||||
def predict_step(self, *args: Any, **kwargs: Any) -> list[dict[str, Tensor]]:
|
||||
|
|
|
@ -175,7 +175,7 @@ class SemanticSegmentationTask(LightningModule):
|
|||
MulticlassAccuracy(
|
||||
num_classes=self.hyperparams["num_classes"],
|
||||
ignore_index=self.ignore_index,
|
||||
mdmc_average="global",
|
||||
multidim_average="global",
|
||||
average="micro",
|
||||
),
|
||||
MulticlassJaccardIndex(
|
||||
|
|
Загрузка…
Ссылка в новой задаче