From 30f770cc97ae399ec4de222c64fb77958704241e Mon Sep 17 00:00:00 2001 From: Shital Shah Date: Tue, 3 Mar 2020 22:42:29 -0800 Subject: [PATCH] dependency on pydotz --- setup.py | 4 ++-- tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 1fb73b5..b6a27fc 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ with open("README.md", "r") as fh: setuptools.setup( name="tensorwatch", - version="0.9.0", + version="0.9.1", author="Shital Shah", author_email="shitals@microsoft.com", description="Interactive Realtime Debugging and Visualization for AI", @@ -25,7 +25,7 @@ setuptools.setup( include_package_data=True, install_requires=[ 'matplotlib', 'numpy', 'pyzmq', 'plotly', 'ipywidgets', - 'pydot @ git+https://github.com/sytelus/pydot@v1.5.0#egg=pydot', + 'pydotz', 'nbformat', 'scikit-image', 'nbformat', 'pyyaml', 'scikit-image', 'graphviz' # , 'receptivefield' ] ) diff --git a/tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py b/tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py index 1ec9835..02d5336 100644 --- a/tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py +++ b/tensorwatch/model_graph/hiddenlayer/pytorch_draw_model.py @@ -16,11 +16,11 @@ class DotWrapper: # directory, file_name = os.path.split(path) # # Remove extension from file name. dot.render() adds it. # file_name = file_name.replace("." + format, "") - # self.dot.render(file_name, directory=directory, cleanup=True) + # self.dot.render(file_name, directory=directory, cleanup=True) if filename is not None: png = self.dot.create_png() with open(os.path.expanduser(filename), 'wb') as fid: - fid.write(png) # + fid.write(png) # def draw_graph(model, args): if args is None: @@ -33,7 +33,7 @@ def draw_graph(model, args): args = torch.ones(args) dot = draw_img_classifier(model, args) - return DotWrapper(dot) + return DotWrapper(dot) def draw_img_classifier(model, dataset=None, display_param_nodes=False, rankdir='TB', styles=None, input_shape=None): @@ -110,7 +110,7 @@ def sgraph2dot(sgraph, display_param_nodes=False, rankdir='TB', styles=None): if op['type'] == 'Conv': return ["sh={}".format(distiller.size2str(op['attrs']['kernel_shape'])), "g={}".format(str(op['attrs']['group']))] - return '' + return '' op_nodes = [op['name'] for op in sgraph.ops.values()] data_nodes = []