зеркало из https://github.com/microsoft/torchgeo.git
SSL4EO download script: fix type hints (#1319)
This commit is contained in:
Родитель
71732511fb
Коммит
63c1840c3e
|
@ -54,7 +54,7 @@ import os
|
|||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from datetime import date, datetime, timedelta
|
||||
from datetime import date, timedelta
|
||||
from multiprocessing.dummy import Lock, Pool
|
||||
from typing import Any
|
||||
|
||||
|
@ -66,11 +66,11 @@ from rasterio.transform import Affine
|
|||
warnings.simplefilter("ignore", UserWarning)
|
||||
|
||||
|
||||
def date2str(date: datetime) -> str:
|
||||
def date2str(date: date) -> str:
|
||||
return date.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def get_period(date: datetime, days: int = 5) -> tuple[str, str, str, str]:
|
||||
def get_period(date: date, days: int = 5) -> tuple[str, str, str, str]:
|
||||
date1 = date - timedelta(days=days / 2)
|
||||
date2 = date + timedelta(days=days / 2)
|
||||
date3 = date1 - timedelta(days=365)
|
||||
|
@ -86,7 +86,7 @@ def get_period(date: datetime, days: int = 5) -> tuple[str, str, str, str]:
|
|||
"""get collection and remove clouds from ee"""
|
||||
|
||||
|
||||
def mask_clouds(args: Any, image: ee.Image) -> ee.Image:
|
||||
def mask_clouds(args: argparse.Namespace, image: ee.Image) -> ee.Image:
|
||||
qa = image.select(args.qa_band)
|
||||
cloudBitMask = 1 << args.qa_cloud_bit
|
||||
# Both flags should be set to zero, indicating clear conditions.
|
||||
|
@ -110,7 +110,7 @@ def get_collection(
|
|||
|
||||
def filter_collection(
|
||||
collection: ee.ImageCollection,
|
||||
coords: list[float],
|
||||
coords: tuple[float, float],
|
||||
period: tuple[str, str, str, str],
|
||||
) -> ee.ImageCollection:
|
||||
filtered = collection
|
||||
|
@ -133,8 +133,8 @@ def filter_collection(
|
|||
|
||||
|
||||
def center_crop(
|
||||
img: np.ndarray[Any, np.dtype[Any]], out_size: int
|
||||
) -> np.ndarray[Any, np.dtype[Any]]:
|
||||
img: "np.typing.NDArray[np.float32]", out_size: int
|
||||
) -> "np.typing.NDArray[np.float32]":
|
||||
image_height, image_width = img.shape[:2]
|
||||
crop_height = crop_width = out_size
|
||||
pad_height = max(crop_height - image_height, 0)
|
||||
|
@ -163,7 +163,7 @@ def adjust_coords(
|
|||
|
||||
def get_patch(
|
||||
collection: ee.ImageCollection,
|
||||
center_coord: list[float],
|
||||
center_coord: tuple[float, float],
|
||||
radius: float,
|
||||
bands: list[str],
|
||||
original_resolutions: list[int],
|
||||
|
@ -215,13 +215,13 @@ def get_random_patches_match(
|
|||
new_resolutions: list[int],
|
||||
dtype: str,
|
||||
meta_cloud_name: str,
|
||||
dates: list[Any],
|
||||
dates: list[date],
|
||||
radius: float,
|
||||
debug: bool = False,
|
||||
match_coords: dict[str, Any] = {},
|
||||
) -> tuple[list[dict[str, Any]], list[float]]:
|
||||
match_coords: dict[int, tuple[float, float]] = {},
|
||||
) -> tuple[list[dict[str, Any]], tuple[float, float]]:
|
||||
# (lon,lat) of idx patch
|
||||
coords = match_coords[str(idx)]
|
||||
coords = match_coords[idx]
|
||||
|
||||
# random +- 30 days of random days within 1 year from the reference dates
|
||||
periods = [get_period(date, days=60) for date in dates]
|
||||
|
@ -252,7 +252,9 @@ def get_random_patches_match(
|
|||
|
||||
|
||||
def save_geotiff(
|
||||
img: np.ndarray[Any, np.dtype[Any]], coords: list[list[float]], filename: str
|
||||
img: "np.typing.NDArray[np.float32]",
|
||||
coords: list[tuple[float, float]],
|
||||
filename: str,
|
||||
) -> None:
|
||||
height, width, channels = img.shape
|
||||
xres = (coords[1][0] - coords[0][0]) / width
|
||||
|
@ -275,8 +277,8 @@ def save_geotiff(
|
|||
|
||||
|
||||
def save_patch(
|
||||
raster: dict[int, Any],
|
||||
coords: list[list[float]],
|
||||
raster: dict[int, "np.typing.NDArray[np.float32]"],
|
||||
coords: list[tuple[float, float]],
|
||||
metadata: dict[str, Any],
|
||||
bands: list[str],
|
||||
new_resolutions: list[int],
|
||||
|
@ -445,7 +447,7 @@ if __name__ == "__main__":
|
|||
with open(ext_path) as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
for row in reader:
|
||||
key = row[0]
|
||||
key = int(row[0])
|
||||
val1 = float(row[1])
|
||||
val2 = float(row[2])
|
||||
ext_coords[key] = (val1, val2) # lon, lat
|
||||
|
@ -458,7 +460,7 @@ if __name__ == "__main__":
|
|||
with open(args.match_file) as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
for row in reader:
|
||||
key = row[0]
|
||||
key = int(row[0])
|
||||
val1 = float(row[1])
|
||||
val2 = float(row[2])
|
||||
match_coords[key] = (val1, val2) # lon, lat
|
||||
|
@ -467,7 +469,7 @@ if __name__ == "__main__":
|
|||
counter = Counter()
|
||||
|
||||
def worker(idx: int) -> None:
|
||||
if str(idx) in ext_coords.keys():
|
||||
if idx in ext_coords.keys():
|
||||
return
|
||||
|
||||
worker_start = time.time()
|
||||
|
|
Загрузка…
Ссылка в новой задаче