447 строки
15 KiB
Python
447 строки
15 KiB
Python
# -*- coding: utf-8 -*-
|
|
# File: viz.py
|
|
# Credit: zxytim
|
|
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
|
|
from ..utils.develop import create_dummy_func # noqa
|
|
from .argtools import shape2d
|
|
from .fs import mkdir_p
|
|
|
|
try:
|
|
import cv2
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
__all__ = ['interactive_imshow',
|
|
'stack_patches', 'gen_stack_patches',
|
|
'dump_dataflow_images', 'intensity_to_rgb',
|
|
'draw_boxes']
|
|
|
|
|
|
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
|
|
"""
|
|
Args:
|
|
img (np.ndarray): an image (expect BGR) to show.
|
|
lclick_cb, rclick_cb: a callback ``func(img, x, y)`` for left/right click event.
|
|
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
|
|
specify a callback ``func(img)`` for keypress.
|
|
|
|
Some existing keypress event handler:
|
|
|
|
* q: destroy the current window
|
|
* x: execute ``sys.exit()``
|
|
* s: save image to "out.png"
|
|
"""
|
|
name = 'tensorpack_viz_window'
|
|
cv2.imshow(name, img)
|
|
|
|
def mouse_cb(event, x, y, *args):
|
|
if event == cv2.EVENT_LBUTTONUP and lclick_cb is not None:
|
|
lclick_cb(img, x, y)
|
|
elif event == cv2.EVENT_RBUTTONUP and rclick_cb is not None:
|
|
rclick_cb(img, x, y)
|
|
cv2.setMouseCallback(name, mouse_cb)
|
|
key = cv2.waitKey(-1)
|
|
while key >= 128:
|
|
key = cv2.waitKey(-1)
|
|
key = chr(key & 0xff)
|
|
cb_name = 'key_cb_' + key
|
|
if cb_name in kwargs:
|
|
kwargs[cb_name](img)
|
|
elif key == 'q':
|
|
cv2.destroyWindow(name)
|
|
elif key == 'x':
|
|
sys.exit()
|
|
elif key == 's':
|
|
cv2.imwrite('out.png', img)
|
|
elif key in ['+', '=']:
|
|
img = cv2.resize(img, None, fx=1.3, fy=1.3, interpolation=cv2.INTER_CUBIC)
|
|
interactive_imshow(img, lclick_cb, rclick_cb, **kwargs)
|
|
elif key == '-':
|
|
img = cv2.resize(img, None, fx=0.7, fy=0.7, interpolation=cv2.INTER_CUBIC)
|
|
interactive_imshow(img, lclick_cb, rclick_cb, **kwargs)
|
|
|
|
|
|
def _preprocess_patch_list(plist):
|
|
plist = np.asarray(plist)
|
|
assert plist.dtype != np.object
|
|
if plist.ndim == 3:
|
|
plist = plist[:, :, :, np.newaxis]
|
|
assert plist.ndim == 4 and plist.shape[3] in [1, 3], plist.shape
|
|
return plist
|
|
|
|
|
|
def _pad_patch_list(plist, bgcolor):
|
|
if isinstance(bgcolor, int):
|
|
bgcolor = (bgcolor, bgcolor, bgcolor)
|
|
|
|
def _pad_channel(plist):
|
|
ret = []
|
|
for p in plist:
|
|
if len(p.shape) == 2:
|
|
p = p[:, :, np.newaxis]
|
|
if p.shape[2] == 1:
|
|
p = np.repeat(p, 3, 2)
|
|
ret.append(p)
|
|
return ret
|
|
|
|
plist = _pad_channel(plist)
|
|
shapes = [x.shape for x in plist]
|
|
ph = max([s[0] for s in shapes])
|
|
pw = max([s[1] for s in shapes])
|
|
|
|
ret = np.zeros((len(plist), ph, pw, 3), dtype=plist[0].dtype)
|
|
ret[:, :, :] = bgcolor
|
|
for idx, p in enumerate(plist):
|
|
s = p.shape
|
|
sh = (ph - s[0]) // 2
|
|
sw = (pw - s[1]) // 2
|
|
ret[idx, sh:sh + s[0], sw:sw + s[1], :] = p
|
|
return ret
|
|
|
|
|
|
class Canvas(object):
|
|
def __init__(self, ph, pw,
|
|
nr_row, nr_col,
|
|
channel, border, bgcolor):
|
|
self.ph = ph
|
|
self.pw = pw
|
|
self.nr_row = nr_row
|
|
self.nr_col = nr_col
|
|
|
|
if border is None:
|
|
border = int(0.05 * min(ph, pw))
|
|
self.border = border
|
|
|
|
if isinstance(bgcolor, int):
|
|
bgchannel = 1
|
|
else:
|
|
bgchannel = 3
|
|
self.bgcolor = bgcolor
|
|
self.channel = max(channel, bgchannel)
|
|
|
|
self.canvas = np.zeros((nr_row * (ph + border) - border,
|
|
nr_col * (pw + border) - border,
|
|
self.channel), dtype='uint8')
|
|
|
|
def draw_patches(self, plist):
|
|
assert self.nr_row * self.nr_col >= len(plist), \
|
|
"{}*{} < {}".format(self.nr_row, self.nr_col, len(plist))
|
|
if self.channel == 3 and plist.shape[3] == 1:
|
|
plist = np.repeat(plist, 3, axis=3)
|
|
cur_row, cur_col = 0, 0
|
|
if self.channel == 1:
|
|
self.canvas.fill(self.bgcolor)
|
|
else:
|
|
self.canvas[:, :, :] = self.bgcolor
|
|
for patch in plist:
|
|
r0 = cur_row * (self.ph + self.border)
|
|
c0 = cur_col * (self.pw + self.border)
|
|
self.canvas[r0:r0 + self.ph, c0:c0 + self.pw] = patch
|
|
cur_col += 1
|
|
if cur_col == self.nr_col:
|
|
cur_col = 0
|
|
cur_row += 1
|
|
|
|
def get_patchid_from_coord(self, x, y):
|
|
x = x // (self.pw + self.border)
|
|
y = y // (self.pw + self.border)
|
|
idx = y * self.nr_col + x
|
|
return idx
|
|
|
|
|
|
def stack_patches(
|
|
patch_list, nr_row, nr_col, border=None,
|
|
pad=False, bgcolor=255, viz=False, lclick_cb=None):
|
|
"""
|
|
Stacked patches into grid, to produce visualizations like the following:
|
|
|
|
.. image:: https://github.com/tensorpack/tensorpack/raw/master/examples/GAN/demo/BEGAN-CelebA-samples.jpg
|
|
|
|
Args:
|
|
patch_list(list[ndarray] or ndarray): NHW or NHWC images in [0,255].
|
|
nr_row(int), nr_col(int): rows and cols of the grid.
|
|
``nr_col * nr_row`` must be no less than ``len(patch_list)``.
|
|
border(int): border length between images.
|
|
Defaults to ``0.05 * min(patch_width, patch_height)``.
|
|
pad (boolean): when `patch_list` is a list, pad all patches to the maximum height and width.
|
|
This option allows stacking patches of different shapes together.
|
|
bgcolor(int or 3-tuple): background color in [0, 255]. Either an int
|
|
or a BGR tuple.
|
|
viz(bool): whether to use :func:`interactive_imshow` to visualize the results.
|
|
lclick_cb: A callback function ``f(patch, patch index in patch_list)``
|
|
to get called when a patch get clicked in imshow.
|
|
|
|
Returns:
|
|
np.ndarray: the stacked image.
|
|
"""
|
|
if pad:
|
|
patch_list = _pad_patch_list(patch_list, bgcolor)
|
|
patch_list = _preprocess_patch_list(patch_list)
|
|
|
|
if lclick_cb is not None:
|
|
viz = True
|
|
ph, pw = patch_list.shape[1:3]
|
|
|
|
canvas = Canvas(ph, pw, nr_row, nr_col,
|
|
patch_list.shape[-1], border, bgcolor)
|
|
|
|
if lclick_cb is not None:
|
|
def lclick_callback(img, x, y):
|
|
idx = canvas.get_patchid_from_coord(x, y)
|
|
lclick_cb(patch_list[idx], idx)
|
|
else:
|
|
lclick_callback = None
|
|
|
|
canvas.draw_patches(patch_list)
|
|
if viz:
|
|
interactive_imshow(canvas.canvas, lclick_cb=lclick_callback)
|
|
return canvas.canvas
|
|
|
|
|
|
def gen_stack_patches(patch_list,
|
|
nr_row=None, nr_col=None, border=None,
|
|
max_width=1000, max_height=1000,
|
|
bgcolor=255, viz=False, lclick_cb=None):
|
|
"""
|
|
Similar to :func:`stack_patches` but with a generator interface.
|
|
It takes a much-longer list and yields stacked results one by one.
|
|
For example, if ``patch_list`` contains 1000 images and ``nr_row==nr_col==10``,
|
|
this generator yields 10 stacked images.
|
|
|
|
Args:
|
|
nr_row(int), nr_col(int): rows and cols of each result.
|
|
max_width(int), max_height(int): Maximum allowed size of the
|
|
stacked image. If ``nr_row/nr_col`` are None, this number
|
|
will be used to infer the rows and cols. Otherwise the option is
|
|
ignored.
|
|
patch_list, border, viz, lclick_cb: same as in :func:`stack_patches`.
|
|
|
|
Yields:
|
|
np.ndarray: the stacked image.
|
|
"""
|
|
# setup parameters
|
|
patch_list = _preprocess_patch_list(patch_list)
|
|
if lclick_cb is not None:
|
|
viz = True
|
|
ph, pw = patch_list.shape[1:3]
|
|
|
|
if border is None:
|
|
border = int(0.05 * min(ph, pw))
|
|
if nr_row is None:
|
|
nr_row = int(max_height / (ph + border))
|
|
if nr_col is None:
|
|
nr_col = int(max_width / (pw + border))
|
|
canvas = Canvas(ph, pw, nr_row, nr_col, patch_list.shape[-1], border, bgcolor)
|
|
|
|
nr_patch = nr_row * nr_col
|
|
start = 0
|
|
|
|
if lclick_cb is not None:
|
|
def lclick_callback(img, x, y):
|
|
idx = canvas.get_patchid_from_coord(x, y)
|
|
idx = idx + start
|
|
if idx < end:
|
|
lclick_cb(patch_list[idx], idx)
|
|
else:
|
|
lclick_callback = None
|
|
|
|
while True:
|
|
end = start + nr_patch
|
|
cur_list = patch_list[start:end]
|
|
if not len(cur_list):
|
|
return
|
|
canvas.draw_patches(cur_list)
|
|
if viz:
|
|
interactive_imshow(canvas.canvas, lclick_cb=lclick_callback)
|
|
yield canvas.canvas
|
|
start = end
|
|
|
|
|
|
def dump_dataflow_images(df, index=0, batched=True,
|
|
number=1000, output_dir=None,
|
|
scale=1, resize=None, viz=None,
|
|
flipRGB=False):
|
|
"""
|
|
Dump or visualize images of a :class:`DataFlow`.
|
|
|
|
Args:
|
|
df (DataFlow): the DataFlow.
|
|
index (int): the index of the image component.
|
|
batched (bool): whether the component contains batched images (NHW or
|
|
NHWC) or not (HW or HWC).
|
|
number (int): how many datapoint to take from the DataFlow.
|
|
output_dir (str): output directory to save images, default to not save.
|
|
scale (float): scale the value, usually either 1 or 255.
|
|
resize (tuple or None): tuple of (h, w) to resize the images to.
|
|
viz (tuple or None): tuple of (h, w) determining the grid size to use
|
|
with :func:`gen_stack_patches` for visualization. No visualization will happen by
|
|
default.
|
|
flipRGB (bool): apply a RGB<->BGR conversion or not.
|
|
"""
|
|
if output_dir:
|
|
mkdir_p(output_dir)
|
|
if viz is not None:
|
|
viz = shape2d(viz)
|
|
vizsize = viz[0] * viz[1]
|
|
if resize is not None:
|
|
resize = tuple(shape2d(resize))
|
|
vizlist = []
|
|
|
|
df.reset_state()
|
|
cnt = 0
|
|
while True:
|
|
for dp in df:
|
|
if not batched:
|
|
imgbatch = [dp[index]]
|
|
else:
|
|
imgbatch = dp[index]
|
|
for img in imgbatch:
|
|
cnt += 1
|
|
if cnt == number:
|
|
return
|
|
if scale != 1:
|
|
img = img * scale
|
|
if resize is not None:
|
|
img = cv2.resize(img, resize)
|
|
if flipRGB:
|
|
img = img[:, :, ::-1]
|
|
if output_dir:
|
|
fname = os.path.join(output_dir, '{:03d}.jpg'.format(cnt))
|
|
cv2.imwrite(fname, img)
|
|
if viz is not None:
|
|
vizlist.append(img)
|
|
if viz is not None and len(vizlist) >= vizsize:
|
|
stack_patches(
|
|
vizlist[:vizsize],
|
|
nr_row=viz[0], nr_col=viz[1], viz=True)
|
|
vizlist = vizlist[vizsize:]
|
|
|
|
|
|
def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
|
|
"""
|
|
Convert a 1-channel matrix of intensities to an RGB image employing a colormap.
|
|
This function requires matplotlib. See `matplotlib colormaps
|
|
<http://matplotlib.org/examples/color/colormaps_reference.html>`_ for a
|
|
list of available colormap.
|
|
|
|
Args:
|
|
intensity (np.ndarray): array of intensities such as saliency.
|
|
cmap (str): name of the colormap to use.
|
|
normalize (bool): if True, will normalize the intensity so that it has
|
|
minimum 0 and maximum 1.
|
|
|
|
Returns:
|
|
np.ndarray: an RGB float32 image in range [0, 255], a colored heatmap.
|
|
"""
|
|
assert intensity.ndim == 2, intensity.shape
|
|
intensity = intensity.astype("float")
|
|
|
|
if normalize:
|
|
intensity -= intensity.min()
|
|
intensity /= intensity.max()
|
|
|
|
cmap = plt.get_cmap(cmap)
|
|
intensity = cmap(intensity)[..., :3]
|
|
return intensity.astype('float32') * 255.0
|
|
|
|
|
|
def draw_text(img, pos, text, color, font_scale=0.4):
|
|
"""
|
|
Draw text on an image.
|
|
|
|
Args:
|
|
pos (tuple): x, y; the position of the text
|
|
text (str):
|
|
font_scale (float):
|
|
color (tuple): a 3-tuple BGR color in [0, 255]
|
|
"""
|
|
img = img.astype(np.uint8)
|
|
x0, y0 = int(pos[0]), int(pos[1])
|
|
# Compute text size.
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
((text_w, text_h), _) = cv2.getTextSize(text, font, font_scale, 1)
|
|
# Place text background.
|
|
if x0 + text_w > img.shape[1]:
|
|
x0 = img.shape[1] - text_w
|
|
if y0 - int(1.15 * text_h) < 0:
|
|
y0 = int(1.15 * text_h)
|
|
back_topleft = x0, y0 - int(1.3 * text_h)
|
|
back_bottomright = x0 + text_w, y0
|
|
cv2.rectangle(img, back_topleft, back_bottomright, color, -1)
|
|
# Show text.
|
|
text_bottomleft = x0, y0 - int(0.25 * text_h)
|
|
cv2.putText(img, text, text_bottomleft, font, font_scale, (222, 222, 222), lineType=cv2.LINE_AA)
|
|
return img
|
|
|
|
|
|
def draw_boxes(im, boxes, labels=None, color=None):
|
|
"""
|
|
Args:
|
|
im (np.ndarray): a BGR image in range [0,255]. It will not be modified.
|
|
boxes (np.ndarray): a numpy array of shape Nx4 where each row is [x1, y1, x2, y2].
|
|
labels: (list[str] or None)
|
|
color: a 3-tuple BGR color (in range [0, 255])
|
|
|
|
Returns:
|
|
np.ndarray: a new image.
|
|
"""
|
|
boxes = np.asarray(boxes, dtype='int32')
|
|
if labels is not None:
|
|
assert len(labels) == len(boxes), "{} != {}".format(len(labels), len(boxes))
|
|
areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)
|
|
sorted_inds = np.argsort(-areas) # draw large ones first
|
|
assert areas.min() > 0, areas.min()
|
|
# allow equal, because we are not very strict about rounding error here
|
|
assert boxes[:, 0].min() >= 0 and boxes[:, 1].min() >= 0 \
|
|
and boxes[:, 2].max() <= im.shape[1] and boxes[:, 3].max() <= im.shape[0], \
|
|
"Image shape: {}\n Boxes:\n{}".format(str(im.shape), str(boxes))
|
|
|
|
im = im.copy()
|
|
if color is None:
|
|
color = (15, 128, 15)
|
|
if im.ndim == 2 or (im.ndim == 3 and im.shape[2] == 1):
|
|
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
|
|
for i in sorted_inds:
|
|
box = boxes[i, :]
|
|
if labels is not None:
|
|
im = draw_text(im, (box[0], box[1]), labels[i], color=color)
|
|
cv2.rectangle(im, (box[0], box[1]), (box[2], box[3]),
|
|
color=color, thickness=1)
|
|
return im
|
|
|
|
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except (ImportError, RuntimeError):
|
|
intensity_to_rgb = create_dummy_func('intensity_to_rgb', 'matplotlib') # noqa
|
|
|
|
if __name__ == '__main__':
|
|
if False:
|
|
imglist = []
|
|
for i in range(100):
|
|
fname = "{:03d}.png".format(i)
|
|
imglist.append(cv2.imread(fname))
|
|
for idx, patch in enumerate(gen_stack_patches(
|
|
imglist, max_width=500, max_height=200)):
|
|
of = "patch{:02d}.png".format(idx)
|
|
cv2.imwrite(of, patch)
|
|
if False:
|
|
imglist = []
|
|
img = cv2.imread('out.png')
|
|
img2 = cv2.resize(img, (300, 300))
|
|
viz = stack_patches([img, img2], 1, 2, pad=True, viz=True)
|
|
|
|
if False:
|
|
img = cv2.imread('cat.jpg')
|
|
boxes = np.asarray([
|
|
[10, 30, 200, 100],
|
|
[20, 80, 250, 250]
|
|
])
|
|
img = draw_boxes(img, boxes, ['asdfasdf', '11111111111111'])
|
|
interactive_imshow(img)
|