This commit is contained in:
Shital Shah 2019-05-19 22:39:56 -07:00
Родитель fa301c09a3
Коммит f8f1154d88
18 изменённых файлов: 169 добавлений и 26 удалений

Просмотреть файл

@ -28,4 +28,5 @@
* new graph on end
* TF support
* generic visualizer -> Given obj and Box, paint in box
* visualize image and text with attention
* visualize image and text with attention
* add confidence interval for plotly: https://plot.ly/python/continuous-error-bars/

Просмотреть файл

@ -20,6 +20,12 @@
</PropertyGroup>
<ItemGroup>
<Compile Include="tensorwatch\embeddings\__init__.py" />
<Compile Include="tensorwatch\mpl\bar_plot.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\mpl\histogram.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\stream_union.py">
<SubType>Code</SubType>
</Compile>

Просмотреть файл

@ -9,7 +9,7 @@ from .watcher_base import WatcherBase
from .text_vis import TextVis
from .plotly import EmbeddingsPlot
from .mpl import LinePlot, ImagePlot
from .mpl import LinePlot, ImagePlot, Histogram
from .visualizer import Visualizer
from .stream import Stream

Просмотреть файл

@ -115,10 +115,12 @@ def linear_to_2d(img, size=None):
def stack_images(imgs):
return np.hstack(imgs)
def plt_loop(sleep_time=1, plt_pause=0.01):
plt.ion()
plt.show(block=False)
while(not plt.waitforbuttonpress(plt_pause)):
def plt_loop(count=None, sleep_time=1, plt_pause=0.01):
#plt.ion()
#plt.show(block=False)
while((count is None or count > 0) and not plt.waitforbuttonpress(plt_pause)):
#plt.draw()
plt.pause(plt_pause)
time.sleep(sleep_time)
time.sleep(sleep_time)
if count is not None:
count = count - 1

Просмотреть файл

@ -6,3 +6,4 @@
# https://github.com/Microsoft/ptvsd/issues/1041
from .line_plot import LinePlot
from .image_plot import ImagePlot
from .histogram import Histogram

Просмотреть файл

@ -0,0 +1 @@

Просмотреть файл

@ -57,7 +57,7 @@ class BaseMplPlot(VisBase):
if not self._ax_main:
# by default assign one subplot to whole graph
self._ax_main = self.figure.add_subplot(111)
self._ax_main.grid(True)
self._ax_main.grid(self.is_show_grid())
# change the color of the top and right spines to opaque gray
self._ax_main.spines['right'].set_color((.8,.8,.8))
self._ax_main.spines['top'].set_color((.8,.8,.8))
@ -66,6 +66,10 @@ class BaseMplPlot(VisBase):
title.set_weight('bold')
return self._ax_main
# overridable
def is_show_grid(self):
return True
def _on_update(self, frame): # pylint: disable=unused-argument
try:
self._update_stream_plots()

Просмотреть файл

@ -0,0 +1,73 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_mpl_plot import BaseMplPlot
import matplotlib
import matplotlib.pyplot as plt
from .. import utils
import numpy as np
import ipywidgets as widgets
class Histogram(BaseMplPlot):
def init_stream_plot(self, stream_vis,
xtitle='', ytitle='', color=None,
bins=None, normed=None, histtype='bar', edge_color=None, linewidth=2,
opacity=None, **stream_vis_args):
stream_vis.xylabel_refs = [] # annotation references
# add main subplot
stream_vis.bins, stream_vis.normed, stream_vis.linewidth = bins, normed, linewidth
stream_vis.ax = self.get_main_axis()
stream_vis.data = []
stream_vis.hist_bars = [] # stores previously drawn bars
#TODO: improve color selection
color = color or plt.cm.Dark2((len(self._stream_vises)%8)/8) # pylint: disable=no-member
stream_vis.color = color
stream_vis.edge_color = 'black'
stream_vis.histtype = histtype
stream_vis.opacity = opacity
stream_vis.ax.set_xlabel(xtitle)
stream_vis.ax.set_ylabel(ytitle)
stream_vis.ax.yaxis.label.set_color(color)
stream_vis.ax.yaxis.label.set_style('italic')
stream_vis.ax.xaxis.label.set_style('italic')
def is_show_grid(self): #override
return False
def clear_bars(self, stream_vis):
for bar in stream_vis.hist_bars:
bar.remove()
stream_vis.hist_bars.clear()
def clear_plot(self, stream_vis, clear_history):
stream_vis.data.clear()
self.clear_bars(stream_vis)
def _show_stream_items(self, stream_vis, stream_items):
"""Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
"""
vals = self._extract_vals(stream_items)
if not len(vals):
return True
stream_vis.data += vals
self.clear_bars(stream_vis)
n, bins, stream_vis.hist_bars = stream_vis.ax.hist(stream_vis.data, bins=stream_vis.bins,
normed=stream_vis.normed, color=stream_vis.color, edgecolor=stream_vis.edge_color,
histtype=stream_vis.histtype, alpha=stream_vis.opacity,
linewidth=stream_vis.linewidth)
stream_vis.ax.set_xticks(bins)
#stream_vis.ax.relim()
#stream_vis.ax.autoscale_view()
return False

