toasty/pyramid.py: add Pyramid.visit_leaves()

This commit is contained in:
Peter Williams 2022-09-03 09:06:36 -04:00
Родитель c1d5d03a18
Коммит 05a5089662
1 изменённых файлов: 149 добавлений и 4 удалений

Просмотреть файл

@ -28,6 +28,7 @@ import glob
from collections import namedtuple
from contextlib import contextmanager
import os.path
import time
from .image import ImageLoader, SUPPORTED_FORMATS, get_format_vertical_parity_sign
from .progress import progress_bar
@ -888,8 +889,6 @@ class Pyramid(object):
self._walk_serial(callback, cli_progress)
def _walk_serial(self, callback, cli_progress):
import time
if self.depth > 9 and self._tile_filter is not None:
# This is around where there are enough tiles that the prep stage
# might take a noticeable amount of time.
@ -925,8 +924,6 @@ class Pyramid(object):
def _walk_parallel(self, callback, cli_progress, parallel):
import multiprocessing as mp
from queue import Empty
import time
from tqdm import tqdm
# When dispatching we keep track of finished tiles (reported in
# `done_queue`) and notify workers when new tiles are ready to process
@ -1052,6 +1049,137 @@ class Pyramid(object):
for w in workers:
w.join()
def visit_leaves(
self,
callback,
parallel=None,
cli_progress=False,
):
"""Traverse the pyramid, calling the callback for each
leaf tile.
Parameters
----------
callback : function(:class:`Pos`, Optional[:class:`~toasty.toast.Tile`])
-> None
A function to be called for all of the leaves.
parallel : integer or None (the default)
The level of parallelization to use. If unspecified, defaults to
using all CPUs. If the OS does not support fork-based
multiprocessing, parallel processing is not possible and serial
processing will be forced. Pass ``1`` to force serial processing.
cli_progress : optional boolean, defaults False
If true, a progress bar will be printed to the terminal.
Returns
-------
None.
Notes
-----
Use this function to visit all of the leaves of the pyramid, with the
possibility of parallelizing the computation substantially using
Python's multiprocessing framework.
Here, "all" of the leaves means ones that have not been filtered out by
the use of a TOAST tile filter and/or subpyramid selection.
The callback is passed the position of the leaf tile and, if the pyramid
has been defined as a TOAST tile pyramid, its TOAST coordinate
information. If not, the second argument to the callback will be None.
In the corner case that this pyramid has depth 0, the TOAST tile
information will be None even if this is a TOAST pyramid, because the
TOAST tile information is not well-defined at level 0.
In the parallelized case, callbacks may occur in different processes and
so cannot communicate with each other in memory, nor can they
communicate with the parent process (that is, the process in which this
function was invoked). No guarantees are made about the order in which
leaf tiles are visited.
"""
from .par_util import resolve_parallelism
parallel = resolve_parallelism(parallel)
# Unlike walk(), where the parallel case needs some extra logic to
# maintain our ordering guarantees, the serial and parallel cases here
# are pretty similar, so we can reuse more code. First, assess the
# amount of work to do.
if self.depth > 9 and self._tile_filter is not None:
# This is around where there are enough tiles that the prep stage
# might take a noticeable amount of time.
print("Counting tiles ...")
t0 = time.time()
total = self.count_leaf_tiles()
if self.depth > 9 and self._tile_filter is not None:
print(f"... {time.time() - t0:.1f}s elapsed")
if total == 0:
print("- Nothing to do.")
return
# Now the meat of it.
if parallel > 1:
self._visit_leaves_parallel(callback, total, cli_progress, parallel)
else:
self._visit_leaves_serial(callback, total, cli_progress)
def _visit_leaves_serial(self, callback, total, cli_progress):
riter = self._make_iter_reducer()
with progress_bar(total=total, show=cli_progress) as progress:
for pos, tile, is_leaf, _data in riter:
if is_leaf:
callback(pos, tile)
progress.update(1)
riter.set_data(None)
def _visit_leaves_parallel(self, callback, total, cli_progress, parallel):
import multiprocessing as mp
ready_queue = mp.Queue(maxsize=2 * parallel)
done_event = mp.Event()
# Create workers:
workers = []
for _ in range(parallel):
w = mp.Process(
target=_mp_visit_worker,
args=(ready_queue, done_event, callback),
)
w.daemon = True
w.start()
workers.append(w)
# Dispatch:
riter = self._make_iter_reducer()
with progress_bar(total=total, show=cli_progress) as progress:
for pos, tile, is_leaf, _data in riter:
if is_leaf:
ready_queue.put((pos, tile))
progress.update(1)
riter.set_data(None)
# All done!
ready_queue.close()
ready_queue.join_thread()
done_event.set()
for w in workers:
w.join()
class PyramidReductionIterator(object):
"""Non-public helper class for a performing a "reduction iteration" over a
@ -1252,3 +1380,20 @@ def _mp_walk_worker(done_queue, ready_queue, done_event, callback):
callback(pos)
done_queue.put(pos)
def _mp_visit_worker(ready_queue, done_event, callback):
"""
Process tiles that are ready.
"""
from queue import Empty
while True:
try:
args = ready_queue.get(True, timeout=1)
except Empty:
if done_event.is_set():
break
continue
callback(*args)