native tests pass
This commit is contained in:
Родитель
32027d230e
Коммит
b8c40f3636
|
@ -13,6 +13,7 @@ from .mpl.line_plot import LinePlot
|
|||
from .mpl.image_plot import ImagePlot
|
||||
from .mpl.histogram import Histogram
|
||||
from .mpl.bar_plot import BarPlot
|
||||
from .mpl.pie_chart import PieChart
|
||||
from .visualizer import Visualizer
|
||||
|
||||
from .stream import Stream
|
||||
|
|
|
@ -42,7 +42,7 @@ def probabilities2classes(probs, topk=5):
|
|||
for p, c in zip(top_probs[0][0].detach().numpy(), top_probs[1][0].detach().numpy()))
|
||||
|
||||
class ImagenetLabels:
|
||||
def __init__(self, json_path='../../data/imagenet_class_index.json'):
|
||||
def __init__(self, json_path='imagenet_class_index.json'):
|
||||
self._idx2label = []
|
||||
self._idx2cls = []
|
||||
self._cls2label = {}
|
||||
|
|
|
@ -64,7 +64,7 @@ class NotebookMaker:
|
|||
lines = []
|
||||
|
||||
stream_identifier = 's'+str(stream_index)
|
||||
lines.append("{} = client.open_stream(stream_name='{}')".format(stream_identifier, stream_name))
|
||||
lines.append("{} = client.open_stream(name='{}')".format(stream_identifier, stream_name))
|
||||
|
||||
vis_identifier = 'v'+str(stream_index)
|
||||
vis_args_strs = ['stream={}'.format(stream_identifier)]
|
||||
|
|
|
@ -52,4 +52,4 @@ def logits2probabilities(logits):
|
|||
return F.softmax(logits, dim=1)
|
||||
|
||||
def tensor2numpy(t):
|
||||
return t.data().cpu().numpy()
|
||||
return t.data.cpu().numpy()
|
||||
|
|
|
@ -43,7 +43,7 @@ class Visualizer:
|
|||
|
||||
self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each,
|
||||
history_len=history_len, dim_history=dim_history, opacity=opacity,
|
||||
only_summary=only_summary if 'summary' != vis_type else True,
|
||||
only_summary=only_summary if vis_type is None or 'summary' != vis_type else True,
|
||||
separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color,
|
||||
xrange=xrange, yrange=yrange, zrange=zrange,
|
||||
draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True,
|
||||
|
@ -63,7 +63,7 @@ class Visualizer:
|
|||
if vis_type is None or vis_type in ['line',
|
||||
'mpl-line', 'mpl-line3d', 'mpl-scatter3d', 'mpl-scatter']:
|
||||
return mpl.line_plot.LinePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height,
|
||||
is_3d=vis_type.endswith('3d'), **vis_args)
|
||||
is_3d=vis_type is not None and vis_type.endswith('3d'), **vis_args)
|
||||
if vis_type in ['image', 'mpl-image']:
|
||||
return mpl.image_plot.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
|
||||
if vis_type in ['bar', 'bar3d']:
|
||||
|
|
|
@ -87,7 +87,7 @@ class Watcher(WatcherBase):
|
|||
# request = create stream
|
||||
if clisrv_req.req_type == CliSrvReqTypes.create_stream:
|
||||
stream_req = clisrv_req.req_data
|
||||
self.create_stream(stream_name=stream_req.stream_name, devices=stream_req.devices,
|
||||
self.create_stream(name=stream_req.stream_name, devices=stream_req.devices,
|
||||
event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle,
|
||||
vis_args=stream_req.vis_args)
|
||||
return None # ignore return as we can't send back stream obj
|
||||
|
|
|
@ -73,14 +73,14 @@ class WatcherClient(WatcherBase):
|
|||
self._zmq_srvmgmt_sub.add_stream_req(stream_req)
|
||||
|
||||
if stream_req.devices is not None:
|
||||
stream = self.open_stream(stream_name=stream_req.stream_name, devices=stream_req.devices)
|
||||
stream = self.open_stream(name=stream_req.stream_name, devices=stream_req.devices)
|
||||
else: # we cannot return remote streams that are not backed by a device
|
||||
stream = None
|
||||
return stream
|
||||
|
||||
# override to set devices default to tcp
|
||||
def open_stream(self, name:str=None, devices:Sequence[str]=None)->Stream: # overriden
|
||||
return super(WatcherClient, self).open_stream(stream_name=name, devices=devices)
|
||||
return super(WatcherClient, self).open_stream(name=name, devices=devices)
|
||||
|
||||
|
||||
# override to send request to server
|
||||
|
|
|
@ -4,8 +4,8 @@ import tensorwatch as tw
|
|||
def main():
|
||||
w = tw.Watcher()
|
||||
s1 = w.create_stream()
|
||||
s2 = w.create_stream(stream_name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
|
||||
s3 = w.create_stream(stream_name='loss', expr='lambda d:d.loss')
|
||||
s2 = w.create_stream(name='accuracy', vis_args=tw.VisArgs(vis_type='line', xtitle='X-Axis', clear_after_each=False, history_len=2))
|
||||
s3 = w.create_stream(name='loss', expr='lambda d:d.loss')
|
||||
w.make_notebook()
|
||||
|
||||
main()
|
||||
|
|
|
@ -27,22 +27,22 @@ def show_find_lr():
|
|||
|
||||
utils.wait_key()
|
||||
|
||||
def plot_grads():
|
||||
def plot_grads_plotly():
|
||||
train_cli = tw.WatcherClient()
|
||||
grads = train_cli.create_stream(event_name='batch',
|
||||
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
|
||||
expr='lambda d:grads_abs_mean(d.model)', throttle=1)
|
||||
p = tw.plotly.line_plot.LinePlot('Demo')
|
||||
p.subscribe(grads, xtitle='Epoch', ytitle='Gradients', history_len=30, new_on_eval=True)
|
||||
p.subscribe(grads, xtitle='Layer', ytitle='Gradients', history_len=30, new_on_eval=True)
|
||||
utils.wait_key()
|
||||
|
||||
|
||||
def plot_grads1():
|
||||
def plot_grads():
|
||||
train_cli = tw.WatcherClient()
|
||||
|
||||
grads = train_cli.create_stream(event_name='batch',
|
||||
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
|
||||
expr='lambda d:grads_abs_mean(d.model)', throttle=1)
|
||||
grad_plot = tw.LinePlot()
|
||||
grad_plot.subscribe(grads, xtitle='Epoch', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
|
||||
grad_plot.subscribe(grads, xtitle='Layer', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
|
||||
grad_plot.show()
|
||||
|
||||
tw.plt_loop()
|
||||
|
@ -51,9 +51,9 @@ def plot_weight():
|
|||
train_cli = tw.WatcherClient()
|
||||
|
||||
params = train_cli.create_stream(event_name='batch',
|
||||
expr='lambda d:agg_params(d.model, lambda p: p.abs().mean().item())', throttle=1)
|
||||
expr='lambda d:weights_abs_mean(d.model)', throttle=1)
|
||||
params_plot = tw.LinePlot()
|
||||
params_plot.subscribe(params, xtitle='Epoch', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
|
||||
params_plot.subscribe(params, xtitle='Layer', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
|
||||
params_plot.show()
|
||||
|
||||
tw.plt_loop()
|
||||
|
@ -89,14 +89,14 @@ def batch_stats():
|
|||
# vis=train_loss, vis_type='mpl-line')
|
||||
|
||||
train_loss.show()
|
||||
tw.image_utils.plt_loop()
|
||||
tw.plt_loop()
|
||||
|
||||
def text_stats():
|
||||
train_cli = tw.WatcherClient()
|
||||
stream = train_cli.create_stream(event_name="batch",
|
||||
expr='lambda d:(d.x, d.metrics.batch_loss)')
|
||||
expr='lambda d:(d.metrics.epoch_index, d.metrics.batch_loss)')
|
||||
|
||||
trl = tw.Visualizer(stream, vis_type=None)
|
||||
trl = tw.Visualizer(stream, vis_type='text')
|
||||
trl.show()
|
||||
input('Paused...')
|
||||
|
||||
|
@ -104,7 +104,7 @@ def text_stats():
|
|||
|
||||
#epoch_stats()
|
||||
#plot_weight()
|
||||
#plot_grads1()
|
||||
img_in_class()
|
||||
#text_stats()
|
||||
#plot_grads()
|
||||
#img_in_class()
|
||||
text_stats()
|
||||
#batch_stats()
|
|
@ -2,7 +2,7 @@ from tensorwatch.saliency import saliency
|
|||
from tensorwatch import image_utils, imagenet_utils, pytorch_utils
|
||||
|
||||
model = pytorch_utils.get_model('resnet50')
|
||||
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/dogs.png', 240, #'../data/elephant.png', 101,
|
||||
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/test_images/dogs.png', 240, #'../data/elephant.png', 101,
|
||||
image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB')
|
||||
|
||||
results = saliency.get_image_saliency_results(model, raw_input, input, target_class)
|
||||
|
|
|
@ -20,7 +20,7 @@ def show_mpl():
|
|||
def show_text():
|
||||
cli = tw.WatcherClient()
|
||||
s1 = cli.create_stream(expr='lambda v:(v.i, v.sum)')
|
||||
text = tw.Visualizer(s1)
|
||||
text = tw.Visualizer(s1, vis_type='text')
|
||||
text.show()
|
||||
input('Waiting')
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
<ProjectGuid>9a7fe67e-93f0-42b5-b58f-77320fc639e4</ProjectGuid>
|
||||
<ProjectHome>
|
||||
</ProjectHome>
|
||||
<StartupFile>files\file_stream.py</StartupFile>
|
||||
<StartupFile>mnist\cli_mnist.py</StartupFile>
|
||||
<SearchPath>
|
||||
</SearchPath>
|
||||
<WorkingDirectory>.</WorkingDirectory>
|
||||
|
@ -24,6 +24,7 @@
|
|||
<EnableUnmanagedDebugging>false</EnableUnmanagedDebugging>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<Compile Include="deps\live_graph.py" />
|
||||
<Compile Include="visualizations\arr_img_plot.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
|
@ -70,7 +71,7 @@
|
|||
<Compile Include="visualizations\histogram.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="visualizations\line_plot.py">
|
||||
<Compile Include="visualizations\line3d_plot.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="components\notebook_maker.py">
|
||||
|
@ -98,7 +99,7 @@
|
|||
<Compile Include="components\stream.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="zmq\zmq_stream_pub.py">
|
||||
<Compile Include="zmq\zmq_stream.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="zmq\zmq_watcher_client.py">
|
||||
|
@ -107,11 +108,11 @@
|
|||
<Compile Include="zmq\zmq_watcher_server.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="zmq\zmq_stream_sub.py">
|
||||
<Compile Include="zmq\zmq_sub.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="simple_log\srv_ij.py" />
|
||||
<Compile Include="zmq\zmq_srv.py">
|
||||
<Compile Include="zmq\zmq_pub.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="deps\thread.py">
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import tensorwatch as tw
|
||||
import random, time
|
||||
|
||||
# TODO: resolve problem with Axis3D?
|
||||
|
||||
def static_line3d():
|
||||
w = tw.Watcher()
|
||||
s = w.create_stream()
|
|
@ -1,5 +1,5 @@
|
|||
from tensorwatch.watcher_base import WatcherBase
|
||||
from tensorwatch import LinePlot
|
||||
from tensorwatch.mpl.line_plot import LinePlot
|
||||
from tensorwatch.image_utils import plt_loop
|
||||
from tensorwatch.stream import Stream
|
||||
from tensorwatch.lv_types import StreamItem
|
||||
|
|
|
@ -6,7 +6,7 @@ from tensorwatch import utils
|
|||
utils.set_debug_verbosity(10)
|
||||
|
||||
def clisrv_callback(clisrv, msg):
|
||||
print(msg)
|
||||
print('from clisrv', msg)
|
||||
|
||||
stream = ZmqWrapper.Publication(port = 40859)
|
||||
clisrv = ZmqWrapper.ClientServer(40860, True, clisrv_callback)
|
|
@ -1,13 +1,11 @@
|
|||
from tensorwatch.watcher_base import WatcherBase
|
||||
from tensorwatch.stream import Stream
|
||||
from tensorwatch.zmq_stream import ZmqStream
|
||||
|
||||
def main():
|
||||
watcher = WatcherBase()
|
||||
zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
|
||||
zmq_sub = ZmqStream(for_write=False, stream_name = 'ZmqSub', console_debug=True)
|
||||
|
||||
stream = watcher.create_stream(expr='lambda vars:vars.x**2')
|
||||
|
||||
zmq_pub = ZmqStream(for_write=True, stream_name = 'ZmqPub', console_debug=True)
|
||||
zmq_pub.subscribe(stream)
|
||||
|
||||
for i in range(5):
|
|
@ -3,11 +3,15 @@ import time
|
|||
from tensorwatch.zmq_wrapper import ZmqWrapper
|
||||
from tensorwatch import utils
|
||||
|
||||
def on_event(obj):
|
||||
print(obj)
|
||||
class A:
|
||||
def on_event(self, obj):
|
||||
print(obj)
|
||||
|
||||
a = A()
|
||||
|
||||
|
||||
utils.set_debug_verbosity(10)
|
||||
sub = ZmqWrapper.Subscription(40859, "Topic1", on_event)
|
||||
sub = ZmqWrapper.Subscription(40859, "Topic1", a.on_event)
|
||||
print("subscriber is waiting")
|
||||
|
||||
clisrv = ZmqWrapper.ClientServer(40860, False)
|
Загрузка…
Ссылка в новой задаче