Don't use Any in sample/batch utils

This commit is contained in:
Adam J. Stewart 2024-08-21 15:50:09 +02:00
Родитель 65972c8022
Коммит e7ac62960d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
1 изменённых файлов: 14 добавлений и 14 удалений

Просмотреть файл

@ -14,10 +14,10 @@ import pathlib
import shutil
import subprocess
import sys
from collections.abc import Iterable, Iterator, Sequence, Mapping
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, TypeAlias, cast, overload
from typing import Any, TypeAlias, TypeVar, cast, overload
import numpy as np
import rasterio
@ -43,6 +43,8 @@ __all__ = (
Path: TypeAlias = str | pathlib.Path
K = TypeVar('K')
V = TypeVar('V')
@dataclass(frozen=True)
@ -367,7 +369,7 @@ def working_dir(dirname: Path, create: bool = False) -> Iterator[None]:
os.chdir(cwd)
def _list_dict_to_dict_list(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, list[Any]]:
def _list_dict_to_dict_list(samples: Iterable[Mapping[K, V]]) -> dict[K, list[V]]:
"""Convert a list of dictionaries to a dictionary of lists.
Args:
@ -385,7 +387,7 @@ def _list_dict_to_dict_list(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, l
return collated
def _dict_list_to_list_dict(sample: Mapping[Any, Sequence[Any]]) -> list[dict[Any, Any]]:
def _dict_list_to_list_dict(sample: Mapping[K, Sequence[V]]) -> list[dict[K, V]]:
"""Convert a dictionary of lists to a list of dictionaries.
Args:
@ -396,16 +398,14 @@ def _dict_list_to_list_dict(sample: Mapping[Any, Sequence[Any]]) -> list[dict[An
.. versionadded:: 0.2
"""
uncollated: list[dict[Any, Any]] = [
{} for _ in range(max(map(len, sample.values())))
]
uncollated: list[dict[K, V]] = [{} for _ in range(max(map(len, sample.values())))]
for key, values in sample.items():
for i, value in enumerate(values):
uncollated[i][key] = value
return uncollated
def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
def stack_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]:
"""Stack a list of samples along a new axis.
Useful for forming a mini-batch of samples to pass to
@ -419,14 +419,14 @@ def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
collated: dict[K, V] = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.stack(value)
return collated
def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
def concat_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]:
"""Concatenate a list of samples along an existing axis.
Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`.
@ -439,7 +439,7 @@ def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
collated: dict[K, V] = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.cat(value)
@ -448,7 +448,7 @@ def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
return collated
def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
def merge_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]:
"""Merge a list of samples.
Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`.
@ -461,7 +461,7 @@ def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = {}
collated: dict[K, V] = {}
for sample in samples:
for key, value in sample.items():
if key in collated and isinstance(value, Tensor):
@ -473,7 +473,7 @@ def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
return collated
def unbind_samples(sample: Mapping[Any, Sequence[Any]]) -> list[dict[Any, Any]]:
def unbind_samples(sample: Mapping[K, Sequence[V] | Tensor]) -> list[dict[K, V]]:
"""Reverse of :func:`stack_samples`.
Useful for turning a mini-batch of samples into a list of samples. These individual