Просмотреть файл

@ -29,6 +29,9 @@ class ImagePlot(BaseMplPlot):
img.set_data(np.zeros((x, y)))
def _show_stream_items(self, stream_vis, stream_items):
"""Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
"""
# as we repaint each image plot, select last if multiple events were pending
stream_item = None
for er in reversed(stream_items):
@ -36,7 +39,7 @@ class ImagePlot(BaseMplPlot):
stream_item = er
break
if stream_item is None:
return False
return True
row, col, i = 0, 0, 0
dirty = False
@ -86,7 +89,7 @@ class ImagePlot(BaseMplPlot):
break
i += 1
return dirty
return not dirty
def has_legend(self):

Просмотреть файл

@ -78,9 +78,12 @@ class LinePlot(BaseMplPlot):
stream_vis.xylabel_refs.clear()
def _show_stream_items(self, stream_vis, stream_items):
"""Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
"""
vals = self._extract_vals(stream_items)
if not len(vals):
return False
return True # not dirty
line = stream_vis.ax.get_lines()[-1]
xdata, ydata = line.get_data()
@ -147,7 +150,7 @@ class LinePlot(BaseMplPlot):
stream_vis.ax.relim()
stream_vis.ax.autoscale_view()
return True
return False # dirty

Просмотреть файл

@ -98,7 +98,9 @@ class BasePlotlyPlot(VisBase):
pass
@abstractmethod
def _show_stream_items(self, stream_vis, stream_items):
"""(for derived class) Plot the data in given axes"""
"""Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
"""
pass
@abstractmethod
def _setup_layout(self, stream_vis):

Просмотреть файл

@ -86,9 +86,12 @@ class LinePlot(BasePlotlyPlot):
return self._create_2d_trace(stream_vis, mode, hoverinfo, marker, line)
def _show_stream_items(self, stream_vis, stream_items):
"""Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
"""
vals = self._extract_vals(stream_items)
if not len(vals):
return False
return True
# get trace data
trace = self.widget.data[stream_vis.trace_index]
@ -170,7 +173,7 @@ class LinePlot(BasePlotlyPlot):
exisitng += clrdata
self.widget.data[stream_vis.trace_index].marker.color = exisitng
return True
return False # dirty
def clear_plot(self, stream_vis, clear_history):
traces = range(len(stream_vis.trace_history)) if clear_history else (stream_vis.trace_index,)

Просмотреть файл

