SSL4EO download script: fix type hints (#1319)

This commit is contained in:
Adam J. Stewart 2023-05-10 17:03:47 -05:00 коммит произвёл GitHub
Родитель 71732511fb
Коммит 63c1840c3e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 20 добавлений и 18 удалений

Просмотреть файл

@ -54,7 +54,7 @@ import os
import time
import warnings
from collections import defaultdict
from datetime import date, datetime, timedelta
from datetime import date, timedelta
from multiprocessing.dummy import Lock, Pool
from typing import Any
@ -66,11 +66,11 @@ from rasterio.transform import Affine
warnings.simplefilter("ignore", UserWarning)
def date2str(date: datetime) -> str:
def date2str(date: date) -> str:
return date.strftime("%Y-%m-%d")
def get_period(date: datetime, days: int = 5) -> tuple[str, str, str, str]:
def get_period(date: date, days: int = 5) -> tuple[str, str, str, str]:
date1 = date - timedelta(days=days / 2)
date2 = date + timedelta(days=days / 2)
date3 = date1 - timedelta(days=365)
@ -86,7 +86,7 @@ def get_period(date: datetime, days: int = 5) -> tuple[str, str, str, str]:
"""get collection and remove clouds from ee"""
def mask_clouds(args: Any, image: ee.Image) -> ee.Image:
def mask_clouds(args: argparse.Namespace, image: ee.Image) -> ee.Image:
qa = image.select(args.qa_band)
cloudBitMask = 1 << args.qa_cloud_bit
# Both flags should be set to zero, indicating clear conditions.
@ -110,7 +110,7 @@ def get_collection(
def filter_collection(
collection: ee.ImageCollection,
coords: list[float],
coords: tuple[float, float],
period: tuple[str, str, str, str],
) -> ee.ImageCollection:
filtered = collection
@ -133,8 +133,8 @@ def filter_collection(
def center_crop(
img: np.ndarray[Any, np.dtype[Any]], out_size: int
) -> np.ndarray[Any, np.dtype[Any]]:
img: "np.typing.NDArray[np.float32]", out_size: int
) -> "np.typing.NDArray[np.float32]":
image_height, image_width = img.shape[:2]
crop_height = crop_width = out_size
pad_height = max(crop_height - image_height, 0)
@ -163,7 +163,7 @@ def adjust_coords(
def get_patch(
collection: ee.ImageCollection,
center_coord: list[float],
center_coord: tuple[float, float],
radius: float,
bands: list[str],
original_resolutions: list[int],
@ -215,13 +215,13 @@ def get_random_patches_match(
new_resolutions: list[int],
dtype: str,
meta_cloud_name: str,
dates: list[Any],
dates: list[date],
radius: float,
debug: bool = False,
match_coords: dict[str, Any] = {},
) -> tuple[list[dict[str, Any]], list[float]]:
match_coords: dict[int, tuple[float, float]] = {},
) -> tuple[list[dict[str, Any]], tuple[float, float]]:
# (lon,lat) of idx patch
coords = match_coords[str(idx)]
coords = match_coords[idx]
# random +- 30 days of random days within 1 year from the reference dates
periods = [get_period(date, days=60) for date in dates]
@ -252,7 +252,9 @@ def get_random_patches_match(
def save_geotiff(
img: np.ndarray[Any, np.dtype[Any]], coords: list[list[float]], filename: str
img: "np.typing.NDArray[np.float32]",
coords: list[tuple[float, float]],
filename: str,
) -> None:
height, width, channels = img.shape
xres = (coords[1][0] - coords[0][0]) / width
@ -275,8 +277,8 @@ def save_geotiff(
def save_patch(
raster: dict[int, Any],
coords: list[list[float]],
raster: dict[int, "np.typing.NDArray[np.float32]"],
coords: list[tuple[float, float]],
metadata: dict[str, Any],
bands: list[str],
new_resolutions: list[int],
@ -445,7 +447,7 @@ if __name__ == "__main__":
with open(ext_path) as csv_file:
reader = csv.reader(csv_file)
for row in reader:
key = row[0]
key = int(row[0])
val1 = float(row[1])
val2 = float(row[2])
ext_coords[key] = (val1, val2) # lon, lat
@ -458,7 +460,7 @@ if __name__ == "__main__":
with open(args.match_file) as csv_file:
reader = csv.reader(csv_file)
for row in reader:
key = row[0]
key = int(row[0])
val1 = float(row[1])
val2 = float(row[2])
match_coords[key] = (val1, val2) # lon, lat
@ -467,7 +469,7 @@ if __name__ == "__main__":
counter = Counter()
def worker(idx: int) -> None:
if str(idx) in ext_coords.keys():
if idx in ext_coords.keys():
return
worker_start = time.time()