зеркало из https://github.com/microsoft/torchgeo.git
Bump the torch group in /requirements with 2 updates (#2192)
* Bump the torch group in /requirements with 2 updates Bumps the torch group in /requirements with 2 updates: [torch](https://github.com/pytorch/pytorch) and [torchvision](https://github.com/pytorch/vision). Updates `torch` from 2.3.1 to 2.4.0 - [Release notes](https://github.com/pytorch/pytorch/releases) - [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md) - [Commits](https://github.com/pytorch/pytorch/compare/v2.3.1...v2.4.0) Updates `torchvision` from 0.18.1 to 0.19.0 - [Release notes](https://github.com/pytorch/vision/releases) - [Commits](https://github.com/pytorch/vision/compare/v0.18.1...0.19.0) --- updated-dependencies: - dependency-name: torch dependency-type: direct:production update-type: version-update:semver-minor dependency-group: torch - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-minor dependency-group: torch ... Signed-off-by: dependabot[bot] <support@github.com> * Fix or silence mypy warnings --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
6c2ed0b93b
Коммит
fe2eae1c7e
|
@ -217,7 +217,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
params = model.parameters()
|
||||
optimizer = optim.SGD(params, lr=0.0001)
|
||||
optimizer = optim.SGD(params, lr=0.0001) # type: ignore[attr-defined]
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', args.device)
|
||||
model = model.to(device)
|
||||
|
|
|
@ -17,6 +17,6 @@ rtree==1.3.0
|
|||
segmentation-models-pytorch==0.3.3
|
||||
shapely==2.0.5
|
||||
timm==0.9.2
|
||||
torch==2.3.1
|
||||
torch==2.4.0
|
||||
torchmetrics==1.4.0.post0
|
||||
torchvision==0.18.1
|
||||
torchvision==0.19.0
|
||||
|
|
|
@ -306,7 +306,7 @@ class DFC2022(NonGeoDataset):
|
|||
ncols = 2
|
||||
image = sample['image'][:3]
|
||||
image = image.to(torch.uint8)
|
||||
image = image.permute(1, 2, 0).numpy()
|
||||
image_arr = image.permute(1, 2, 0).numpy()
|
||||
|
||||
dem = sample['image'][-1].numpy()
|
||||
dem = percentile_normalization(dem, lower=0, upper=100, axis=(0, 1))
|
||||
|
@ -325,7 +325,7 @@ class DFC2022(NonGeoDataset):
|
|||
|
||||
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10))
|
||||
|
||||
axs[0].imshow(image)
|
||||
axs[0].imshow(image_arr)
|
||||
axs[0].axis('off')
|
||||
axs[1].imshow(dem)
|
||||
axs[1].axis('off')
|
||||
|
|
|
@ -243,25 +243,25 @@ class NASAMarineDebris(NonGeoDataset):
|
|||
image = sample['image']
|
||||
if 'boxes' in sample and len(sample['boxes']):
|
||||
image = draw_bounding_boxes(image=sample['image'], boxes=sample['boxes'])
|
||||
image = image.permute((1, 2, 0)).numpy()
|
||||
image_arr = image.permute((1, 2, 0)).numpy()
|
||||
|
||||
if 'prediction_boxes' in sample and len(sample['prediction_boxes']):
|
||||
ncols += 1
|
||||
preds = draw_bounding_boxes(
|
||||
image=sample['image'], boxes=sample['prediction_boxes']
|
||||
)
|
||||
preds = preds.permute((1, 2, 0)).numpy()
|
||||
preds_arr = preds.permute((1, 2, 0)).numpy()
|
||||
|
||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
||||
if ncols < 2:
|
||||
axs.imshow(image)
|
||||
axs.imshow(image_arr)
|
||||
axs.axis('off')
|
||||
if show_titles:
|
||||
axs.set_title('Ground Truth')
|
||||
else:
|
||||
axs[0].imshow(image)
|
||||
axs[0].imshow(image_arr)
|
||||
axs[0].axis('off')
|
||||
axs[1].imshow(preds)
|
||||
axs[1].imshow(preds_arr)
|
||||
axs[1].axis('off')
|
||||
|
||||
if show_titles:
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
|
||||
import lightning
|
||||
from lightning.pytorch import LightningModule
|
||||
from torch.optim import AdamW
|
||||
from torch.optim import AdamW # type: ignore[attr-defined]
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
import lightning
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD
|
||||
from torch.optim import SGD # type: ignore[attr-defined]
|
||||
|
||||
from .base import BaseTask
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from lightly.models.modules import MoCoProjectionHead
|
|||
from lightly.models.utils import deactivate_requires_grad, update_momentum
|
||||
from lightly.utils.scheduler import cosine_schedule
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, AdamW, Optimizer
|
||||
from torch.optim import SGD, AdamW, Optimizer # type: ignore[attr-defined]
|
||||
from torch.optim.lr_scheduler import (
|
||||
CosineAnnealingLR,
|
||||
LinearLR,
|
||||
|
|
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|||
from lightly.loss import NTXentLoss
|
||||
from lightly.models.modules import SimCLRProjectionHead
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
from torch.optim import Adam # type: ignore[attr-defined]
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче