toasty/toast.py: finish exposing the filtered sampling functionality in a nice API

This commit is contained in:
Peter Williams 2021-10-13 14:55:57 -04:00
Родитель 95739b9995
Коммит 7293e68d54
2 изменённых файлов: 82 добавлений и 8 удалений

Просмотреть файл

@ -0,0 +1,6 @@
sample_layer_filtered
=====================
.. currentmodule:: toasty.toast
.. autofunction:: sample_layer_filtered

Просмотреть файл

@ -29,6 +29,7 @@ create_single_tile
generate_tiles generate_tiles
generate_tiles_filtered generate_tiles_filtered
sample_layer sample_layer
sample_layer_filtered
Tile Tile
ToastCoordinateSystem ToastCoordinateSystem
toast_pixel_for_point toast_pixel_for_point
@ -37,16 +38,14 @@ toast_tile_for_point
toast_tile_get_coords toast_tile_get_coords
'''.split() '''.split()
from collections import defaultdict, namedtuple from collections import namedtuple
from enum import Enum from enum import Enum
import os
import logging
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ._libtoasty import subsample, mid from ._libtoasty import subsample, mid
from .image import Image from .image import Image
from .pyramid import Pos, depth2tiles, is_subtile, pos_parent, tiles_at_depth from .pyramid import Pos, tiles_at_depth
HALFPI = 0.5 * np.pi HALFPI = 0.5 * np.pi
THREEHALFPI = 1.5 * np.pi THREEHALFPI = 1.5 * np.pi
@ -483,7 +482,7 @@ def create_single_tile(pos, coordsys=ToastCoordinateSystem.ASTRONOMICAL):
""" """
if pos.n == 0: if pos.n == 0:
raise ArgumentError('cannot create a Tile for the n=0 tile') raise ValueError('cannot create a Tile for the n=0 tile')
children = _create_level1_tiles(coordsys) children = _create_level1_tiles(coordsys)
cur_n = 0 cur_n = 0
@ -700,7 +699,75 @@ def _mp_sample_worker(queue, done_event, pio, sampler, format):
pio.write_image(tile.pos, Image.from_array(sampled_data), format=format) pio.write_image(tile.pos, Image.from_array(sampled_data), format=format)
def _sample_filtered_parallel(pio, format, tile_filter, sampler, depth, coordsys, cli_progress, parallel): def sample_layer_filtered(
pio,
tile_filter,
sampler,
depth,
coordsys=ToastCoordinateSystem.ASTRONOMICAL,
parallel=None,
cli_progress=False,
):
"""Populate a subset of a layer of the TOAST tile pyramid through direct sampling.
Parameters
----------
pio : :class:`toasty.pyramid.PyramidIO`
A :class:`~toasty.pyramid.PyramidIO` instance to manage the I/O with
the tiles in the tile pyramid.
tile_filter : callable
A tile filtering function, suitable for passing to
:func:`toasty.toast.generate_tiles_filtered`.
sampler : callable
The sampler callable that will produce data for tiling.
depth : int
The depth of the layer of the TOAST tile pyramid to generate. The
number of tiles in each layer is ``4**depth``. Each tile is 256×256
TOAST pixels, so the resolution of the pixelization at which the
data will be sampled is a refinement level of ``2**(depth + 8)``.
coordsys : optional :class:`ToastCoordinateSystem`
The TOAST coordinate system to use. Default is
:attr:`ToastCoordinateSystem.ASTRONOMICAL`.
parallel : integer or None (the default)
The level of parallelization to use. If unspecified, defaults to using
all CPUs. If the OS does not support fork-based multiprocessing,
parallel processing is not possible and serial processing will be
forced. Pass ``1`` to force serial processing.
cli_progress : optional boolean, defaults False
If true, a progress bar will be printed to the terminal using tqdm.
"""
from .par_util import resolve_parallelism
parallel = resolve_parallelism(parallel)
if parallel > 1:
_sample_filtered_parallel(pio, tile_filter, sampler, depth, coordsys, cli_progress, parallel)
else:
_sample_filtered_serial(pio, tile_filter, sampler, depth, coordsys, cli_progress)
def _sample_filtered_serial(pio, tile_filter, sampler, depth, coordsys, cli_progress):
n_todo = count_tiles_matching_filter(depth, tile_filter, bottom_only=True, coordsys=coordsys)
with tqdm(total=n_todo, disable=not cli_progress) as progress:
for tile in generate_tiles_filtered(depth, tile_filter, bottom_only=True, coordsys=coordsys):
lon, lat = toast_tile_get_coords(tile)
sampled_data = sampler(lon, lat)
img = Image.from_array(sampled_data)
with pio.update_image(tile.pos, masked_mode=img.mode, default='masked') as basis:
img.update_into_maskable_buffer(basis, slice(None), slice(None), slice(None), slice(None))
progress.update(1)
if cli_progress:
print()
# do not clean lockfiles, for HPC contexts where we're processing different
# chunks in parallel.
def _sample_filtered_parallel(pio, tile_filter, sampler, depth, coordsys, cli_progress, parallel):
import multiprocessing as mp import multiprocessing as mp
n_todo = count_tiles_matching_filter(depth, tile_filter, bottom_only=True, coordsys=coordsys) n_todo = count_tiles_matching_filter(depth, tile_filter, bottom_only=True, coordsys=coordsys)
@ -710,7 +777,7 @@ def _sample_filtered_parallel(pio, format, tile_filter, sampler, depth, coordsys
workers = [] workers = []
for _ in range(parallel): for _ in range(parallel):
w = mp.Process(target=_mp_sample_filtered, args=(queue, done_event, pio, sampler, format)) w = mp.Process(target=_mp_sample_filtered, args=(queue, done_event, pio, sampler))
w.daemon = True w.daemon = True
w.start() w.start()
workers.append(w) workers.append(w)
@ -737,7 +804,8 @@ def _sample_filtered_parallel(pio, format, tile_filter, sampler, depth, coordsys
# do not clean lockfiles, for HPC contexts where we're processing different # do not clean lockfiles, for HPC contexts where we're processing different
# chunks in parallel. # chunks in parallel.
def _mp_sample_filtered(queue, done_event, pio, sampler, format):
def _mp_sample_filtered(queue, done_event, pio, sampler):
""" """
Process tiles on the queue. Process tiles on the queue.