This commit is contained in:
Shital Shah 2020-03-03 22:42:29 -08:00
Родитель 6f7bcbe510
Коммит 30f770cc97
2 изменённых файлов: 6 добавлений и 6 удалений

Просмотреть файл

@ -8,7 +8,7 @@ with open("README.md", "r") as fh:
setuptools.setup(
name="tensorwatch",
version="0.9.0",
version="0.9.1",
author="Shital Shah",
author_email="shitals@microsoft.com",
description="Interactive Realtime Debugging and Visualization for AI",
@ -25,7 +25,7 @@ setuptools.setup(
include_package_data=True,
install_requires=[
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'ipywidgets',
'pydot @ git+https://github.com/sytelus/pydot@v1.5.0#egg=pydot',
'pydotz',
'nbformat', 'scikit-image', 'nbformat', 'pyyaml', 'scikit-image', 'graphviz' # , 'receptivefield'
]
)

Просмотреть файл

@ -16,11 +16,11 @@ class DotWrapper:
# directory, file_name = os.path.split(path)
# # Remove extension from file name. dot.render() adds it.
# file_name = file_name.replace("." + format, "")
# self.dot.render(file_name, directory=directory, cleanup=True)
# self.dot.render(file_name, directory=directory, cleanup=True)
if filename is not None:
png = self.dot.create_png()
with open(os.path.expanduser(filename), 'wb') as fid:
fid.write(png) #
fid.write(png) #
def draw_graph(model, args):
if args is None:
@ -33,7 +33,7 @@ def draw_graph(model, args):
args = torch.ones(args)
dot = draw_img_classifier(model, args)
return DotWrapper(dot)
return DotWrapper(dot)
def draw_img_classifier(model, dataset=None, display_param_nodes=False,
rankdir='TB', styles=None, input_shape=None):
@ -110,7 +110,7 @@ def sgraph2dot(sgraph, display_param_nodes=False, rankdir='TB', styles=None):
if op['type'] == 'Conv':
return ["sh={}".format(distiller.size2str(op['attrs']['kernel_shape'])),
"g={}".format(str(op['attrs']['group']))]
return ''
return ''
op_nodes = [op['name'] for op in sgraph.ops.values()]
data_nodes = []