зеркало из https://github.com/microsoft/torchgeo.git
Bump torchvision from 0.14.1 to 0.15.1 in /requirements (#1177)
* Bump torchvision from 0.14.1 to 0.15.1 in /requirements Bumps [torchvision](https://github.com/pytorch/vision) from 0.14.1 to 0.15.1. - [Release notes](https://github.com/pytorch/vision/releases) - [Commits](https://github.com/pytorch/vision/compare/v0.14.1...v0.15.1) --- updated-dependencies: - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * bump pytorch too * Fix tests * bump precommit * blacken --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
3ff642feb1
Коммит
28ce9599de
|
@ -34,5 +34,5 @@ repos:
|
|||
hooks:
|
||||
- id: mypy
|
||||
args: [--strict, --ignore-missing-imports, --show-error-codes]
|
||||
additional_dependencies: [torch>=1.13, torchmetrics>=0.10, lightning>=1.8, pytest>=6, pyvista>=0.20, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0]
|
||||
additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=1.8, pytest>=6, pyvista>=0.20, omegaconf>=2.1, kornia>=0.6, numpy>=1.22.0]
|
||||
exclude: (build|data|dist|logo|logs|output)/
|
||||
|
|
|
@ -119,7 +119,7 @@ def run_eval_loop(
|
|||
}
|
||||
for i in range(len(batch["image"]))
|
||||
]
|
||||
with torch.inference_mode(): # type: ignore[no-untyped-call]
|
||||
with torch.inference_mode():
|
||||
y_pred = model(x)
|
||||
metrics(y_pred, y)
|
||||
results = metrics.compute()
|
||||
|
|
|
@ -17,6 +17,6 @@ scikit-learn==1.2.2
|
|||
segmentation-models-pytorch==0.3.2
|
||||
shapely==2.0.1
|
||||
timm==0.6.12
|
||||
torch==1.13.1
|
||||
torch==2.0.0
|
||||
torchmetrics==0.11.4
|
||||
torchvision==0.14.1
|
||||
torchvision==0.15.1
|
||||
|
|
|
@ -55,11 +55,11 @@ install_requires =
|
|||
# timm 0.4.12 required by segmentation-models-pytorch
|
||||
timm>=0.4.12,<0.7
|
||||
# torch 1.12+ required by torchvision
|
||||
torch>=1.12,<2
|
||||
torch>=1.12,<3
|
||||
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
|
||||
torchmetrics>=0.10,<0.12
|
||||
# torchvision 0.13+ required for torchvision.models._api.WeightsEnum
|
||||
torchvision>=0.13,<0.15
|
||||
torchvision>=0.13,<0.16
|
||||
python_requires = >=3.8,<4
|
||||
packages = find:
|
||||
|
||||
|
|
|
@ -32,7 +32,10 @@ class TestResNet18:
|
|||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, "url", str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
|
@ -59,7 +62,10 @@ class TestResNet50:
|
|||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, "url", str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
|
|
|
@ -34,7 +34,10 @@ class TestViTSmall16:
|
|||
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, "url", str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
|
|
|
@ -123,7 +123,10 @@ class TestBYOLTask:
|
|||
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, "url", str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
|
|
|
@ -146,7 +146,10 @@ class TestClassificationTask:
|
|||
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, "url", str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
|
|
|
@ -109,7 +109,10 @@ class TestRegressionTask:
|
|||
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, "url", str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
|
|
|
@ -56,9 +56,7 @@ class RQLoss(Module):
|
|||
q = probs
|
||||
|
||||
# manually normalize due to https://github.com/pytorch/pytorch/issues/70100
|
||||
z = q / q.norm( # type: ignore[no-untyped-call]
|
||||
p=1, dim=(0, 2, 3), keepdim=True
|
||||
).clamp_min(1e-12).expand_as(q)
|
||||
z = q / q.norm(p=1, dim=(0, 2, 3), keepdim=True).clamp_min(1e-12).expand_as(q)
|
||||
r = F.normalize(z * target, p=1, dim=1)
|
||||
|
||||
loss = torch.einsum("bcxy,bcxy->bxy", r, torch.log(r) - torch.log(q)).mean()
|
||||
|
|
Загрузка…
Ссылка в новой задаче