зеркало из https://github.com/microsoft/torchgeo.git
SSL4EO download script: optionally download nodata pixels (#1335)
This commit is contained in:
Родитель
e26213409d
Коммит
6ce1c4ef65
|
@ -22,6 +22,7 @@ YEAR=2002 # SLC-on
|
|||
BANDS=(SR_B1 SR_B2 SR_B3 SR_B4 SR_B5 SR_B7)
|
||||
ORIGINAL_RESOLUTIONS=30
|
||||
NEW_RESOLUTIONS=30
|
||||
DEFAULT_VALUE=0
|
||||
|
||||
# Generic parameters
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
|
@ -43,6 +44,7 @@ time python3 "$SCRIPT_DIR/download_ssl4eo.py" \
|
|||
--original-resolutions $ORIGINAL_RESOLUTIONS \
|
||||
--new-resolutions $NEW_RESOLUTIONS \
|
||||
--dtype $DTYPE \
|
||||
--default-value $DEFAULT_VALUE \
|
||||
--num-workers $NUM_WORKERS \
|
||||
--log-freq $LOG_FREQ \
|
||||
--match-file "$MATCH_FILE" \
|
||||
|
|
|
@ -22,6 +22,7 @@ YEAR=2022
|
|||
BANDS=(SR_B1 SR_B2 SR_B3 SR_B4 SR_B5 SR_B6 SR_B7)
|
||||
ORIGINAL_RESOLUTIONS=30
|
||||
NEW_RESOLUTIONS=30
|
||||
DEFAULT_VALUE=0
|
||||
|
||||
# Generic parameters
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
|
@ -43,6 +44,7 @@ time python3 "$SCRIPT_DIR/download_ssl4eo.py" \
|
|||
--original-resolutions $ORIGINAL_RESOLUTIONS \
|
||||
--new-resolutions $NEW_RESOLUTIONS \
|
||||
--dtype $DTYPE \
|
||||
--default-value $DEFAULT_VALUE \
|
||||
--num-workers $NUM_WORKERS \
|
||||
--log-freq $LOG_FREQ \
|
||||
--match-file "$MATCH_FILE" \
|
||||
|
|
|
@ -55,7 +55,7 @@ import time
|
|||
from collections import defaultdict
|
||||
from datetime import date, timedelta
|
||||
from multiprocessing.dummy import Lock, Pool
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import ee
|
||||
import numpy as np
|
||||
|
@ -167,6 +167,7 @@ def get_patch(
|
|||
new_resolutions: list[int],
|
||||
dtype: str = "float32",
|
||||
meta_cloud_name: str = "CLOUD_COVER",
|
||||
default_value: Optional[float] = None,
|
||||
) -> dict[str, Any]:
|
||||
image = collection.sort(meta_cloud_name).first()
|
||||
region = ee.Geometry.Point(center_coord).buffer(radius).bounds()
|
||||
|
@ -183,7 +184,7 @@ def get_patch(
|
|||
patch = image.select(*bands_group)
|
||||
if orig_res != new_res:
|
||||
patch = patch.reproject(patch.projection().crs(), scale=new_res)
|
||||
patch = patch.sampleRectangle(region)
|
||||
patch = patch.sampleRectangle(region, defaultValue=default_value)
|
||||
features = patch.getInfo()
|
||||
for i, band in zip(indices, bands_group):
|
||||
x = features["properties"][band]
|
||||
|
@ -212,6 +213,7 @@ def get_random_patches_match(
|
|||
new_resolutions: list[int],
|
||||
dtype: str,
|
||||
meta_cloud_name: str,
|
||||
default_value: Optional[float],
|
||||
dates: list[date],
|
||||
radius: float,
|
||||
debug: bool = False,
|
||||
|
@ -237,6 +239,7 @@ def get_random_patches_match(
|
|||
new_resolutions,
|
||||
dtype,
|
||||
meta_cloud_name,
|
||||
default_value,
|
||||
)
|
||||
for c in filtered_collections
|
||||
]
|
||||
|
@ -384,6 +387,10 @@ if __name__ == "__main__":
|
|||
help="new band resolutions in meters",
|
||||
)
|
||||
parser.add_argument("--dtype", type=str, default="float32", help="data type")
|
||||
# If None, don't download patches with nodata pixels
|
||||
parser.add_argument(
|
||||
"--default-value", type=float, default=None, help="default fill value"
|
||||
)
|
||||
# download settings
|
||||
parser.add_argument("--num-workers", type=int, default=8, help="number of workers")
|
||||
parser.add_argument("--log-freq", type=int, default=10, help="print frequency")
|
||||
|
@ -478,6 +485,7 @@ if __name__ == "__main__":
|
|||
new_resolutions,
|
||||
dtype,
|
||||
args.meta_cloud_name,
|
||||
args.default_value,
|
||||
dates,
|
||||
radius=args.radius,
|
||||
debug=args.debug,
|
||||
|
|
Загрузка…
Ссылка в новой задаче