@ -49,6 +49,9 @@ class TextVis(VisBase):
self.df = self.df.iloc[0:0]
def _show_stream_items(self, stream_vis, stream_items):
"""Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
"""
for stream_item in stream_items:
if stream_item.ended:
self.df = self.df.append(pd.Series({'Ended':True}),
@ -56,7 +59,7 @@ class TextVis(VisBase):
else:
vals = self._extract_vals((stream_item,))
self._append(stream_vis, vals)
return True
return False # dirty
def _post_update_stream_plot(self, stream_vis):
if get_ipython():

Просмотреть файл

@ -139,7 +139,9 @@ class VisBase(Stream, metaclass=ABCMeta):
pass
@abstractmethod
def _show_stream_items(self, stream_vis, stream_items):
"""(for derived class) Plot the data in given axes"""
"""Paint the given stream_items in to visualizer. If visualizer is dirty then return False else True.
"""
pass
@abstractmethod
def _post_add_subscription(self, stream_vis, **stream_vis_args):
@ -163,7 +165,7 @@ class VisBase(Stream, metaclass=ABCMeta):
self.clear_plot(stream_vis, clear_history)
# if we have something to render
dirty = self._show_stream_items(stream_vis, stream_items)
dirty = not self._show_stream_items(stream_vis, stream_items)
if dirty:
self._post_update_stream_plot(stream_vis)
stream_vis.last_update = time.time()

Просмотреть файл

@ -19,6 +19,9 @@ class Visualizer:
only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None,
xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False,
# histogram
bins=None, normed=None, histtype='bar', edge_color=None, linewidth=2,
vis_args={}, stream_vis_args={})->None:
cell = cell._host_base.cell if cell is not None else None
@ -39,6 +42,7 @@ class Visualizer:
draw_marker=draw_marker,
rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels,
colormap=colormap, viz_img_scale=viz_img_scale,
bins=bins, normed=normed, histtype=histtype, edge_color=edge_color, linewidth=linewidth,
**stream_vis_args)
stream.load()
@ -50,6 +54,15 @@ class Visualizer:
if vis_type is None:
from .text_vis import TextVis
return TextVis(cell=cell, title=title, **vis_args)
if vis_type in ['line', 'mpl-line', 'mpl-scatter']:
from . import mpl
return mpl.LinePlot(cell=cell, title=title, **vis_args)
if vis_type in ['image', 'mpl-image']:
from . import mpl
return mpl.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
if vis_type in ['histogram']:
from . import mpl
return mpl.Histogram(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
if vis_type in ['text', 'summary']:
from .text_vis import TextVis
return TextVis(cell=cell, title=title, **vis_args)
@ -58,12 +71,6 @@ class Visualizer:
from . import plotly
return plotly.LinePlot(cell=cell, title=title,
is_3d=vis_type in ['line3d', 'scatter3d', 'mesh3d'], **vis_args)
if vis_type in ['image', 'mpl-image']:
from . import mpl
return mpl.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
if vis_type in ['line', 'mpl-line', 'mpl-scatter']:
from . import mpl
return mpl.LinePlot(cell=cell, title=title, **vis_args)
if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']:
from . import plotly
return plotly.EmbeddingsPlot(cell=cell, title=title, is_3d='2d' not in vis_type,

Просмотреть файл

@ -0,0 +1,29 @@
import tensorwatch as tw
import random, time
def static_hist():
w = tw.Watcher()
s = w.create_stream()
v = tw.Visualizer(s, vis_type='histogram', bins=6)
v.show()
for i in range(100):
i = float(i)
s.write(int(random.random()*10))
tw.plt_loop()
def dynamic_hist():
w = tw.Watcher()
s = w.create_stream()
v = tw.Visualizer(s, vis_type='histogram', bins=6, clear_after_each=True)
v.show()
for i in range(100):
s.write([int(random.random()*10) for i in range(100)])
tw.plt_loop(count=3)
dynamic_hist()

Просмотреть файл

@ -5,7 +5,7 @@
<ProjectGuid>9a7fe67e-93f0-42b5-b58f-77320fc639e4</ProjectGuid>
<ProjectHome>
</ProjectHome>
<StartupFile>simple_log\confidence_interval.py</StartupFile>
<StartupFile>simple_log\histogram.py</StartupFile>
<SearchPath>
</SearchPath>
<WorkingDirectory>.</WorkingDirectory>
@ -61,12 +61,15 @@
<Compile Include="simple_log\cli_sum_log_2.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="simple_log\confidence_interval.py">
<Compile Include="simple_log\confidence_int.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="simple_log\file_only_test.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="simple_log\histogram.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="simple_log\srv_sum_log_2.py" />
<Compile Include="simple_log\srv_sum_log.py">
<SubType>Code</SubType>