diff --git a/tensorwatch/mpl/base_mpl_plot.py b/tensorwatch/mpl/base_mpl_plot.py index 36d3aef..fa424b9 100644 --- a/tensorwatch/mpl/base_mpl_plot.py +++ b/tensorwatch/mpl/base_mpl_plot.py @@ -1,169 +1,173 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -#from IPython import get_ipython, display -#if get_ipython(): -# get_ipython().magic('matplotlib notebook') - -#import matplotlib -#if os.name == 'posix' and "DISPLAY" not in os.environ: -# matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! - -#from ipywidgets.widgets.interaction import show_inline_matplotlib_plots -#from ipykernel.pylab.backend_inline import flush_figures - -from ..vis_base import VisBase - -import sys, logging -from abc import abstractmethod -from .. import utils - - -class BaseMplPlot(VisBase): - def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None, is_3d:bool=False, - stream_name:str=None, console_debug:bool=False, **vis_args): - super(BaseMplPlot, self).__init__(VisBase.widgets.Output(), cell, title, show_legend, - stream_name=stream_name, console_debug=console_debug, **vis_args) - - self._fig_init_done = False - self.show_legend = show_legend - self.is_3d = is_3d - if is_3d: - # this is needed for some reason - from mpl_toolkits.mplot3d import Axes3D - # graph objects - self.figure = None - self._ax_main = None - # matplotlib animation - self.animation = None - self.anim_interval = None - #print(matplotlib.get_backend()) - #display.display(self.cell) - - # anim_interval in seconds - def init_fig(self, anim_interval:float=1.0): - import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue - - """(for derived class) Initializes matplotlib figure""" - if self._fig_init_done: - return False - - # create figure and animation - self.figure = plt.figure(figsize=(8, 3)) - self.anim_interval = anim_interval - - # default color pallet - import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue - - plt.set_cmap('Dark2') - plt.rcParams['image.cmap']='Dark2' - - self._fig_init_done = True - return True - - def get_main_axis(self): - # if we don't yet have main axis, create one - if not self._ax_main: - # by default assign one subplot to whole graph - self._ax_main = self.figure.add_subplot(111, - projection=None if not self.is_3d else '3d') - self._ax_main.grid(self.is_show_grid()) - # change the color of the top and right spines to opaque gray - self._ax_main.spines['right'].set_color((.8,.8,.8)) - self._ax_main.spines['top'].set_color((.8,.8,.8)) - if self.title is not None: - title = self._ax_main.set_title(self.title) - title.set_weight('bold') - return self._ax_main - - # overridable - def is_show_grid(self): - return True - - def _on_update(self, frame): # pylint: disable=unused-argument - try: - self._update_stream_plots() - except Exception as ex: - # when exception occurs here, animation will stop and there - # will be no further plot updates - # TODO: may be we don't need all of below but none of them - # are popping up exception in Jupyter Notebook because these - # exceptions occur in background? - self.last_ex = ex - logging.exception('Exception in matplotlib update loop') - - - def show(self, blocking=False): - if not self.is_shown and self.anim_interval: - from matplotlib.animation import FuncAnimation # function-level import as this one is expensive - self.animation = FuncAnimation(self.figure, self._on_update, interval=self.anim_interval*1000.0) - super(BaseMplPlot, self).show(blocking) - - def _post_update_stream_plot(self, stream_vis): - import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue - - utils.debug_log("Plot updated", stream_vis.stream.stream_name, verbosity=5) - - if self.layout_dirty: - # do not do tight_layout() call on every update - # that would jumble up the graphs! it should only called - # once each time there is change in layout - self.figure.tight_layout() - self.layout_dirty = False - - # below forces redraw and it was helpful to - # repaint even if there was error in interval loop - # but it does work in native UX and not in Jupyter Notebook - #self.figure.canvas.draw() - #self.figure.canvas.flush_events() - - if self._use_hbox and VisBase.get_ipython(): - self.widget.clear_output(wait=True) - with self.widget: - plt.show(self.figure) - - # everything else that doesn't work - #self.figure.show() - #display.clear_output(wait=True) - #display.display(self.figure) - #flush_figures() - #plt.show() - #show_inline_matplotlib_plots() - #elif not get_ipython(): - # self.figure.canvas.draw() - - def _post_add_subscription(self, stream_vis, **stream_vis_args): - import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue - - # make sure figure is initialized - self.init_fig() - self.init_stream_plot(stream_vis, **stream_vis_args) - - # redo the legend - #self.figure.legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) - if self.show_legend: - self.figure.legend(loc='lower right') - plt.subplots_adjust(hspace=0.6) - - def _show_widget_native(self, blocking:bool): - import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue - - #plt.ion() - #plt.show() - return plt.show(block=blocking) - - def _show_widget_notebook(self): - # no need to return anything because %matplotlib notebook will - # detect spawning of figure and paint it - # if self.figure is returned then you will see two of them - return None - #plt.show() - #return self.figure - - def _can_update_stream_plots(self): - return False # we run interval timer which will flush the key - - @abstractmethod - def init_stream_plot(self, stream_vis, **stream_vis_args): - """(for derived class) Create new plot info for this stream""" - pass +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +#from IPython import get_ipython, display +#if get_ipython(): +# get_ipython().magic('matplotlib notebook') + +#import matplotlib +#if os.name == 'posix' and "DISPLAY" not in os.environ: +# matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! + +#from ipywidgets.widgets.interaction import show_inline_matplotlib_plots +#from ipykernel.pylab.backend_inline import flush_figures + +from ..vis_base import VisBase + +import sys, logging +from abc import abstractmethod +from .. import utils + + +class BaseMplPlot(VisBase): + def __init__(self, cell:VisBase.widgets.Box=None, title:str=None, show_legend:bool=None, is_3d:bool=False, + stream_name:str=None, console_debug:bool=False, **vis_args): + super(BaseMplPlot, self).__init__(VisBase.widgets.Output(), cell, title, show_legend, + stream_name=stream_name, console_debug=console_debug, **vis_args) + + self._fig_init_done = False + self.show_legend = show_legend + self.is_3d = is_3d + if is_3d: + # this is needed for some reason + from mpl_toolkits.mplot3d import Axes3D + # graph objects + self.figure = None + self._ax_main = None + # matplotlib animation + self.animation = None + self.anim_interval = None + #print(matplotlib.get_backend()) + #display.display(self.cell) + + # anim_interval in seconds + def init_fig(self, anim_interval:float=1.0): + import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue + + """(for derived class) Initializes matplotlib figure""" + if self._fig_init_done: + return False + + # create figure and animation + self.figure = plt.figure(figsize=(8, 3)) + self.anim_interval = anim_interval + + # default color pallet + import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue + + plt.set_cmap('Dark2') + plt.rcParams['image.cmap']='Dark2' + + self._fig_init_done = True + return True + + def get_main_axis(self): + # if we don't yet have main axis, create one + if not self._ax_main: + # by default assign one subplot to whole graph + self._ax_main = self.figure.add_subplot(111, + projection=None if not self.is_3d else '3d') + self._ax_main.grid(self.is_show_grid()) + # change the color of the top and right spines to opaque gray + self._ax_main.spines['right'].set_color((.8,.8,.8)) + self._ax_main.spines['top'].set_color((.8,.8,.8)) + if self.title is not None: + title = self._ax_main.set_title(self.title) + title.set_weight('bold') + return self._ax_main + + # overridable + def is_show_grid(self): + return True + + def _on_update(self, frame): # pylint: disable=unused-argument + try: + self._update_stream_plots() + except Exception as ex: + # when exception occurs here, animation will stop and there + # will be no further plot updates + # TODO: may be we don't need all of below but none of them + # are popping up exception in Jupyter Notebook because these + # exceptions occur in background? + self.last_ex = ex + logging.exception('Exception in matplotlib update loop') + + + def show(self, blocking=False): + if not self.is_shown and self.anim_interval: + from matplotlib.animation import FuncAnimation # function-level import as this one is expensive + self.animation = FuncAnimation(self.figure, self._on_update, interval=self.anim_interval*1000.0) + super(BaseMplPlot, self).show(blocking) + + def _post_update_stream_plot(self, stream_vis): + import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue + + utils.debug_log("Plot updated", stream_vis.stream.stream_name, verbosity=5) + + if self.layout_dirty: + # do not do tight_layout() call on every update + # that would jumble up the graphs! it should only called + # once each time there is change in layout + self.figure.tight_layout() + self.layout_dirty = False + + # below forces redraw and it was helpful to + # repaint even if there was error in interval loop + # but it does work in native UX and not in Jupyter Notebook + #self.figure.canvas.draw() + #self.figure.canvas.flush_events() + + if self._use_hbox and VisBase.get_ipython(): + self.widget.clear_output(wait=True) + with self.widget: + plt.show(self.figure) + + # everything else that doesn't work + #self.figure.show() + #display.clear_output(wait=True) + #display.display(self.figure) + #flush_figures() + #plt.show() + #show_inline_matplotlib_plots() + #elif not get_ipython(): + # self.figure.canvas.draw() + + def _post_add_subscription(self, stream_vis, **stream_vis_args): + import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue + + # make sure figure is initialized + self.init_fig() + self.init_stream_plot(stream_vis, **stream_vis_args) + + # redo the legend + #self.figure.legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) + if self.show_legend: + self.figure.legend(loc='lower right') + plt.subplots_adjust(hspace=0.6) + + def _show_widget_native(self, blocking:bool): + import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue + + #plt.ion() + #plt.show() + return plt.show(block=blocking) + + def _show_widget_notebook(self): + # no need to return anything because %matplotlib notebook will + # detect spawning of figure and paint it + # if self.figure is returned then you will see two of them + return None + #plt.show() + #return self.figure + + def _can_update_stream_plots(self): + return False # we run interval timer which will flush the key + + @abstractmethod + def init_stream_plot(self, stream_vis, **stream_vis_args): + """(for derived class) Create new plot info for this stream""" + pass + + def _save_widget(self, filepath:str)->None: + self._update_stream_plots() + self.figure.savefig(filepath) \ No newline at end of file diff --git a/tensorwatch/vis_base.py b/tensorwatch/vis_base.py index 9139128..9a32805 100644 --- a/tensorwatch/vis_base.py +++ b/tensorwatch/vis_base.py @@ -1,180 +1,185 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import sys, time, threading, queue, functools -from typing import Any -from types import MethodType -from abc import ABCMeta, abstractmethod - -from .lv_types import StreamVisInfo, StreamItem -from . import utils -from .stream import Stream - - -class VisBase(Stream, metaclass=ABCMeta): - # these are expensive import so we attach to base class so derived class can use them - from IPython import get_ipython, display - import ipywidgets as widgets - - def __init__(self, widget, cell:widgets.Box, title:str, show_legend:bool, stream_name:str=None, console_debug:bool=False, **vis_args): - super(VisBase, self).__init__(stream_name=stream_name, console_debug=console_debug) - - self.lock = threading.Lock() - self._use_hbox = True - #utils.set_default(vis_args, 'cell_width', '100%') - - self.widget = widget - - self.cell = cell or VisBase.widgets.HBox(layout=VisBase.widgets.Layout(\ - width=vis_args.get('cell_width', None))) if self._use_hbox else None - if self._use_hbox: - self.cell.children += (self.widget,) - self._stream_vises = {} - self.is_shown = cell is not None - self.title = title - self.last_ex = None - self.layout_dirty = False - self.q_last_processed = 0 - - def subscribe(self, stream:Stream, title=None, clear_after_end=False, clear_after_each=False, - show:bool=False, history_len=1, dim_history=True, opacity=None, **stream_vis_args): - # in this ovedrride we don't call base class method - with self.lock: - self.layout_dirty = True - - stream_vis = StreamVisInfo(stream, title, clear_after_end, - clear_after_each, history_len, dim_history, opacity, - len(self._stream_vises), stream_vis_args, 0) - stream_vis._clear_pending = False - stream_vis._pending_items = queue.Queue() - self._stream_vises[stream.stream_name] = stream_vis - - self._post_add_subscription(stream_vis, **stream_vis_args) - - super(VisBase, self).subscribe(stream) - - if show or (show is None and not self.is_shown): - return self.show() - - def show(self, blocking:bool=False): - self.is_shown = True - if VisBase.get_ipython(): - if self._use_hbox: - VisBase.display.display(self.cell) # this method doesn't need returns - #return self.cell - else: - return self._show_widget_notebook() - else: - return self._show_widget_native(blocking) - - def write(self, val:Any, from_stream:'Stream'=None): - stream_item = self.to_stream_item(val) - - stream_vis:StreamVisInfo = None - if from_stream: - stream_vis = self._stream_vises.get(from_stream.stream_name, None) - - if not stream_vis: # select the first one we have - stream_vis = next(iter(self._stream_vises.values())) - - VisBase.write_stream_plot(self, stream_vis, stream_item) - - super(VisBase, self).write(stream_item) - - - @staticmethod - def write_stream_plot(vis, stream_vis:StreamVisInfo, stream_item:StreamItem): - with vis.lock: # this could be from separate thread! - #if stream_vis is None: - # utils.debug_log('stream_vis not specified in VisBase.write') - # stream_vis = next(iter(vis._stream_vises.values())) # use first as default - utils.debug_log("Stream received: {}".format(stream_item.stream_name), verbosity=5) - stream_vis._pending_items.put(stream_item) - - # if we accumulated enough of pending items then let's process them - if vis._can_update_stream_plots(): - vis._update_stream_plots() - - def _extract_results(self, stream_vis): - stream_items, clear_current, clear_history = [], False, False - while not stream_vis._pending_items.empty(): - stream_item = stream_vis._pending_items.get() - if stream_item.stream_reset: - utils.debug_log("Stream reset", stream_item.stream_name) - stream_items.clear() # no need to process these events - clear_current, clear_history = True, True - else: - # check if there was an exception - if stream_item.exception is not None: - #TODO: need better handling here? - print(stream_item.exception, file=sys.stderr) - raise stream_item.exception - - # state management for _clear_pending - # if we need to clear plot before putting in data, do so - if stream_vis._clear_pending: - stream_items.clear() - clear_current = True - stream_vis._clear_pending = False - if stream_vis.clear_after_each or (stream_item.ended and stream_vis.clear_after_end): - stream_vis._clear_pending = True - - stream_items.append(stream_item) - - return stream_items, clear_current, clear_history - - def _extract_vals(self, stream_items): - vals = [] - for stream_item in stream_items: - if stream_item.ended or stream_item.value is None: - pass # no values to add - else: - if utils.is_array_like(stream_item.value, tuple_is_array=False): - vals.extend(stream_item.value) - else: - vals.append(stream_item.value) - return vals - - @abstractmethod - def clear_plot(self, stream_vis, clear_history): - """(for derived class) Clears the data in specified plot before new data is redrawn""" - pass - @abstractmethod - def _show_stream_items(self, stream_vis, stream_items): - """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. - """ - - pass - @abstractmethod - def _post_add_subscription(self, stream_vis, **stream_vis_args): - pass - - # typically we want to batch up items for performance - def _can_update_stream_plots(self): - return True - - @abstractmethod - def _post_update_stream_plot(self, stream_vis): - pass - - def _update_stream_plots(self): - with self.lock: - self.q_last_processed = time.time() - for stream_vis in self._stream_vises.values(): - stream_items, clear_current, clear_history = self._extract_results(stream_vis) - - if clear_current: - self.clear_plot(stream_vis, clear_history) - - # if we have something to render - dirty = not self._show_stream_items(stream_vis, stream_items) - if dirty: - self._post_update_stream_plot(stream_vis) - stream_vis.last_update = time.time() - - @abstractmethod - def _show_widget_native(self, blocking:bool): - pass - @abstractmethod - def _show_widget_notebook(self): - pass \ No newline at end of file +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import sys, time, threading, queue, functools +from typing import Any +from types import MethodType +from abc import ABCMeta, abstractmethod + +from .lv_types import StreamVisInfo, StreamItem +from . import utils +from .stream import Stream + + +class VisBase(Stream, metaclass=ABCMeta): + # these are expensive import so we attach to base class so derived class can use them + from IPython import get_ipython, display + import ipywidgets as widgets + + def __init__(self, widget, cell:widgets.Box, title:str, show_legend:bool, stream_name:str=None, console_debug:bool=False, **vis_args): + super(VisBase, self).__init__(stream_name=stream_name, console_debug=console_debug) + + self.lock = threading.Lock() + self._use_hbox = True + #utils.set_default(vis_args, 'cell_width', '100%') + + self.widget = widget + + self.cell = cell or VisBase.widgets.HBox(layout=VisBase.widgets.Layout(\ + width=vis_args.get('cell_width', None))) if self._use_hbox else None + if self._use_hbox: + self.cell.children += (self.widget,) + self._stream_vises = {} + self.is_shown = cell is not None + self.title = title + self.last_ex = None + self.layout_dirty = False + self.q_last_processed = 0 + + def subscribe(self, stream:Stream, title=None, clear_after_end=False, clear_after_each=False, + show:bool=False, history_len=1, dim_history=True, opacity=None, **stream_vis_args): + # in this ovedrride we don't call base class method + with self.lock: + self.layout_dirty = True + + stream_vis = StreamVisInfo(stream, title, clear_after_end, + clear_after_each, history_len, dim_history, opacity, + len(self._stream_vises), stream_vis_args, 0) + stream_vis._clear_pending = False + stream_vis._pending_items = queue.Queue() + self._stream_vises[stream.stream_name] = stream_vis + + self._post_add_subscription(stream_vis, **stream_vis_args) + + super(VisBase, self).subscribe(stream) + + if show or (show is None and not self.is_shown): + return self.show() + + def show(self, blocking:bool=False): + self.is_shown = True + if VisBase.get_ipython(): + if self._use_hbox: + VisBase.display.display(self.cell) # this method doesn't need returns + #return self.cell + else: + return self._show_widget_notebook() + else: + return self._show_widget_native(blocking) + + def save(self, filepath:str)->None: + self._save_widget(filepath) + + def write(self, val:Any, from_stream:'Stream'=None): + stream_item = self.to_stream_item(val) + + stream_vis:StreamVisInfo = None + if from_stream: + stream_vis = self._stream_vises.get(from_stream.stream_name, None) + + if not stream_vis: # select the first one we have + stream_vis = next(iter(self._stream_vises.values())) + + VisBase.write_stream_plot(self, stream_vis, stream_item) + + super(VisBase, self).write(stream_item) + + + @staticmethod + def write_stream_plot(vis, stream_vis:StreamVisInfo, stream_item:StreamItem): + with vis.lock: # this could be from separate thread! + #if stream_vis is None: + # utils.debug_log('stream_vis not specified in VisBase.write') + # stream_vis = next(iter(vis._stream_vises.values())) # use first as default + utils.debug_log("Stream received: {}".format(stream_item.stream_name), verbosity=5) + stream_vis._pending_items.put(stream_item) + + # if we accumulated enough of pending items then let's process them + if vis._can_update_stream_plots(): + vis._update_stream_plots() + + def _extract_results(self, stream_vis): + stream_items, clear_current, clear_history = [], False, False + while not stream_vis._pending_items.empty(): + stream_item = stream_vis._pending_items.get() + if stream_item.stream_reset: + utils.debug_log("Stream reset", stream_item.stream_name) + stream_items.clear() # no need to process these events + clear_current, clear_history = True, True + else: + # check if there was an exception + if stream_item.exception is not None: + #TODO: need better handling here? + print(stream_item.exception, file=sys.stderr) + raise stream_item.exception + + # state management for _clear_pending + # if we need to clear plot before putting in data, do so + if stream_vis._clear_pending: + stream_items.clear() + clear_current = True + stream_vis._clear_pending = False + if stream_vis.clear_after_each or (stream_item.ended and stream_vis.clear_after_end): + stream_vis._clear_pending = True + + stream_items.append(stream_item) + + return stream_items, clear_current, clear_history + + def _extract_vals(self, stream_items): + vals = [] + for stream_item in stream_items: + if stream_item.ended or stream_item.value is None: + pass # no values to add + else: + if utils.is_array_like(stream_item.value, tuple_is_array=False): + vals.extend(stream_item.value) + else: + vals.append(stream_item.value) + return vals + + @abstractmethod + def clear_plot(self, stream_vis, clear_history): + """(for derived class) Clears the data in specified plot before new data is redrawn""" + pass + @abstractmethod + def _show_stream_items(self, stream_vis, stream_items): + """Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True. + """ + + pass + @abstractmethod + def _post_add_subscription(self, stream_vis, **stream_vis_args): + pass + + # typically we want to batch up items for performance + def _can_update_stream_plots(self): + return True + + @abstractmethod + def _post_update_stream_plot(self, stream_vis): + pass + + def _update_stream_plots(self): + with self.lock: + self.q_last_processed = time.time() + for stream_vis in self._stream_vises.values(): + stream_items, clear_current, clear_history = self._extract_results(stream_vis) + + if clear_current: + self.clear_plot(stream_vis, clear_history) + + # if we have something to render + dirty = not self._show_stream_items(stream_vis, stream_items) + if dirty: + self._post_update_stream_plot(stream_vis) + stream_vis.last_update = time.time() + + @abstractmethod + def _show_widget_native(self, blocking:bool): + pass + @abstractmethod + def _show_widget_notebook(self): + pass + def _save_widget(self, filepath:str)->None: + raise NotImplementedError('Save functionality is not implemented') \ No newline at end of file diff --git a/tensorwatch/visualizer.py b/tensorwatch/visualizer.py index 075f045..e6066fd 100644 --- a/tensorwatch/visualizer.py +++ b/tensorwatch/visualizer.py @@ -1,88 +1,91 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from .stream import Stream -from .vis_base import VisBase -from . import mpl -from . import plotly - -class Visualizer: - """Constructs visualizer for specified vis_type. - - NOTE: If you modify arguments here then also sync VisArgs contructor. - """ - def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None, - cell:'Visualizer'=None, title:str=None, - clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None, - - rows=2, cols=5, img_width=None, img_height=None, img_channels=None, - colormap=None, viz_img_scale=None, - - # these image params are for hover on point for t-sne - hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None, - - only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None, - xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False, - - # histogram - bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None, - - # pie chart - autopct=None, shadow=None, - - vis_args={}, stream_vis_args={})->None: - - cell = cell._host_base.cell if cell is not None else None - - if host: - self._host_base = host._host_base - else: - self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape, - cell_width=cell_width, cell_height=cell_height, - **vis_args) - - self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each, - history_len=history_len, dim_history=dim_history, opacity=opacity, - only_summary=only_summary if vis_type is None or 'summary' != vis_type else True, - separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color, - xrange=xrange, yrange=yrange, zrange=zrange, - draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True, - draw_marker=draw_marker, - rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels, - colormap=colormap, viz_img_scale=viz_img_scale, - bins=bins, normed=normed, histtype=histtype, edge_color=edge_color, linewidth=linewidth, bar_width = bar_width, - autopct=autopct, shadow=shadow, - **stream_vis_args) - - stream.load() - - def show(self): - return self._host_base.show() - - def _get_vis_base(self, vis_type, cell:VisBase.widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase: - if vis_type is None or vis_type in ['line', - 'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']: - return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, - is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args) - if vis_type in ['image', 'mpl-image']: - return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) - if vis_type in ['bar', 'bar3d']: - return mpl.bar_plot.BarPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, - is_3d=vis_type.endswith('3d'), **vis_args) - if vis_type in ['histogram']: - return mpl.histogram.Histogram(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) - if vis_type in ['pie']: - return mpl.pie_chart.PieChart(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) - if vis_type in ['text', 'summary']: - from .text_vis import TextVis - return TextVis(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) - if vis_type in ['line3d', 'scatter', 'scatter3d', - 'plotly-line', 'plotly-line3d', 'plotly-scatter', 'plotly-scatter3d', 'mesh3d']: - return plotly.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, - is_3d=vis_type.endswith('3d'), **vis_args) - if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']: - return plotly.embeddings_plot.EmbeddingsPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, - is_3d='2d' not in vis_type, - hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args) - else: - raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type)) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .stream import Stream +from .vis_base import VisBase +from . import mpl +from . import plotly + +class Visualizer: + """Constructs visualizer for specified vis_type. + + NOTE: If you modify arguments here then also sync VisArgs contructor. + """ + def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None, + cell:'Visualizer'=None, title:str=None, + clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None, + + rows=2, cols=5, img_width=None, img_height=None, img_channels=None, + colormap=None, viz_img_scale=None, + + # these image params are for hover on point for t-sne + hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None, + + only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None, + xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False, + + # histogram + bins=None, normed=None, histtype='bar', edge_color=None, linewidth=None, bar_width=None, + + # pie chart + autopct=None, shadow=None, + + vis_args={}, stream_vis_args={})->None: + + cell = cell._host_base.cell if cell is not None else None + + if host: + self._host_base = host._host_base + else: + self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape, + cell_width=cell_width, cell_height=cell_height, + **vis_args) + + self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each, + history_len=history_len, dim_history=dim_history, opacity=opacity, + only_summary=only_summary if vis_type is None or 'summary' != vis_type else True, + separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color, + xrange=xrange, yrange=yrange, zrange=zrange, + draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True, + draw_marker=draw_marker, + rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels, + colormap=colormap, viz_img_scale=viz_img_scale, + bins=bins, normed=normed, histtype=histtype, edge_color=edge_color, linewidth=linewidth, bar_width = bar_width, + autopct=autopct, shadow=shadow, + **stream_vis_args) + + stream.load() + + def show(self): + return self._host_base.show() + + def save(self, filepath:str)->None: + self._host_base.save(filepath) + + def _get_vis_base(self, vis_type, cell:VisBase.widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase: + if vis_type is None or vis_type in ['line', + 'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']: + return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, + is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args) + if vis_type in ['image', 'mpl-image']: + return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) + if vis_type in ['bar', 'bar3d']: + return mpl.bar_plot.BarPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, + is_3d=vis_type.endswith('3d'), **vis_args) + if vis_type in ['histogram']: + return mpl.histogram.Histogram(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) + if vis_type in ['pie']: + return mpl.pie_chart.PieChart(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) + if vis_type in ['text', 'summary']: + from .text_vis import TextVis + return TextVis(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args) + if vis_type in ['line3d', 'scatter', 'scatter3d', + 'plotly-line', 'plotly-line3d', 'plotly-scatter', 'plotly-scatter3d', 'mesh3d']: + return plotly.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, + is_3d=vis_type.endswith('3d'), **vis_args) + if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']: + return plotly.embeddings_plot.EmbeddingsPlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, + is_3d='2d' not in vis_type, + hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args) + else: + raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type)) diff --git a/test/visualizations/arr_mpl_line.py b/test/visualizations/arr_mpl_line.py new file mode 100644 index 0000000..2146816 --- /dev/null +++ b/test/visualizations/arr_mpl_line.py @@ -0,0 +1,7 @@ +import tensorwatch as tw + +stream = tw.ArrayStream([(i, i*i) for i in range(50)]) +img_plot = tw.Visualizer(stream, vis_type='mpl-line', viz_img_scale=3, xtitle='Epochs', ytitle='Gain') +# img_plot.show() +# tw.plt_loop() +img_plot.save(r'c:\temp\fig1.png')