This commit is contained in:
Shital Shah 2019-05-21 23:52:06 -07:00
Родитель 32027d230e
Коммит b8c40f3636
21 изменённых файлов: 46 добавлений и 40 удалений

Просмотреть файл

@ -13,6 +13,7 @@ from .mpl.line_plot import LinePlot
from .mpl.image_plot import ImagePlot
from .mpl.histogram import Histogram
from .mpl.bar_plot import BarPlot
from .mpl.pie_chart import PieChart
from .visualizer import Visualizer
from .stream import Stream

Просмотреть файл

@ -42,7 +42,7 @@ def probabilities2classes(probs, topk=5):
for p, c in zip(top_probs[0][0].detach().numpy(), top_probs[1][0].detach().numpy()))
class ImagenetLabels:
def __init__(self, json_path='../../data/imagenet_class_index.json'):
def __init__(self, json_path='imagenet_class_index.json'):
self._idx2label = []
self._idx2cls = []
self._cls2label = {}

Просмотреть файл

@ -64,7 +64,7 @@ class NotebookMaker:
lines = []
stream_identifier = 's'+str(stream_index)
lines.append("{} = client.open_stream(stream_name='{}')".format(stream_identifier, stream_name))
lines.append("{} = client.open_stream(name='{}')".format(stream_identifier, stream_name))
vis_identifier = 'v'+str(stream_index)
vis_args_strs = ['stream={}'.format(stream_identifier)]

Просмотреть файл

@ -52,4 +52,4 @@ def logits2probabilities(logits):
return F.softmax(logits, dim=1)
def tensor2numpy(t):
return t.data().cpu().numpy()
return t.data.cpu().numpy()

Просмотреть файл

@ -43,7 +43,7 @@ class Visualizer:
self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each,
history_len=history_len, dim_history=dim_history, opacity=opacity,
only_summary=only_summary if 'summary' != vis_type else True,
only_summary=only_summary if vis_type is None or 'summary' != vis_type else True,
separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color,
xrange=xrange, yrange=yrange, zrange=zrange,
draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True,
@ -63,7 +63,7 @@ class Visualizer:
if vis_type is None or vis_type in ['line',
'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']:
return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
is_3d=vis_type.endswith('3d'), **vis_args)
is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args)
if vis_type in ['image', 'mpl-image']:
return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
if vis_type in ['bar', 'bar3d']:

Просмотреть файл

