зеркало из https://github.com/microsoft/torchgeo.git
Don't use Any in sample/batch utils
This commit is contained in:
Родитель
65972c8022
Коммит
e7ac62960d
|
@ -14,10 +14,10 @@ import pathlib
|
|||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Iterable, Iterator, Sequence, Mapping
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, TypeAlias, cast, overload
|
||||
from typing import Any, TypeAlias, TypeVar, cast, overload
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
|
@ -43,6 +43,8 @@ __all__ = (
|
|||
|
||||
|
||||
Path: TypeAlias = str | pathlib.Path
|
||||
K = TypeVar('K')
|
||||
V = TypeVar('V')
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -367,7 +369,7 @@ def working_dir(dirname: Path, create: bool = False) -> Iterator[None]:
|
|||
os.chdir(cwd)
|
||||
|
||||
|
||||
def _list_dict_to_dict_list(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, list[Any]]:
|
||||
def _list_dict_to_dict_list(samples: Iterable[Mapping[K, V]]) -> dict[K, list[V]]:
|
||||
"""Convert a list of dictionaries to a dictionary of lists.
|
||||
|
||||
Args:
|
||||
|
@ -385,7 +387,7 @@ def _list_dict_to_dict_list(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, l
|
|||
return collated
|
||||
|
||||
|
||||
def _dict_list_to_list_dict(sample: Mapping[Any, Sequence[Any]]) -> list[dict[Any, Any]]:
|
||||
def _dict_list_to_list_dict(sample: Mapping[K, Sequence[V]]) -> list[dict[K, V]]:
|
||||
"""Convert a dictionary of lists to a list of dictionaries.
|
||||
|
||||
Args:
|
||||
|
@ -396,16 +398,14 @@ def _dict_list_to_list_dict(sample: Mapping[Any, Sequence[Any]]) -> list[dict[An
|
|||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
uncollated: list[dict[Any, Any]] = [
|
||||
{} for _ in range(max(map(len, sample.values())))
|
||||
]
|
||||
uncollated: list[dict[K, V]] = [{} for _ in range(max(map(len, sample.values())))]
|
||||
for key, values in sample.items():
|
||||
for i, value in enumerate(values):
|
||||
uncollated[i][key] = value
|
||||
return uncollated
|
||||
|
||||
|
||||
def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
||||
def stack_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]:
|
||||
"""Stack a list of samples along a new axis.
|
||||
|
||||
Useful for forming a mini-batch of samples to pass to
|
||||
|
@ -419,14 +419,14 @@ def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
|||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
|
||||
collated: dict[K, V] = _list_dict_to_dict_list(samples)
|
||||
for key, value in collated.items():
|
||||
if isinstance(value[0], Tensor):
|
||||
collated[key] = torch.stack(value)
|
||||
return collated
|
||||
|
||||
|
||||
def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
||||
def concat_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]:
|
||||
"""Concatenate a list of samples along an existing axis.
|
||||
|
||||
Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`.
|
||||
|
@ -439,7 +439,7 @@ def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
|||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
|
||||
collated: dict[K, V] = _list_dict_to_dict_list(samples)
|
||||
for key, value in collated.items():
|
||||
if isinstance(value[0], Tensor):
|
||||
collated[key] = torch.cat(value)
|
||||
|
@ -448,7 +448,7 @@ def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
|||
return collated
|
||||
|
||||
|
||||
def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
||||
def merge_samples(samples: Iterable[Mapping[K, V]]) -> dict[K, V]:
|
||||
"""Merge a list of samples.
|
||||
|
||||
Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`.
|
||||
|
@ -461,7 +461,7 @@ def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
|||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
collated: dict[Any, Any] = {}
|
||||
collated: dict[K, V] = {}
|
||||
for sample in samples:
|
||||
for key, value in sample.items():
|
||||
if key in collated and isinstance(value, Tensor):
|
||||
|
@ -473,7 +473,7 @@ def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
|
|||
return collated
|
||||
|
||||
|
||||
def unbind_samples(sample: Mapping[Any, Sequence[Any]]) -> list[dict[Any, Any]]:
|
||||
def unbind_samples(sample: Mapping[K, Sequence[V] | Tensor]) -> list[dict[K, V]]:
|
||||
"""Reverse of :func:`stack_samples`.
|
||||
|
||||
Useful for turning a mini-batch of samples into a list of samples. These individual
|
||||
|
|
Загрузка…
Ссылка в новой задаче