@ -87,7 +87,7 @@ class Watcher(WatcherBase):
# request = create stream
if clisrv_req.req_type == CliSrvReqTypes.create_stream:
stream_req = clisrv_req.req_data
self.create_stream(stream_name=stream_req.stream_name, devices=stream_req.devices,
self.create_stream(name=stream_req.stream_name, devices=stream_req.devices,
event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle,
vis_args=stream_req.vis_args)
return None # ignore return as we can't send back stream obj

Просмотреть файл

@ -73,14 +73,14 @@ class WatcherClient(WatcherBase):
self._zmq_srvmgmt_sub.add_stream_req(stream_req)
if stream_req.devices is not None:
stream = self.open_stream(stream_name=stream_req.stream_name, devices=stream_req.devices)
stream = self.open_stream(name=stream_req.stream_name, devices=stream_req.devices)
else: # we cannot return remote streams that are not backed by a device
stream = None
return stream
# override to set devices default to tcp
def open_stream(self, name:str=None, devices:Sequence[str]=None)->Stream: # overriden
return super(WatcherClient, self).open_stream(stream_name=name, devices=devices)
return super(WatcherClient, self).open_stream(name=name, devices=devices)
# override to send request to server

Просмотреть файл

@ -4,8 +4,8 @@ import tensorwatch as tw
def main():
w = tw.Watcher()
s1 = w.create_stream()
s2 = w.create_stream(stream_name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
s3 = w.create_stream(stream_name='loss', expr='lambda d:d.loss')
s2 = w.create_stream(name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
s3 = w.create_stream(name='loss', expr='lambda d:d.loss')
w.make_notebook()
main()

Просмотреть файл

Просмотреть файл

@ -27,22 +27,22 @@ def show_find_lr():
utils.wait_key()
def plot_grads():
def plot_grads_plotly():
train_cli = tw.WatcherClient()
grads = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
expr='lambda d:grads_abs_mean(d.model)', throttle=1)
p = tw.plotly.line_plot.LinePlot('Demo')
p.subscribe(grads, xtitle='Epoch', ytitle='Gradients', history_len=30, new_on_eval=True)
p.subscribe(grads, xtitle='Layer', ytitle='Gradients', history_len=30, new_on_eval=True)
utils.wait_key()
def plot_grads1():
def plot_grads():
train_cli = tw.WatcherClient()
grads = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
expr='lambda d:grads_abs_mean(d.model)', throttle=1)
grad_plot = tw.LinePlot()
grad_plot.subscribe(grads, xtitle='Epoch', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
grad_plot.subscribe(grads, xtitle='Layer', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
grad_plot.show()
tw.plt_loop()
@ -51,9 +51,9 @@ def plot_weight():
train_cli = tw.WatcherClient()
params = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.abs().mean().item())', throttle=1)
expr='lambda d:weights_abs_mean(d.model)', throttle=1)
params_plot = tw.LinePlot()
params_plot.subscribe(params, xtitle='Epoch', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
params_plot.subscribe(params, xtitle='Layer', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
params_plot.show()
tw.plt_loop()
@ -89,14 +89,14 @@ def batch_stats():
# vis=train_loss, vis_type='mpl-line')
train_loss.show()
tw.image_utils.plt_loop()
tw.plt_loop()
def text_stats():
train_cli = tw.WatcherClient()
stream = train_cli.create_stream(event_name="batch",
expr='lambda d:(d.x, d.metrics.batch_loss)')
expr='lambda d:(d.metrics.epoch_index, d.metrics.batch_loss)')
trl = tw.Visualizer(stream, vis_type=None)
trl = tw.Visualizer(stream, vis_type='text')
trl.show()
input('Paused...')
@ -104,7 +104,7 @@ def text_stats():
#epoch_stats()
#plot_weight()
#plot_grads1()
img_in_class()
#text_stats()
#plot_grads()
#img_in_class()
text_stats()
#batch_stats()

Просмотреть файл

@ -2,7 +2,7 @@ from tensorwatch.saliency import saliency
from tensorwatch import image_utils, imagenet_utils, pytorch_utils
model = pytorch_utils.get_model('resnet50')
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/dogs.png', 240, #'../data/elephant.png', 101,
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/test_images/dogs.png', 240, #'../data/elephant.png', 101,
image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB')
results = saliency.get_image_saliency_results(model, raw_input, input, target_class)

Просмотреть файл

@ -20,7 +20,7 @@ def show_mpl():
def show_text():
cli = tw.WatcherClient()
s1 = cli.create_stream(expr='lambda v:(v.i, v.sum)')
text = tw.Visualizer(s1)
text = tw.Visualizer(s1, vis_type='text')
text.show()
input('Waiting')

Просмотреть файл

Просмотреть файл

@ -5,7 +5,7 @@
<ProjectGuid>9a7fe67e-93f0-42b5-b58f-77320fc639e4</ProjectGuid>
<ProjectHome>
</ProjectHome>
<StartupFile>files\file_stream.py</StartupFile>
<StartupFile>mnist\cli_mnist.py</StartupFile>
<SearchPath>
</SearchPath>
<WorkingDirectory>.</WorkingDirectory>
@ -24,6 +24,7 @@
<EnableUnmanagedDebugging>false</EnableUnmanagedDebugging>
</PropertyGroup>
<ItemGroup>
<Compile Include="deps\live_graph.py" />
<Compile Include="visualizations\arr_img_plot.py">
<SubType>Code</SubType>
</Compile>
@ -70,7 +71,7 @@
<Compile Include="visualizations\histogram.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="visualizations\line_plot.py">
<Compile Include="visualizations\line3d_plot.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="components\notebook_maker.py">
@ -98,7 +99,7 @@
<Compile Include="components\stream.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="zmq\zmq_stream_pub.py">
<Compile Include="zmq\zmq_stream.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="zmq\zmq_watcher_client.py">
@ -107,11 +108,11 @@
<Compile Include="zmq\zmq_watcher_server.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="zmq\zmq_stream_sub.py">
<Compile Include="zmq\zmq_sub.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="simple_log\srv_ij.py" />
<Compile Include="zmq\zmq_srv.py">
<Compile Include="zmq\zmq_pub.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="deps\thread.py">

Просмотреть файл

@ -1,6 +1,8 @@
import tensorwatch as tw
import random, time
# TODO: resolve problem with Axis3D?
def static_line3d():
w = tw.Watcher()
s = w.create_stream()

Просмотреть файл

@ -1,5 +1,5 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch import LinePlot
from tensorwatch.mpl.line_plot import LinePlot
from tensorwatch.image_utils import plt_loop
from tensorwatch.stream import Stream
from tensorwatch.lv_types import StreamItem

Просмотреть файл

@ -6,7 +6,7 @@ from tensorwatch import utils
utils.set_debug_verbosity(10)
def clisrv_callback(clisrv, msg):
print(msg)
print('from clisrv', msg)
stream = ZmqWrapper.Publication(port = 40859)
clisrv = ZmqWrapper.ClientServer(40860, True, clisrv_callback)

Просмотреть файл

@ -1,13 +1,11 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch.stream import Stream
from tensorwatch.zmq_stream import ZmqStream
def main():
watcher = WatcherBase()
zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
zmq_sub = ZmqStream(for_write=False, stream_name = 'ZmqSub', console_debug=True)
stream = watcher.create_stream(expr='lambda vars:vars.x**2')
zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
zmq_pub.subscribe(stream)
for i in range(5):

Просмотреть файл

@ -3,11 +3,15 @@ import time
from tensorwatch.zmq_wrapper import ZmqWrapper
from tensorwatch import utils
def on_event(obj):
print(obj)
class A:
def on_event(self, obj):
print(obj)
a = A()
utils.set_debug_verbosity(10)
sub = ZmqWrapper.Subscription(40859, "Topic1", on_event)
sub = ZmqWrapper.Subscription(40859, "Topic1", a.on_event)
print("subscriber is waiting")
clisrv = ZmqWrapper.ClientServer(40860, False)