This commit is contained in:
Shital Shah 2019-05-15 01:56:53 -07:00
Родитель 8d9d7490f0
Коммит 9e5f1f3355
111 изменённых файлов: 18275 добавлений и 29 удалений

345
.gitignore поставляемый
Просмотреть файл

@ -1,3 +1,6 @@
data/
!data/log-20180928-175828.dlclog
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
@ -102,3 +105,345 @@ venv.bak/
# mypy # mypy
.mypy_cache/ .mypy_cache/
## Ignore Visual Studio temporary files, build results, and
## files generated by popular Visual Studio add-ons.
##
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
# User-specific files
*.rsuser
*.suo
*.user
*.userosscache
*.sln.docstates
# User-specific files (MonoDevelop/Xamarin Studio)
*.userprefs
# Build results
[Dd]ebug/
[Dd]ebugPublic/
[Rr]elease/
[Rr]eleases/
x64/
x86/
[Aa][Rr][Mm]/
[Aa][Rr][Mm]64/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
# Visual Studio 2015/2017 cache/options directory
.vs/
# Uncomment if you have tasks that create the project's static files in wwwroot
#wwwroot/
# Visual Studio 2017 auto generated files
Generated\ Files/
# MSTest test Results
[Tt]est[Rr]esult*/
[Bb]uild[Ll]og.*
# NUNIT
*.VisualState.xml
TestResult.xml
# Build Results of an ATL Project
[Dd]ebugPS/
[Rr]eleasePS/
dlldata.c
# Benchmark Results
BenchmarkDotNet.Artifacts/
# .NET Core
project.lock.json
project.fragment.lock.json
artifacts/
# StyleCop
StyleCopReport.xml
# Files built by Visual Studio
*_i.c
*_p.c
*_h.h
*.ilk
*.meta
*.obj
*.iobj
*.pch
*.pdb
*.ipdb
*.pgc
*.pgd
*.rsp
*.sbr
*.tlb
*.tli
*.tlh
*.tmp
*.tmp_proj
*_wpftmp.csproj
*.log
*.vspscc
*.vssscc
.builds
*.pidb
*.svclog
*.scc
# Chutzpah Test files
_Chutzpah*
# Visual C++ cache files
ipch/
*.aps
*.ncb
*.opendb
*.opensdf
*.sdf
*.cachefile
*.VC.db
*.VC.VC.opendb
# Visual Studio profiler
*.psess
*.vsp
*.vspx
*.sap
# Visual Studio Trace Files
*.e2e
# TFS 2012 Local Workspace
$tf/
# Guidance Automation Toolkit
*.gpState
# ReSharper is a .NET coding add-in
_ReSharper*/
*.[Rr]e[Ss]harper
*.DotSettings.user
# JustCode is a .NET coding add-in
.JustCode
# TeamCity is a build add-in
_TeamCity*
# DotCover is a Code Coverage Tool
*.dotCover
# AxoCover is a Code Coverage Tool
.axoCover/*
!.axoCover/settings.json
# Visual Studio code coverage results
*.coverage
*.coveragexml
# NCrunch
_NCrunch_*
.*crunch*.local.xml
nCrunchTemp_*
# MightyMoose
*.mm.*
AutoTest.Net/
# Web workbench (sass)
.sass-cache/
# Installshield output folder
[Ee]xpress/
# DocProject is a documentation generator add-in
DocProject/buildhelp/
DocProject/Help/*.HxT
DocProject/Help/*.HxC
DocProject/Help/*.hhc
DocProject/Help/*.hhk
DocProject/Help/*.hhp
DocProject/Help/Html2
DocProject/Help/html
# Click-Once directory
publish/
# Publish Web Output
*.[Pp]ublish.xml
*.azurePubxml
# Note: Comment the next line if you want to checkin your web deploy settings,
# but database connection strings (with potential passwords) will be unencrypted
*.pubxml
*.publishproj
# Microsoft Azure Web App publish settings. Comment the next line if you want to
# checkin your Azure Web App publish settings, but sensitive information contained
# in these scripts will be unencrypted
PublishScripts/
# NuGet Packages
*.nupkg
# The packages folder can be ignored because of Package Restore
**/[Pp]ackages/*
# except build/, which is used as an MSBuild target.
!**/[Pp]ackages/build/
# Uncomment if necessary however generally it will be regenerated when needed
#!**/[Pp]ackages/repositories.config
# NuGet v3's project.json files produces more ignorable files
*.nuget.props
*.nuget.targets
# Microsoft Azure Build Output
csx/
*.build.csdef
# Microsoft Azure Emulator
ecf/
rcf/
# Windows Store app package directories and files
AppPackages/
BundleArtifacts/
Package.StoreAssociation.xml
_pkginfo.txt
*.appx
# Visual Studio cache files
# files ending in .cache can be ignored
*.[Cc]ache
# but keep track of directories ending in .cache
!?*.[Cc]ache/
# Others
ClientBin/
~$*
*~
*.dbmdl
*.dbproj.schemaview
*.jfm
*.pfx
*.publishsettings
orleans.codegen.cs
# Including strong name files can present a security risk
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
#*.snk
# Since there are multiple workflows, uncomment next line to ignore bower_components
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
#bower_components/
# ASP.NET Core default setup: bower directory is configured as wwwroot/lib/ and bower restore is true
**/wwwroot/lib/
# RIA/Silverlight projects
Generated_Code/
# Backup & report files from converting an old project file
# to a newer Visual Studio version. Backup files are not needed,
# because we have git ;-)
_UpgradeReport_Files/
Backup*/
UpgradeLog*.XML
UpgradeLog*.htm
ServiceFabricBackup/
*.rptproj.bak
# SQL Server files
*.mdf
*.ldf
*.ndf
# Business Intelligence projects
*.rdl.data
*.bim.layout
*.bim_*.settings
*.rptproj.rsuser
# Microsoft Fakes
FakesAssemblies/
# GhostDoc plugin setting file
*.GhostDoc.xml
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
node_modules/
# Visual Studio 6 build log
*.plg
# Visual Studio 6 workspace options file
*.opt
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
*.vbw
# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
**/*.DesktopClient/ModelManifest.xml
**/*.Server/GeneratedArtifacts
**/*.Server/ModelManifest.xml
_Pvt_Extensions
# Paket dependency manager
.paket/paket.exe
paket-files/
# FAKE - F# Make
.fake/
# JetBrains Rider
.idea/
*.sln.iml
# CodeRush personal settings
.cr/personal
# Python Tools for Visual Studio (PTVS)
__pycache__/
*.pyc
# Cake - Uncomment if you are using it
# tools/**
# !tools/packages.config
# Tabs Studio
*.tss
# Telerik's JustMock configuration file
*.jmconfig
# BizTalk build output
*.btp.cs
*.btm.cs
*.odx.cs
*.xsd.cs
# OpenCover UI analysis results
OpenCover/
# Azure Stream Analytics local run output
ASALocalRun/
# MSBuild Binary and Structured Log
*.binlog
# NVidia Nsight GPU debugger configuration file
*.nvuser
# MFractors (Xamarin productivity tool) working folder
.mfractor/
# Local History for Visual Studio
.localhistory/
# BeatPulse healthcheck temp database
healthchecksdb

15
.pylintrc Normal file
Просмотреть файл

@ -0,0 +1,15 @@
[MASTER]
ignore=model_graph,
receptive_field,
saliency
[TYPECHECK]
ignored-modules=numpy,torch,matplotlib,pyplot,zmq
[MESSAGES CONTROL]
disable=protected-access,
broad-except,
global-statement,
fixme,
C,
R

5
AUTHORS.md Normal file
Просмотреть файл

@ -0,0 +1,5 @@
# Authors
TensorWatch was originally conceived and created by [Shital Shah](https://www.shitalshah.com) during early 2019.
List of all contributors since our first release in May 2019 can be [found here](https://github.com/Microsoft/tensorwatch/graphs/contributors).

6
CHANGELOG.md Normal file
Просмотреть файл

@ -0,0 +1,6 @@
# What's new
Below is summarized list of important changes. This does not include minor/less important changes or bug fixes or documentation update. This list updated every few months. For complete detailed changes, please review [commit history](https://github.com/Microsoft/tensorwatch/commits/master).
### May, 2019
* First release!

14
CONTRIBUTING.md Normal file
Просмотреть файл

@ -0,0 +1,14 @@
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
and actually do, grant us the rights to use your contribution. For details, visit
https://cla.microsoft.com.
When you submit a pull request, a CLA-bot will automatically determine whether you need
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

18
ISSUE_TEMPLATE.md Normal file
Просмотреть файл

@ -0,0 +1,18 @@
# Read This First
## If you are reporting a bug
* Make sure to write **all reproduction steps**
* Include full error message in text form
* Search issues for error message before filing issue
* Attach screenshot if applicable
* Include code to run if applicable
## If you have question
* Add clear and concise title
* Add OS, TensorWatch version, Python version if applicable
* Include context on what you are trying to achieve
* Include details of what you already did to find answers
**What's better than filing issue? Filing a pull request :).**
------------------------------------ (Remove above before filing the issue) ------------------------------------

21
LICENSE
Просмотреть файл

@ -1,21 +0,0 @@
MIT License
Copyright (c) Microsoft Corporation. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE

22
LICENSE.TXT Normal file
Просмотреть файл

@ -0,0 +1,22 @@
Copyright (c) Microsoft Corporation. All rights reserved.
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

127
NOTICE.md Normal file
Просмотреть файл

@ -0,0 +1,127 @@
NOTICES AND INFORMATION
Do Not Translate or Localize
This software incorporates material from third parties. Microsoft makes certain
open source code available at http://3rdpartysource.microsoft.com, or you may
send a check or money order for US $5.00, including the product name, the open
source component name, and version number, to:
Source Code Compliance Team
Microsoft Corporation
One Microsoft Way
Redmond, WA 98052
USA
Notwithstanding any other terms, you may reverse engineer this software to the
extent required to debug changes to any libraries licensed under the GNU Lesser
General Public License.
**Component.** https://github.com/Swall0w/torchstat
**Open Source License/Copyright Notice.**
MIT License
Copyright (c) 2018 Swall0w - Alan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
**Component.** https://github.com/waleedka/hiddenlayer
**Open Source License/Copyright Notice.**
MIT License
Copyright (c) 2018 Waleed Abdulla, Phil Ferriere
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
**Component.** https://github.com/yulongwang12/visual-attribution
**Open Source License/Copyright Notice.**
BSD 2-Clause License
Copyright (c) 2019, Yulong Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
**Component.** https://github.com/marcotcr/lime
**Open Source License/Copyright Notice.**
Copyright (c) 2016, Marco Tulio Correia Ribeiro
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

Просмотреть файл

@ -1,14 +1,18 @@
This package contains Python library for [tensorwatch](https://github.com/sytelus/tensorwatch).
# Contributing # Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a This project welcomes contributions and suggestions. Most contributions require you to
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us agree to a Contributor License Agreement (CLA) declaring that you have the right to,
the rights to use your contribution. For details, visit https://cla.microsoft.com. and actually do, grant us the rights to use your contribution. For details, visit
https://cla.microsoft.com.
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide When you submit a pull request, a CLA-bot will automatically determine whether you need
a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
provided by the bot. You will only need to do this once across all repos using our CLA. instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

6
SUPPORT.md Normal file
Просмотреть файл

@ -0,0 +1,6 @@
# Support
We highly recommend to take a look at source code and contribute to the project. Also please consider [contributing](CONTRIBUTING.md) new features and fixes :).
* [Join TensorWatch Facebook Group](https://www.facebook.com/groups/378075159472803/)
* [File GitHub Issue](https://github.com/Microsoft/tensorwatch/issues)

23
TODO.md Normal file
Просмотреть файл

@ -0,0 +1,23 @@
* Fix cell size issue
* Refactor _plot* interface to accept all values, for ImagePlot only use last value
* Refactor ImagePlot for arbitrary number of images with alpha, cmap
* Change tw.open -> tw.create_viz
* Make sure streams have names as key, each data point has index
* Add tw.open_viz(stream_name, from_index)_
* Add persist=device_name option for streams
* Ability to use streams in standalone mode
* tw.create_viz on server side
* tw.log for server side
* experiment with IPC channel
* confusion matrix as in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
* Speed up import
* Do linting
* live perf data
* NaN tracing
* PCA
* Remove error if MNIST notebook is on and we run fruits
* Remove 2nd image from fruits
* clear exisitng streams when starting client
* ImagePlotItem should accept numpy array or pillow or torch tensor
* image plot getting refreshed at 12hz instead of 2 hz in MNIST
* image plot doesn't title

11
install_jupyterlab.bat Normal file
Просмотреть файл

@ -0,0 +1,11 @@
conda install -c conda-forge jupyterlab nodejs
conda install ipywidgets
conda install -c plotly plotly-orca psutil
set NODE_OPTIONS=--max-old-space-size=4096
jupyter labextension install @jupyter-widgets/jupyterlab-manager --no-build
jupyter labextension install plotlywidget --no-build
jupyter labextension install @jupyterlab/plotly-extension --no-build
jupyter labextension install jupyterlab-chart-editor --no-build
jupyter lab build
set NODE_OPTIONS=

Просмотреть файл

@ -0,0 +1,89 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import tensorwatch as tw"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from regim import *\n",
"ds = DataUtils.mnist_datasets(linearize=True, train_test=False)\n",
"ds = DataUtils.sample_by_class(ds, k=50, shuffle=True, as_np=True, no_test=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"components = tw.get_tsne_components(ds)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4a8099b3e2094f2e89924e61b5ad4701",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FigureWidget({\n",
" 'data': [{'hoverinfo': 'text',\n",
" 'line': {'color': 'rgb(31, 119,…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"comp_stream = tw.ArrayStream(components)\n",
"vis = tw.Visualizer(comp_stream, vis_type='tsne', \n",
" hover_images=ds[0], hover_image_reshape=(28,28))\n",
"vis.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

3409
notebooks/mnist.ipynb Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,722 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorwatch as tw\n",
"import torchvision.models"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"alexnet_model = torchvision.models.alexnet()\n",
"vgg16_model = torchvision.models.vgg16()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: %3 Pages: 1 -->\r\n",
"<svg width=\"387pt\" height=\"1403pt\"\r\n",
" viewBox=\"0.00 0.00 387.00 1403.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 1331)\">\r\n",
"<title>%3</title>\r\n",
"<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,72 -72,-1331 315,-1331 315,72 -72,72\"/>\r\n",
"<!-- AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19</title>\r\n",
"<g id=\"a_node1\"><a xlink:title=\"{&#x27;kernel_shape&#x27;: [3, 3], &#x27;pads&#x27;: [0, 0, 0, 0], &#x27;strides&#x27;: [2, 2]}\">\r\n",
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"188,-1176 114,-1176 114,-1140 188,-1140 188,-1176\"/>\r\n",
"<text text-anchor=\"start\" x=\"122\" y=\"-1155\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">MaxPool3x3</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- 5377442926079053455 -->\r\n",
"<g id=\"node16\" class=\"node\"><title>5377442926079053455</title>\r\n",
"<g id=\"a_node16\"><a xlink:title=\"{&#x27;dilations&#x27;: [1, 1], &#x27;group&#x27;: 1, &#x27;kernel_shape&#x27;: [5, 5], &#x27;pads&#x27;: [2, 2, 2, 2], &#x27;strides&#x27;: [1, 1]}\">\r\n",
"<polygon fill=\"#a1c9f4\" stroke=\"#7c96bc\" points=\"197.5,-1093 104.5,-1093 104.5,-1057 197.5,-1057 197.5,-1093\"/>\r\n",
"<text text-anchor=\"start\" x=\"113\" y=\"-1072\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Conv5x5 &gt; Relu</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19&#45;&gt;5377442926079053455 -->\r\n",
"<g id=\"edge12\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19&#45;&gt;5377442926079053455</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-1139.82C151,-1129.19 151,-1115.31 151,-1103.2\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-1103.15 151,-1093.15 147.5,-1103.15 154.5,-1103.15\"/>\r\n",
"<text text-anchor=\"middle\" x=\"182\" y=\"-1114\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x64x27x27</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22 -->\r\n",
"<g id=\"node2\" class=\"node\"><title>AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22</title>\r\n",
"<g id=\"a_node2\"><a xlink:title=\"{&#x27;kernel_shape&#x27;: [3, 3], &#x27;pads&#x27;: [0, 0, 0, 0], &#x27;strides&#x27;: [2, 2]}\">\r\n",
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"188,-1010 114,-1010 114,-974 188,-974 188,-1010\"/>\r\n",
"<text text-anchor=\"start\" x=\"122\" y=\"-989\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">MaxPool3x3</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- 7417167147267928641 -->\r\n",
"<g id=\"node19\" class=\"node\"><title>7417167147267928641</title>\r\n",
"<g id=\"a_node19\"><a xlink:title=\"{&#x27;dilations&#x27;: [1, 1], &#x27;group&#x27;: 1, &#x27;kernel_shape&#x27;: [3, 3], &#x27;pads&#x27;: [1, 1, 1, 1], &#x27;strides&#x27;: [1, 1]}\">\r\n",
"<polygon fill=\"#a1c9f4\" stroke=\"#7c96bc\" points=\"197.5,-927 104.5,-927 104.5,-883 197.5,-883 197.5,-927\"/>\r\n",
"<text text-anchor=\"start\" x=\"113\" y=\"-911\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Conv3x3 &gt; Relu</text>\r\n",
"<text text-anchor=\"start\" x=\"182\" y=\"-890\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">x3</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22&#45;&gt;7417167147267928641 -->\r\n",
"<g id=\"edge18\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22&#45;&gt;7417167147267928641</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-973.799C151,-963.369 151,-949.742 151,-937.443\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-937.09 151,-927.09 147.5,-937.09 154.5,-937.09\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-948\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x192x13x13</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29</title>\r\n",
"<g id=\"a_node3\"><a xlink:title=\"{&#x27;kernel_shape&#x27;: [3, 3], &#x27;pads&#x27;: [0, 0, 0, 0], &#x27;strides&#x27;: [2, 2]}\">\r\n",
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"188,-836 114,-836 114,-800 188,-800 188,-836\"/>\r\n",
"<text text-anchor=\"start\" x=\"122\" y=\"-815\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">MaxPool3x3</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/31 -->\r\n",
"<g id=\"node5\" class=\"node\"><title>AlexNet/outputs/31</title>\r\n",
"<polygon fill=\"#d0bbff\" stroke=\"#7c96bc\" points=\"147,-753 93,-753 93,-717 147,-717 147,-753\"/>\r\n",
"<text text-anchor=\"start\" x=\"105\" y=\"-732\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Shape</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29&#45;&gt;AlexNet/outputs/31 -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29&#45;&gt;AlexNet/outputs/31</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M129.203,-799.762C124.256,-794.629 119.675,-788.594 117,-782 114.648,-776.204 114.012,-769.66 114.263,-763.364\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"117.769,-763.501 115.377,-753.179 110.811,-762.74 117.769,-763.501\"/>\r\n",
"<text text-anchor=\"middle\" x=\"145\" y=\"-774\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x256x6x6</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/37 -->\r\n",
"<g id=\"node11\" class=\"node\"><title>AlexNet/outputs/37</title>\r\n",
"<polygon fill=\"#fffea3\" stroke=\"#7c96bc\" points=\"194,-451 136,-451 136,-415 194,-415 194,-451\"/>\r\n",
"<text text-anchor=\"start\" x=\"144\" y=\"-430\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Reshape</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29&#45;&gt;AlexNet/outputs/37 -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29&#45;&gt;AlexNet/outputs/37</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M163.591,-799.957C173.972,-784.197 187,-759.699 187,-736 187,-736 187,-736 187,-505 187,-489.781 182.331,-473.515 177.294,-460.398\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"180.489,-458.962 173.443,-451.05 174.016,-461.629 180.489,-458.962\"/>\r\n",
"<text text-anchor=\"middle\" x=\"215\" y=\"-618\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x256x6x6</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/30 -->\r\n",
"<g id=\"node4\" class=\"node\"><title>AlexNet/outputs/30</title>\r\n",
"<g id=\"a_node4\"><a xlink:title=\"{&#x27;value&#x27;: tensor(0)}\">\r\n",
"<polygon fill=\"#ff9f9b\" stroke=\"#7c96bc\" points=\"75,-753 13,-753 13,-717 75,-717 75,-753\"/>\r\n",
"<text text-anchor=\"start\" x=\"21\" y=\"-732\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Constant</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/32 -->\r\n",
"<g id=\"node6\" class=\"node\"><title>AlexNet/outputs/32</title>\r\n",
"<g id=\"a_node6\"><a xlink:title=\"{&#x27;axis&#x27;: 0}\">\r\n",
"<polygon fill=\"#debb9b\" stroke=\"#7c96bc\" points=\"147,-680 93,-680 93,-644 147,-644 147,-680\"/>\r\n",
"<text text-anchor=\"start\" x=\"103\" y=\"-659\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Gather</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/30&#45;&gt;AlexNet/outputs/32 -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>AlexNet/outputs/30&#45;&gt;AlexNet/outputs/32</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M62.3975,-716.813C72.0128,-707.83 83.9339,-696.693 94.4323,-686.886\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"96.8539,-689.413 101.772,-680.029 92.0752,-684.298 96.8539,-689.413\"/>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/31&#45;&gt;AlexNet/outputs/32 -->\r\n",
"<g id=\"edge4\" class=\"edge\"><title>AlexNet/outputs/31&#45;&gt;AlexNet/outputs/32</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M120,-716.813C120,-708.789 120,-699.047 120,-690.069\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"123.5,-690.029 120,-680.029 116.5,-690.029 123.5,-690.029\"/>\r\n",
"</g>\r\n",
"<!-- /outputs/34 -->\r\n",
"<g id=\"node8\" class=\"node\"><title>/outputs/34</title>\r\n",
"<g id=\"a_node8\"><a xlink:title=\"{&#x27;axes&#x27;: [0]}\">\r\n",
"<polygon fill=\"#bcd6fc\" stroke=\"#7c96bc\" points=\"158,-597 88,-597 88,-561 158,-561 158,-597\"/>\r\n",
"<text text-anchor=\"start\" x=\"96\" y=\"-576\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Unsqueeze</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/32&#45;&gt;/outputs/34 -->\r\n",
"<g id=\"edge5\" class=\"edge\"><title>AlexNet/outputs/32&#45;&gt;/outputs/34</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M120.636,-643.822C121.03,-633.19 121.544,-619.306 121.992,-607.204\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"125.492,-607.276 122.365,-597.153 118.497,-607.017 125.492,-607.276\"/>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/33 -->\r\n",
"<g id=\"node7\" class=\"node\"><title>AlexNet/outputs/33</title>\r\n",
"<g id=\"a_node7\"><a xlink:title=\"{&#x27;value&#x27;: tensor(9216)}\">\r\n",
"<polygon fill=\"#ff9f9b\" stroke=\"#7c96bc\" points=\"66,-680 4,-680 4,-644 66,-644 66,-680\"/>\r\n",
"<text text-anchor=\"start\" x=\"12\" y=\"-659\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Constant</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- /outputs/35 -->\r\n",
"<g id=\"node9\" class=\"node\"><title>/outputs/35</title>\r\n",
"<g id=\"a_node9\"><a xlink:title=\"{&#x27;axes&#x27;: [0]}\">\r\n",
"<polygon fill=\"#bcd6fc\" stroke=\"#7c96bc\" points=\"70,-597 0,-597 0,-561 70,-561 70,-597\"/>\r\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-576\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Unsqueeze</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/33&#45;&gt;/outputs/35 -->\r\n",
"<g id=\"edge6\" class=\"edge\"><title>AlexNet/outputs/33&#45;&gt;/outputs/35</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M35,-643.822C35,-633.19 35,-619.306 35,-607.204\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"38.5001,-607.153 35,-597.153 31.5001,-607.153 38.5001,-607.153\"/>\r\n",
"</g>\r\n",
"<!-- /outputs/36 -->\r\n",
"<g id=\"node10\" class=\"node\"><title>/outputs/36</title>\r\n",
"<g id=\"a_node10\"><a xlink:title=\"{&#x27;axis&#x27;: 0}\">\r\n",
"<polygon fill=\"#d0bbff\" stroke=\"#7c96bc\" points=\"150,-524 96,-524 96,-488 150,-488 150,-524\"/>\r\n",
"<text text-anchor=\"start\" x=\"105\" y=\"-503\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Concat</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- /outputs/34&#45;&gt;/outputs/36 -->\r\n",
"<g id=\"edge7\" class=\"edge\"><title>/outputs/34&#45;&gt;/outputs/36</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M123,-560.813C123,-552.789 123,-543.047 123,-534.069\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"126.5,-534.029 123,-524.029 119.5,-534.029 126.5,-534.029\"/>\r\n",
"</g>\r\n",
"<!-- /outputs/35&#45;&gt;/outputs/36 -->\r\n",
"<g id=\"edge8\" class=\"edge\"><title>/outputs/35&#45;&gt;/outputs/36</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M56.3023,-560.813C67.652,-551.656 81.7763,-540.26 94.1015,-530.316\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"96.3088,-533.032 101.894,-524.029 91.9133,-527.584 96.3088,-533.032\"/>\r\n",
"</g>\r\n",
"<!-- /outputs/36&#45;&gt;AlexNet/outputs/37 -->\r\n",
"<g id=\"edge9\" class=\"edge\"><title>/outputs/36&#45;&gt;AlexNet/outputs/37</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M133.167,-487.813C138.12,-479.441 144.179,-469.197 149.677,-459.903\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"152.848,-461.418 154.927,-451.029 146.823,-457.854 152.848,-461.418\"/>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39 -->\r\n",
"<g id=\"node12\" class=\"node\"><title>AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39</title>\r\n",
"<g id=\"a_node12\"><a xlink:title=\"{&#x27;ratio&#x27;: 0.5}\">\r\n",
"<polygon fill=\"#b9f2f0\" stroke=\"#7c96bc\" points=\"202.5,-368 127.5,-368 127.5,-332 202.5,-332 202.5,-368\"/>\r\n",
"<text text-anchor=\"start\" x=\"136\" y=\"-347\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Dropout 0.5</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/outputs/37&#45;&gt;AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39 -->\r\n",
"<g id=\"edge10\" class=\"edge\"><title>AlexNet/outputs/37&#45;&gt;AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-414.822C165,-404.19 165,-390.306 165,-378.204\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-378.153 165,-368.153 161.5,-378.153 168.5,-378.153\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-389\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x9216</text>\r\n",
"</g>\r\n",
"<!-- 10523010716743172207 -->\r\n",
"<g id=\"node17\" class=\"node\"><title>10523010716743172207</title>\r\n",
"<g id=\"a_node17\"><a xlink:title=\"{&#x27;alpha&#x27;: 1.0, &#x27;beta&#x27;: 1.0, &#x27;transB&#x27;: 1}\">\r\n",
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"205,-285 125,-285 125,-249 205,-249 205,-285\"/>\r\n",
"<text text-anchor=\"start\" x=\"133\" y=\"-264\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Linear &gt; Relu</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39&#45;&gt;10523010716743172207 -->\r\n",
"<g id=\"edge14\" class=\"edge\"><title>AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39&#45;&gt;10523010716743172207</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-331.822C165,-321.19 165,-307.306 165,-295.204\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-295.153 165,-285.153 161.5,-295.153 168.5,-295.153\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-306\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x9216</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43 -->\r\n",
"<g id=\"node13\" class=\"node\"><title>AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43</title>\r\n",
"<g id=\"a_node13\"><a xlink:title=\"{&#x27;ratio&#x27;: 0.5}\">\r\n",
"<polygon fill=\"#b9f2f0\" stroke=\"#7c96bc\" points=\"202.5,-202 127.5,-202 127.5,-166 202.5,-166 202.5,-202\"/>\r\n",
"<text text-anchor=\"start\" x=\"136\" y=\"-181\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Dropout 0.5</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- 4117491511057718684 -->\r\n",
"<g id=\"node18\" class=\"node\"><title>4117491511057718684</title>\r\n",
"<g id=\"a_node18\"><a xlink:title=\"{&#x27;alpha&#x27;: 1.0, &#x27;beta&#x27;: 1.0, &#x27;transB&#x27;: 1}\">\r\n",
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"205,-119 125,-119 125,-83 205,-83 205,-119\"/>\r\n",
"<text text-anchor=\"start\" x=\"133\" y=\"-98\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Linear &gt; Relu</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43&#45;&gt;4117491511057718684 -->\r\n",
"<g id=\"edge16\" class=\"edge\"><title>AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43&#45;&gt;4117491511057718684</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-165.822C165,-155.19 165,-141.306 165,-129.204\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-129.153 165,-119.153 161.5,-129.153 168.5,-129.153\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-140\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x4096</text>\r\n",
"</g>\r\n",
"<!-- AlexNet/Sequential[classifier]/ReLU[5]/outputs/46 -->\r\n",
"<g id=\"node14\" class=\"node\"><title>AlexNet/Sequential[classifier]/ReLU[5]/outputs/46</title>\r\n",
"<g id=\"a_node14\"><a xlink:title=\"{&#x27;alpha&#x27;: 1.0, &#x27;beta&#x27;: 1.0, &#x27;transB&#x27;: 1}\">\r\n",
"<polygon fill=\"#4878d0\" stroke=\"#7c96bc\" points=\"192,-36 138,-36 138,-0 192,-0 192,-36\"/>\r\n",
"<text text-anchor=\"start\" x=\"150\" y=\"-15\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Linear</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- 10377602221935690008 -->\r\n",
"<g id=\"node15\" class=\"node\"><title>10377602221935690008</title>\r\n",
"<g id=\"a_node15\"><a xlink:title=\"{&#x27;dilations&#x27;: [1, 1], &#x27;group&#x27;: 1, &#x27;kernel_shape&#x27;: [11, 11], &#x27;pads&#x27;: [2, 2, 2, 2], &#x27;strides&#x27;: [4, 4]}\">\r\n",
"<polygon fill=\"#a1c9f4\" stroke=\"#7c96bc\" points=\"203.5,-1259 98.5,-1259 98.5,-1223 203.5,-1223 203.5,-1259\"/>\r\n",
"<text text-anchor=\"start\" x=\"107\" y=\"-1238\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Conv11x11 &gt; Relu</text>\r\n",
"</a>\r\n",
"</g>\r\n",
"</g>\r\n",
"<!-- 10377602221935690008&#45;&gt;AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19 -->\r\n",
"<g id=\"edge11\" class=\"edge\"><title>10377602221935690008&#45;&gt;AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-1222.82C151,-1212.19 151,-1198.31 151,-1186.2\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-1186.15 151,-1176.15 147.5,-1186.15 154.5,-1186.15\"/>\r\n",
"<text text-anchor=\"middle\" x=\"182\" y=\"-1197\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x64x55x55</text>\r\n",
"</g>\r\n",
"<!-- 5377442926079053455&#45;&gt;AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22 -->\r\n",
"<g id=\"edge13\" class=\"edge\"><title>5377442926079053455&#45;&gt;AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-1056.82C151,-1046.19 151,-1032.31 151,-1020.2\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-1020.15 151,-1010.15 147.5,-1020.15 154.5,-1020.15\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-1031\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x192x27x27</text>\r\n",
"</g>\r\n",
"<!-- 10523010716743172207&#45;&gt;AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43 -->\r\n",
"<g id=\"edge15\" class=\"edge\"><title>10523010716743172207&#45;&gt;AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-248.822C165,-238.19 165,-224.306 165,-212.204\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-212.153 165,-202.153 161.5,-212.153 168.5,-212.153\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-223\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x4096</text>\r\n",
"</g>\r\n",
"<!-- 4117491511057718684&#45;&gt;AlexNet/Sequential[classifier]/ReLU[5]/outputs/46 -->\r\n",
"<g id=\"edge17\" class=\"edge\"><title>4117491511057718684&#45;&gt;AlexNet/Sequential[classifier]/ReLU[5]/outputs/46</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-82.822C165,-72.1903 165,-58.306 165,-46.2035\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-46.1532 165,-36.1533 161.5,-46.1533 168.5,-46.1532\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-57\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x4096</text>\r\n",
"</g>\r\n",
"<!-- 7417167147267928641&#45;&gt;AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29 -->\r\n",
"<g id=\"edge19\" class=\"edge\"><title>7417167147267928641&#45;&gt;AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29</title>\r\n",
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-882.989C151,-871.923 151,-858.219 151,-846.336\"/>\r\n",
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-846.062 151,-836.062 147.5,-846.062 154.5,-846.062\"/>\r\n",
"<text text-anchor=\"middle\" x=\"185\" y=\"-857\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x256x13x13</text>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<tensorwatch.model_graph.hiddenlayer.graph.Graph at 0x1d63302d5f8>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tw.draw_model(alexnet_model, [1, 3, 224, 224])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[MAdd]: Dropout is not supported!\n",
"[Flops]: Dropout is not supported!\n",
"[Memory]: Dropout is not supported!\n",
"[MAdd]: Dropout is not supported!\n",
"[Flops]: Dropout is not supported!\n",
"[Memory]: Dropout is not supported!\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>module name</th>\n",
" <th>input shape</th>\n",
" <th>output shape</th>\n",
" <th>params</th>\n",
" <th>memory(MB)</th>\n",
" <th>MAdd</th>\n",
" <th>Flops</th>\n",
" <th>MemRead(B)</th>\n",
" <th>MemWrite(B)</th>\n",
" <th>duration[%]</th>\n",
" <th>MemR+W(B)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>features.0</td>\n",
" <td>3 224 224</td>\n",
" <td>64 55 55</td>\n",
" <td>23296.0</td>\n",
" <td>0.74</td>\n",
" <td>140,553,600.0</td>\n",
" <td>70,470,400.0</td>\n",
" <td>695296.0</td>\n",
" <td>774400.0</td>\n",
" <td>9.08%</td>\n",
" <td>1469696.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>features.1</td>\n",
" <td>64 55 55</td>\n",
" <td>64 55 55</td>\n",
" <td>0.0</td>\n",
" <td>0.74</td>\n",
" <td>193,600.0</td>\n",
" <td>193,600.0</td>\n",
" <td>774400.0</td>\n",
" <td>774400.0</td>\n",
" <td>2.59%</td>\n",
" <td>1548800.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>features.2</td>\n",
" <td>64 55 55</td>\n",
" <td>64 27 27</td>\n",
" <td>0.0</td>\n",
" <td>0.18</td>\n",
" <td>373,248.0</td>\n",
" <td>193,600.0</td>\n",
" <td>774400.0</td>\n",
" <td>186624.0</td>\n",
" <td>6.48%</td>\n",
" <td>961024.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>features.3</td>\n",
" <td>64 27 27</td>\n",
" <td>192 27 27</td>\n",
" <td>307392.0</td>\n",
" <td>0.53</td>\n",
" <td>447,897,600.0</td>\n",
" <td>224,088,768.0</td>\n",
" <td>1416192.0</td>\n",
" <td>559872.0</td>\n",
" <td>10.38%</td>\n",
" <td>1976064.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>features.4</td>\n",
" <td>192 27 27</td>\n",
" <td>192 27 27</td>\n",
" <td>0.0</td>\n",
" <td>0.53</td>\n",
" <td>139,968.0</td>\n",
" <td>139,968.0</td>\n",
" <td>559872.0</td>\n",
" <td>559872.0</td>\n",
" <td>1.30%</td>\n",
" <td>1119744.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>features.5</td>\n",
" <td>192 27 27</td>\n",
" <td>192 13 13</td>\n",
" <td>0.0</td>\n",
" <td>0.12</td>\n",
" <td>259,584.0</td>\n",
" <td>139,968.0</td>\n",
" <td>559872.0</td>\n",
" <td>129792.0</td>\n",
" <td>3.89%</td>\n",
" <td>689664.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>features.6</td>\n",
" <td>192 13 13</td>\n",
" <td>384 13 13</td>\n",
" <td>663936.0</td>\n",
" <td>0.25</td>\n",
" <td>224,280,576.0</td>\n",
" <td>112,205,184.0</td>\n",
" <td>2785536.0</td>\n",
" <td>259584.0</td>\n",
" <td>5.19%</td>\n",
" <td>3045120.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>features.7</td>\n",
" <td>384 13 13</td>\n",
" <td>384 13 13</td>\n",
" <td>0.0</td>\n",
" <td>0.25</td>\n",
" <td>64,896.0</td>\n",
" <td>64,896.0</td>\n",
" <td>259584.0</td>\n",
" <td>259584.0</td>\n",
" <td>0.00%</td>\n",
" <td>519168.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>features.8</td>\n",
" <td>384 13 13</td>\n",
" <td>256 13 13</td>\n",
" <td>884992.0</td>\n",
" <td>0.17</td>\n",
" <td>299,040,768.0</td>\n",
" <td>149,563,648.0</td>\n",
" <td>3799552.0</td>\n",
" <td>173056.0</td>\n",
" <td>10.37%</td>\n",
" <td>3972608.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>features.9</td>\n",
" <td>256 13 13</td>\n",
" <td>256 13 13</td>\n",
" <td>0.0</td>\n",
" <td>0.17</td>\n",
" <td>43,264.0</td>\n",
" <td>43,264.0</td>\n",
" <td>173056.0</td>\n",
" <td>173056.0</td>\n",
" <td>1.30%</td>\n",
" <td>346112.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>features.10</td>\n",
" <td>256 13 13</td>\n",
" <td>256 13 13</td>\n",
" <td>590080.0</td>\n",
" <td>0.17</td>\n",
" <td>199,360,512.0</td>\n",
" <td>99,723,520.0</td>\n",
" <td>2533376.0</td>\n",
" <td>173056.0</td>\n",
" <td>11.67%</td>\n",
" <td>2706432.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>features.11</td>\n",
" <td>256 13 13</td>\n",
" <td>256 13 13</td>\n",
" <td>0.0</td>\n",
" <td>0.17</td>\n",
" <td>43,264.0</td>\n",
" <td>43,264.0</td>\n",
" <td>173056.0</td>\n",
" <td>173056.0</td>\n",
" <td>0.00%</td>\n",
" <td>346112.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>features.12</td>\n",
" <td>256 13 13</td>\n",
" <td>256 6 6</td>\n",
" <td>0.0</td>\n",
" <td>0.04</td>\n",
" <td>73,728.0</td>\n",
" <td>43,264.0</td>\n",
" <td>173056.0</td>\n",
" <td>36864.0</td>\n",
" <td>1.30%</td>\n",
" <td>209920.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>classifier.0</td>\n",
" <td>9216</td>\n",
" <td>9216</td>\n",
" <td>0.0</td>\n",
" <td>0.04</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00%</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>classifier.1</td>\n",
" <td>9216</td>\n",
" <td>4096</td>\n",
" <td>37752832.0</td>\n",
" <td>0.02</td>\n",
" <td>75,493,376.0</td>\n",
" <td>37,748,736.0</td>\n",
" <td>151048192.0</td>\n",
" <td>16384.0</td>\n",
" <td>22.82%</td>\n",
" <td>151064576.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>classifier.2</td>\n",
" <td>4096</td>\n",
" <td>4096</td>\n",
" <td>0.0</td>\n",
" <td>0.02</td>\n",
" <td>4,096.0</td>\n",
" <td>4,096.0</td>\n",
" <td>16384.0</td>\n",
" <td>16384.0</td>\n",
" <td>0.00%</td>\n",
" <td>32768.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>classifier.3</td>\n",
" <td>4096</td>\n",
" <td>4096</td>\n",
" <td>0.0</td>\n",
" <td>0.02</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.00%</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>classifier.4</td>\n",
" <td>4096</td>\n",
" <td>4096</td>\n",
" <td>16781312.0</td>\n",
" <td>0.02</td>\n",
" <td>33,550,336.0</td>\n",
" <td>16,777,216.0</td>\n",
" <td>67141632.0</td>\n",
" <td>16384.0</td>\n",
" <td>9.76%</td>\n",
" <td>67158016.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>classifier.5</td>\n",
" <td>4096</td>\n",
" <td>4096</td>\n",
" <td>0.0</td>\n",
" <td>0.02</td>\n",
" <td>4,096.0</td>\n",
" <td>4,096.0</td>\n",
" <td>16384.0</td>\n",
" <td>16384.0</td>\n",
" <td>1.30%</td>\n",
" <td>32768.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>classifier.6</td>\n",
" <td>4096</td>\n",
" <td>1000</td>\n",
" <td>4097000.0</td>\n",
" <td>0.00</td>\n",
" <td>8,191,000.0</td>\n",
" <td>4,096,000.0</td>\n",
" <td>16404384.0</td>\n",
" <td>4000.0</td>\n",
" <td>2.59%</td>\n",
" <td>16408384.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>total</th>\n",
" <td></td>\n",
" <td></td>\n",
" <td></td>\n",
" <td>61100840.0</td>\n",
" <td>4.15</td>\n",
" <td>1,429,567,512.0</td>\n",
" <td>715,543,488.0</td>\n",
" <td>16404384.0</td>\n",
" <td>4000.0</td>\n",
" <td>100.00%</td>\n",
" <td>253606976.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)\n",
"0 features.0 3 224 224 64 55 55 23296.0 0.74 140,553,600.0 70,470,400.0 695296.0 774400.0 9.08% 1469696.0\n",
"1 features.1 64 55 55 64 55 55 0.0 0.74 193,600.0 193,600.0 774400.0 774400.0 2.59% 1548800.0\n",
"2 features.2 64 55 55 64 27 27 0.0 0.18 373,248.0 193,600.0 774400.0 186624.0 6.48% 961024.0\n",
"3 features.3 64 27 27 192 27 27 307392.0 0.53 447,897,600.0 224,088,768.0 1416192.0 559872.0 10.38% 1976064.0\n",
"4 features.4 192 27 27 192 27 27 0.0 0.53 139,968.0 139,968.0 559872.0 559872.0 1.30% 1119744.0\n",
"5 features.5 192 27 27 192 13 13 0.0 0.12 259,584.0 139,968.0 559872.0 129792.0 3.89% 689664.0\n",
"6 features.6 192 13 13 384 13 13 663936.0 0.25 224,280,576.0 112,205,184.0 2785536.0 259584.0 5.19% 3045120.0\n",
"7 features.7 384 13 13 384 13 13 0.0 0.25 64,896.0 64,896.0 259584.0 259584.0 0.00% 519168.0\n",
"8 features.8 384 13 13 256 13 13 884992.0 0.17 299,040,768.0 149,563,648.0 3799552.0 173056.0 10.37% 3972608.0\n",
"9 features.9 256 13 13 256 13 13 0.0 0.17 43,264.0 43,264.0 173056.0 173056.0 1.30% 346112.0\n",
"10 features.10 256 13 13 256 13 13 590080.0 0.17 199,360,512.0 99,723,520.0 2533376.0 173056.0 11.67% 2706432.0\n",
"11 features.11 256 13 13 256 13 13 0.0 0.17 43,264.0 43,264.0 173056.0 173056.0 0.00% 346112.0\n",
"12 features.12 256 13 13 256 6 6 0.0 0.04 73,728.0 43,264.0 173056.0 36864.0 1.30% 209920.0\n",
"13 classifier.0 9216 9216 0.0 0.04 0.0 0.0 0.0 0.0 0.00% 0.0\n",
"14 classifier.1 9216 4096 37752832.0 0.02 75,493,376.0 37,748,736.0 151048192.0 16384.0 22.82% 151064576.0\n",
"15 classifier.2 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0\n",
"16 classifier.3 4096 4096 0.0 0.02 0.0 0.0 0.0 0.0 0.00% 0.0\n",
"17 classifier.4 4096 4096 16781312.0 0.02 33,550,336.0 16,777,216.0 67141632.0 16384.0 9.76% 67158016.0\n",
"18 classifier.5 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 1.30% 32768.0\n",
"19 classifier.6 4096 1000 4097000.0 0.00 8,191,000.0 4,096,000.0 16404384.0 4000.0 2.59% 16408384.0\n",
"total 61100840.0 4.15 1,429,567,512.0 715,543,488.0 16404384.0 4000.0 100.00% 253606976.0"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tw.model_stats(alexnet_model, [1, 3, 224, 224])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

28
setup.py Normal file
Просмотреть файл

@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import setuptools
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name="tensorwatch",
version="0.6.0",
author="Shital Shah",
author_email="shitals@microsoft.com",
description="Interactive Realtime Debugging and Visualization for AI",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/sytelus/tensorwatch",
packages=setuptools.find_packages(),
license='MIT',
classifiers=(
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
),
install_requires=[
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat' # , 'receptivefield'
]
)

175
tensorwatch.pyproj Normal file
Просмотреть файл

@ -0,0 +1,175 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" DefaultTargets="Build">
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<SchemaVersion>2.0</SchemaVersion>
<ProjectGuid>{cc8abc7f-ede1-4e13-b6b7-0041a5ec66a7}</ProjectGuid>
<ProjectHome />
<StartupFile>tensorwatch\v2\zmq_watcher_server.py</StartupFile>
<SearchPath />
<WorkingDirectory>.</WorkingDirectory>
<OutputPath>.</OutputPath>
<ProjectTypeGuids>{888888a0-9f3d-457c-b088-3a5042f75d52}</ProjectTypeGuids>
<LaunchProvider>Standard Python launcher</LaunchProvider>
<InterpreterId>Global|ContinuumAnalytics|Anaconda36-64</InterpreterId>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)' == 'Debug'" />
<PropertyGroup Condition="'$(Configuration)' == 'Release'" />
<PropertyGroup>
<VisualStudioVersion Condition=" '$(VisualStudioVersion)' == '' ">10.0</VisualStudioVersion>
</PropertyGroup>
<ItemGroup>
<Compile Include="tensorwatch\embeddings\__init__.py" />
<Compile Include="tensorwatch\stream_union.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\file_stream.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\filtered_stream.py" />
<Compile Include="tensorwatch\stream.py" />
<Compile Include="tensorwatch\stream_factory.py" />
<Compile Include="tensorwatch\visualizer.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\vis_base.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\saliency\lime\lime_base.py" />
<Compile Include="tensorwatch\saliency\lime\lime_image.py" />
<Compile Include="tensorwatch\saliency\lime\wrappers\scikit_image.py" />
<Compile Include="tensorwatch\saliency\lime\wrappers\__init__.py" />
<Compile Include="tensorwatch\saliency\lime\__init__.py" />
<Compile Include="tensorwatch\pytorch_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\array_stream.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\data_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\imagenet_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\image_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\model_graph\hiddenlayer\ge.py" />
<Compile Include="tensorwatch\model_graph\hiddenlayer\graph.py" />
<Compile Include="tensorwatch\model_graph\hiddenlayer\pytorch_builder.py" />
<Compile Include="tensorwatch\model_graph\hiddenlayer\tf_builder.py" />
<Compile Include="tensorwatch\model_graph\hiddenlayer\transforms.py" />
<Compile Include="tensorwatch\model_graph\hiddenlayer\__init__.py" />
<Compile Include="tensorwatch\model_graph\torchstat_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\model_graph\__init__.py" />
<Compile Include="tensorwatch\mpl\base_mpl_plot.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\evaler.py" />
<Compile Include="tensorwatch\evaler_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\mpl\image_plot.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\lv_types.py" />
<Compile Include="tensorwatch\mpl\line_plot.py" />
<Compile Include="tensorwatch\mpl\__init__.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\plotly\base_plotly_plot.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\plotly\line_plot.py" />
<Compile Include="tensorwatch\plotly\embeddings_plot.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\plotly\__init__.py" />
<Compile Include="tensorwatch\receptive_field\rf_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\repeated_timer.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\embeddings\tsne_utils.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\saliency\backprop.py" />
<Compile Include="tensorwatch\saliency\deeplift.py" />
<Compile Include="tensorwatch\saliency\epsilon_lrp.py" />
<Compile Include="tensorwatch\saliency\gradcam.py" />
<Compile Include="tensorwatch\saliency\inverter_util.py" />
<Compile Include="tensorwatch\saliency\lime_image_explainer.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\saliency\occlusion.py" />
<Compile Include="tensorwatch\saliency\saliency.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\saliency\__init__.py" />
<Compile Include="tensorwatch\text_vis.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\utils.py" />
<Compile Include="tensorwatch\watcher_base.py" />
<Compile Include="tensorwatch\zmq_mgmt_stream.py">
<SubType>Code</SubType>
</Compile>
<Compile Include="tensorwatch\zmq_wrapper.py" />
<Compile Include="tensorwatch\zmq_stream.py" />
<Compile Include="tensorwatch\watcher_client.py" />
<Compile Include="tensorwatch\watcher.py" />
<Compile Include="tensorwatch\__init__.py" />
<Compile Include="setup.py" />
</ItemGroup>
<ItemGroup>
<Folder Include="tensorwatch" />
<Folder Include="tensorwatch\saliency\lime\wrappers\" />
<Folder Include="tensorwatch\model_graph\" />
<Folder Include="tensorwatch\model_graph\hiddenlayer\" />
<Folder Include="tensorwatch\model_graph\hiddenlayer\__pycache__\" />
<Folder Include="tensorwatch\mpl\" />
<Folder Include="tensorwatch\embeddings\" />
<Folder Include="tensorwatch\saliency\lime\" />
<Folder Include="tensorwatch\receptive_field\" />
<Folder Include="tensorwatch\plotly\" />
<Folder Include="tensorwatch\saliency\" />
</ItemGroup>
<ItemGroup>
<InterpreterReference Include="Global|ContinuumAnalytics|Anaconda36-64" />
</ItemGroup>
<ItemGroup>
<Content Include=".gitignore" />
<Content Include=".pylintrc">
<SubType>Code</SubType>
</Content>
<Content Include="CONTRIBUTING.md">
<SubType>Code</SubType>
</Content>
<Content Include="install_jupyterlab.bat" />
<Content Include="LICENSE.TXT">
<SubType>Code</SubType>
</Content>
<Content Include="NOTICE.md">
<SubType>Code</SubType>
</Content>
<Content Include="README.md" />
<Content Include="tensorwatch\model_graph\hiddenlayer\README.md">
<SubType>Code</SubType>
</Content>
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\ge.cpython-36.pyc" />
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\graph.cpython-36.pyc" />
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\pytorch_builder.cpython-36.pyc" />
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\transforms.cpython-36.pyc" />
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\__init__.cpython-36.pyc" />
<Content Include="tensorwatch\saliency\README.md">
<SubType>Code</SubType>
</Content>
<Content Include="TODO.md" />
<Content Include="update_package.bat" />
</ItemGroup>
<Import Project="$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets" />
</Project>

23
tensorwatch.sln Normal file
Просмотреть файл

@ -0,0 +1,23 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.27428.2043
MinimumVisualStudioVersion = 10.0.40219.1
Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "tensorwatch", "tensorwatch.pyproj", "{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Release|Any CPU.ActiveCfg = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {99E7AEC7-2CDE-48C8-B98B-4E28E4F840B6}
EndGlobalSection
EndGlobal

34
tensorwatch/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterable, Sequence, Union
from .watcher_client import WatcherClient
from .watcher import Watcher
from .watcher_base import WatcherBase
from .text_vis import TextVis
from .plotly import EmbeddingsPlot
from .mpl import LinePlot, ImagePlot
from .visualizer import Visualizer
from .stream import Stream
from .array_stream import ArrayStream
from .lv_types import ImagePlotItem, VisParams
from . import utils
###### Import methods for tw namespace #########
#from .receptive_field.rf_utils import plot_receptive_field, plot_grads_at
from .embeddings.tsne_utils import get_tsne_components
from .model_graph.torchstat_utils import model_stats
from .image_utils import show_image, open_image, img2pyt, linear_to_2d, plt_loop
from .data_utils import pyt_ds2list, sample_by_class, col2array, search_similar
def draw_model(model, input_shape=None, orientation='TB'): #orientation = 'LR' for landscpe
from .model_graph.hiddenlayer import graph
g = graph.build_graph(model, input_shape, orientation=orientation)
return g

Просмотреть файл

@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .stream import Stream
from .lv_types import StreamItem
import uuid
class ArrayStream(Stream):
def __init__(self, array, stream_name:str=None, console_debug:bool=False):
super(ArrayStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
self.stream_name = stream_name
self.array = array
self.creator_id = str(uuid.uuid4())
def load(self, from_stream:'Stream'=None):
if self.array is not None:
stream_item = StreamItem(item_index=0, value=self.array,
stream_name=self.stream_name, creator_id=self.creator_id, stream_index=0)
self.write(stream_item)

60
tensorwatch/data_utils.py Normal file
Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
import scipy.spatial.distance
import heapq
import numpy as np
import torch
def pyt_tensor2np(pyt_tensor):
if pyt_tensor is None:
return None
if isinstance(pyt_tensor, torch.Tensor):
n = pyt_tensor.data.cpu().numpy()
if len(n.shape) == 1:
return n[0]
else:
return n
elif isinstance(pyt_tensor, np.ndarray):
return pyt_tensor
else:
return np.array(pyt_tensor)
def pyt_tuple2np(pyt_tuple):
return tuple((pyt_tensor2np(t) for t in pyt_tuple))
def pyt_ds2list(pyt_ds, count=None):
count = count or len(pyt_ds)
return [pyt_tuple2np(t) for t, c in zip(pyt_ds, range(count))]
def sample_by_class(data, n_samples, class_col=1, shuffle=True):
if shuffle:
random.shuffle(data)
samples = {}
for i, t in enumerate(data):
cls = t[class_col]
if cls not in samples:
samples[cls] = []
if len(samples[cls]) < n_samples:
samples[cls].append(data[i])
samples = sum(samples.values(), [])
return samples
def col2array(dataset, col):
return [row[col] for row in dataset]
def search_similar(inputs, compare_to, algorithm='euclidean', topk=5, invert_score=True):
all_scores = scipy.spatial.distance.cdist(inputs, compare_to, algorithm)
all_results = []
for input_val, scores in zip(inputs, all_scores):
result = []
for i, (score, data) in enumerate(zip(scores, compare_to)):
if invert_score:
score = 1/(score + 1.0E-6)
if len(result) < topk:
heapq.heappush(result, (score, (i, input_val, data)))
else:
heapq.heappushpop(result, (score, (i, input_val, data)))
all_results.append(result)
return all_results

Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
def _standardize_data(data, col, whitten, flatten):
if col is not None:
data = data[col]
#TODO: enable auto flattening
#if data is tensor then flatten it first
#if flatten and len(data) > 0 and hasattr(data[0], 'shape') and \
# utils.has_method(data[0], 'reshape'):
# data = [d.reshape((-1,)) for d in data]
if whitten:
data = StandardScaler().fit_transform(data)
return data
def get_tsne_components(data, features_col=0, labels_col=1, whitten=True, n_components=3, perplexity=20, flatten=True, for_plot=True):
features = _standardize_data(data, features_col, whitten, flatten)
tsne = TSNE(n_components=n_components, perplexity=perplexity)
tsne_results = tsne.fit_transform(features)
if for_plot:
comps = tsne_results.tolist()
labels = data[labels_col]
for i, item in enumerate(comps):
item.append(None) # annotation
item.append(str(labels[i])) # text
item.append(labels[i]) # color
return comps
return tsne_results

123
tensorwatch/evaler.py Normal file
Просмотреть файл

@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import threading
import sys
from collections.abc import Iterator
from .lv_types import EventVars
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
# pylint: disable=unused-import
from functools import *
from itertools import *
from statistics import *
import numpy as np
from .evaler_utils import *
class Evaler:
class EvalReturn:
def __init__(self, result=None, is_valid=False, exception=None):
self.result, self.exception, self.is_valid = \
result, exception, is_valid
def reset(self):
self.result, self.exception, self.is_valid = \
None, None, False
class PostableIterator:
def __init__(self, eval_wait):
self.eval_wait = eval_wait
self.post_wait = threading.Event()
self.event_vars, self.ended = None, None # define attributes in init
self.reset()
def reset(self):
self.event_vars, self.ended = None, False
self.post_wait.clear()
def abort(self):
self.ended = True
self.post_wait.set()
def post(self, event_vars:EventVars=None, ended=False):
self.event_vars, self.ended = event_vars, ended
self.post_wait.set()
def get_vals(self):
while True:
self.post_wait.wait()
self.post_wait.clear()
if self.ended:
break
else:
yield self.event_vars
# below will cause result=None, is_valid=False when
# expression has reduce
self.eval_wait.set()
def __init__(self, expr):
self.eval_wait = threading.Event()
self.reset_wait = threading.Event()
self.g = Evaler.PostableIterator(self.eval_wait)
self.expr = expr
self.eval_return, self.continue_thread = None, None # define in __init__
self.reset()
self.th = threading.Thread(target=self._runner, daemon=True, name='evaler')
self.th.start()
self.running = True
def reset(self):
self.g.reset()
self.eval_wait.clear()
self.reset_wait.clear()
self.eval_return = Evaler.EvalReturn()
self.continue_thread = True
def _runner(self):
while True:
# this var will be used by eval
l = self.g.get_vals() # pylint: disable=unused-variable
try:
result = eval(self.expr) # pylint: disable=eval-used
if isinstance(result, Iterator):
for item in result:
self.eval_return = Evaler.EvalReturn(item, True)
else:
self.eval_return = Evaler.EvalReturn(result, True)
except Exception as ex: # pylint: disable=broad-except
print(ex, file=sys.stderr)
self.eval_return = Evaler.EvalReturn(None, True, ex)
self.eval_wait.set()
self.reset_wait.wait()
if not self.continue_thread:
break
self.reset()
self.running = False
utils.debug_log('eval runner ended!')
def abort(self):
utils.debug_log('Evaler Aborted')
self.continue_thread = False
self.g.abort()
self.eval_wait.set()
self.reset_wait.set()
def post(self, event_vars:EventVars=None, ended=False, continue_thread=True):
if not self.running:
utils.debug_log('post was called when Evaler is not running')
return None, False
self.eval_return.reset()
self.g.post(event_vars, ended)
self.eval_wait.wait()
self.eval_wait.clear()
# save result before it would get reset
eval_return = self.eval_return
self.reset_wait.set()
self.continue_thread = continue_thread
if isinstance(eval_return.result, Iterator):
eval_return.result = list(eval_return.result)
return eval_return
def join(self):
self.th.join()

109
tensorwatch/evaler_utils.py Normal file
Просмотреть файл

@ -0,0 +1,109 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import math
import random
from . import utils
from .lv_types import ImagePlotItem
from collections import OrderedDict
from itertools import groupby, islice
def skip_mod(mod, g):
for index, item in enumerate(g):
if index % mod == 0:
yield item
# sort keys, group by key, apply val function to each value in group, aggregate values
def groupby2(l, key=lambda x:x, val=lambda x:x, agg=lambda x:x, sort=True):
if sort:
l = sorted(l, key=key)
grp = ((k,v) for k,v in groupby(l, key=key))
valx = ((k, (val(x) for x in v)) for k,v in grp)
aggx = ((k, agg(v)) for k,v in valx)
return aggx
# aggregate weights or biases, use p2v to transform tensor to scaler
def agg_params(model, p2v, weight_or_bias=True):
for i, (n, p) in enumerate(model.named_parameters()):
if p.requires_grad:
is_bias = 'bias' in n
if (weight_or_bias and not is_bias) or (not weight_or_bias and is_bias):
yield i, p2v(p), n
# use this for image to class problems
def pyt_img_class_out_xform(item): # (net_input, target, in_weight, out_weight, net_output, loss)
net_input = item[0].data.cpu().numpy()
# turn log-probabilities in to (max log-probability, class ID)
net_output = torch.max(item[4],0)
# return image, text
return ImagePlotItem((net_input,), title="T:{},Pb:{:.2f},pd:{:.2f},L:{:.2f}".\
format(item[1], math.exp(net_output[0]), net_output[1], item[5]))
# use this for image to image translation problems
def pyt_img_img_out_xform(item): # (net_input, target, in_weight, out_weight, net_output, loss)
net_input = item[0].data.cpu().numpy()
net_output = item[4].data.cpu().numpy()
target = item[1].data.cpu().numpy()
tar_weight = item[3].data.cpu().numpy() if item[3] is not None else None
# return in-image, text, out-image, target-image
return ImagePlotItem((net_input, target, net_output, tar_weight),
title="L:{:.2f}, S:{:.2f}, {:.2f}-{:.2f}, {:.2f}-{:.2f}".\
format(item[5], net_input.std(), net_input.min(), net_input.max(), net_output.min(), net_output.max()))
def cols2rows(batch):
in_weight = utils.fill_like(batch.in_weight, batch.input)
tar_weight = utils.fill_like(batch.tar_weight, batch.input)
losses = [l.mean() for l in batch.loss_all]
targets = [t.item() if len(t.shape)==0 else t for t in batch.target]
return list(zip(batch.input, targets, in_weight, tar_weight,
batch.output, losses))
def top(l, topk=1, order='dsc', group_key=None, out_xform=lambda x:x):
min_result = OrderedDict()
for event_vars in l:
batch = cols2rows(event_vars.batch)
# by default group items in batch by target value
group_key = group_key or (lambda b: b[1]) #target
by_class = groupby2(batch, group_key)
# pick the first values for each class after sorting by loss
reverse, sf, ls_cmp = True, lambda b: b[5], False
if order=='asc':
reverse = False
elif order=='rnd':
ls_cmp, sf = True, lambda t: random.random()
elif order=='dsc':
pass
else:
raise ValueError('order parameter must be dsc, asc or rnd')
# sort grouped objects by sort function then
# take first k values in each group
# create (key, topk-sized list) tuples for each group
s = ((k, list(islice(sorted(v, key=sf, reverse=reverse), topk))) \
for k,v in by_class)
# for each group, maintain global k values for each keys
changed = False
for k,va in s:
# get global k values for this key, if it doesn't exist
# then put current in global min
cur_min = min_result.get(k, None)
if cur_min is None:
min_result[k] = va
changed = True
else:
# for each k value in this group, we will compare
for i, (va_k, cur_k) in enumerate(zip(va, cur_min)):
if ls_cmp or (reverse and cur_k[5] < va_k[5]) \
or (not reverse and cur_k[5] > va_k[5]):
cur_min[i] = va[i]
changed = True
if changed:
# flatten each list in dictionary value
yield (out_xform(t) for va in min_result.values() for t in va)

Просмотреть файл

@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .stream import Stream
import pickle
from typing import Any
from . import utils
class FileStream(Stream):
def __init__(self, for_write:bool, file_name:str, stream_name:str=None, console_debug:bool=False):
super(FileStream, self).__init__(stream_name=stream_name or file_name, console_debug=console_debug)
self._file = open(file_name, 'wb' if for_write else 'rb')
self.file_name = file_name
self.for_write = for_write
utils.debug_log('FileStream started', self.file_name, verbosity=1)
def close(self):
if not self._file.closed:
self._file.close()
self._file = None
utils.debug_log('FileStream is closed', self.file_name, verbosity=1)
super(FileStream, self).close()
def write(self, val:Any, from_stream:'Stream'=None):
if self.for_write:
pickle.dump(val, self._file)
super(FileStream, self).write(val)
def load(self, from_stream:'Stream'=None):
if self.for_write:
raise IOError('Cannot use load() call because FileSteam is opened with for_write=True')
if self._file is not None:
while not utils.is_eof(self._file):
stream_item = pickle.load(self._file)
self.write(stream_item)
super(FileStream, self).load()

Просмотреть файл

@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .stream import Stream
from typing import Callable, Any
class FilteredStream(Stream):
def __init__(self, source_stream:Stream, filter_expr:Callable, stream_name:str=None,
console_debug:bool=False)->None:
stream_name = stream_name or '{}|{}'.format(source_stream.stream_name, str(filter_expr))
super(FilteredStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
self.subscribe(source_stream)
self.filter_expr = filter_expr
def write(self, val:Any, from_stream:'Stream'=None):
result, is_valid = self.filter_expr(val) \
if self.filter_expr is not None \
else (val, True)
if is_valid:
return super(FilteredStream, self).write(result)
# else ignore this call

124
tensorwatch/image_utils.py Normal file
Просмотреть файл

@ -0,0 +1,124 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import numpy as np
import math
import time
def guess_image_dims(img):
if len(img.shape) == 1:
# assume 2D monochrome (MNIST)
width = height = round(math.sqrt(img.shape[0]))
if width*height != img.shape[0]:
# assume 3 channels (CFAR, ImageNet)
width = height = round(math.sqrt(img.shape[0] / 3))
if width*height*3 != img.shape[0]:
raise ValueError("Cannot guess image dimensions for linearized pixels")
return (3, height, width)
return (1, height, width)
return img.shape
def to_imshow_array(img, width=None, height=None):
# array from Pytorch has shape: [[channels,] height, width]
# image needed for imshow needs: [height, width, channels]
if img is not None:
if isinstance(img, Image.Image):
img = np.array(img)
if len(img.shape) >= 2:
return img # img is already compatible to imshow
# force max 3 dimensions
if len(img.shape) > 3:
# TODO allow config
# select first one in batch
img = img[0:1,:,:]
if len(img.shape) == 1: # linearized pixels typically used for MLPs
if not(width and height):
# pylint: disable=unused-variable
channels, height, width = guess_image_dims(img)
img = img.reshape((-1, height, width))
if len(img.shape) == 3:
if img.shape[0] == 1: # single channel images
img = img.squeeze(0)
else:
img = np.swapaxes(img, 0, 2) # transpose H,W for imshow
img = np.swapaxes(img, 0, 1)
elif len(img.shape) == 2:
img = np.swapaxes(img, 0, 1) # transpose H,W for imshow
else: #zero dimensions
img = None
return img
#width_dim=1 for imshow, 2 for pytorch arrays
def stitch_horizontal(images, width_dim=1):
return np.concatenate(images, axis=width_dim)
def _resize_image(img, size=None):
if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
if size is None:
# make guess for 1-dim tensors
h = int(math.sqrt(img.shape[0]))
w = int(img.shape[0] / h)
size = h,w
img = np.reshape(img, size)
return img
def show_image(img, size=None, alpha=None, cmap=None,
img2=None, size2=None, alpha2=None, cmap2=None, ax=None):
img =_resize_image(img, size)
img2 =_resize_image(img2, size2)
(ax or plt).imshow(img, alpha=alpha, cmap=cmap)
if img2 is not None:
(ax or plt).imshow(img2, alpha=alpha2, cmap=cmap2)
return ax or plt.show()
# convert_mode param is mode: https://pillow.readthedocs.io/en/5.1.x/handbook/concepts.html#modes
# use convert_mode='RGB' to force 3 channels
def open_image(path, resize=None, resample=Image.ANTIALIAS, convert_mode=None):
img = Image.open(path)
if resize is not None:
img = img.resize(resize, resample)
if convert_mode is not None:
img = img.convert(convert_mode)
return img
def img2pyt(img, add_batch_dim=True, resize=None):
ts = []
if resize is not None:
ts.append(transforms.RandomResizedCrop(resize))
ts.append(transforms.ToTensor())
img_pyt = transforms.Compose(ts)(img)
if add_batch_dim:
img_pyt.unsqueeze_(0)
return img_pyt
def linear_to_2d(img, size=None):
if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
if size is None:
# make guess for 1-dim tensors
h = int(math.sqrt(img.shape[0]))
w = int(img.shape[0] / h)
size = h,w
img = np.reshape(img, size)
return img
def stack_images(imgs):
return np.hstack(imgs)
def plt_loop(sleep_time=1, plt_pause=0.01):
plt.ion()
plt.show(block=False)
while(not plt.waitforbuttonpress(plt_pause)):
#plt.draw()
plt.pause(plt_pause)
time.sleep(sleep_time)

Просмотреть файл

@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torchvision import transforms
from . import pytorch_utils
import json
def get_image_transform():
transf = transforms.Compose([ #TODO: cache these transforms?
get_resize_transform(),
transforms.ToTensor(),
get_normalize_transform()
])
return transf
def get_resize_transform():
return transforms.Resize((224, 224))
def get_normalize_transform():
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def predict(model, images, image_transform=None, device=None):
logits = pytorch_utils.batch_predict(model, images,
input_transform=image_transform or get_image_transform(), device=device)
probs = pytorch_utils.logits2probabilities(logits) #2-dim array, one column per class, one row per input
return probs
_imagenet_labels = None
def get_imagenet_labels():
# pylint: disable=global-statement
global _imagenet_labels
_imagenet_labels = _imagenet_labels or ImagenetLabels()
return _imagenet_labels
def probabilities2classes(probs, topk=5):
labels = get_imagenet_labels()
top_probs = probs.topk(topk)
# return (probability, class_id, class_label, class_code)
return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \
for p, c in zip(top_probs[0][0].detach().numpy(), top_probs[1][0].detach().numpy()))
class ImagenetLabels:
def __init__(self, json_path='../../data/imagenet_class_index.json'):
self._idx2label = []
self._idx2cls = []
self._cls2label = {}
self._cls2idx = {}
with open(json_path, "r") as read_file:
class_json = json.load(read_file)
self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))]
self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))]
self._cls2label = {class_json[str(k)][0]: class_json[str(k)][1] for k in range(len(class_json))}
self._cls2idx = {class_json[str(k)][0]: k for k in range(len(class_json))}
def index2label_text(self, index):
return self._idx2label[index]
def index2label_code(self, index):
return self._idx2cls[index]
def label_code2label_text(self, label_code):
return self._cls2label[label_code]
def label_code2index(self, label_code):
return self._cls2idx[label_code]

130
tensorwatch/lv_types.py Normal file
Просмотреть файл

@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Callable, Any, Sequence
from . import utils
import uuid
class EventVars:
def __init__(self, globals_val, **vars_val):
if globals_val is not None:
for key in globals_val:
setattr(self, key, globals_val[key])
for key in vars_val:
setattr(self, key, vars_val[key])
def __str__(self):
sb = []
for key in self.__dict__:
val = self.__dict__[key]
if utils.is_scalar(val):
sb.append('{key}={value}'.format(key=key, value=val))
else:
sb.append('{key}="{value}"'.format(key=key, value=val))
return ', '.join(sb)
EventsVars = List[EventVars]
class StreamItem:
def __init__(self, item_index:int, value:Any,
stream_name:str, creator_id:str, stream_index:int,
ended:bool=False, exception:Exception=None, stream_reset:bool=False):
self.value = value
self.exception = exception
self.stream_name = stream_name
self.item_index = item_index
self.ended = ended
self.creator_id = creator_id
self.stream_index = stream_index
self.stream_reset = stream_reset
def __repr__(self):
return str(self.__dict__)
EventEvalFunc = Callable[[EventsVars], StreamItem]
class VisParams:
def __init__(self, vis_type=None, host_vis=None,
cell=None, title=None,
clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None,
images=None, images_reshape=None, width=None, height=None, vis_args=None, stream_vis_args=None)->None:
self.vis_type=vis_type
self.host_vis=host_vis,
self.cell=cell
self.title=title
self.clear_after_end=clear_after_end
self.clear_after_each=clear_after_each
self.history_len=history_len
self.dim_history=dim_history
self.opacity=opacity
self.images=images
self.images_reshape=images_reshape
self.width=width
self.height=height
self.vis_args=vis_args or {}
self.stream_vis_args=stream_vis_args or {}
class StreamOpenRequest:
def __init__(self, stream_name:str, devices:Sequence[str]=None,
event_name:str='')->None:
self.stream_name = stream_name or str(uuid.uuid4())
self.devices = devices
self.event_name = event_name
class StreamCreateRequest:
def __init__(self, stream_name:str, devices:Sequence[str]=None, event_name:str='',
expr:str=None, throttle:float=None, vis_params:VisParams=None):
self.event_name = event_name
self.expr = expr
self.stream_name = stream_name or str(uuid.uuid4())
self.devices = devices
self.vis_params = vis_params
# max throughput n Lenovo P50 laptop for MNIST
# text console -> 0.1s
# matplotlib line graph -> 0.5s
self.throttle = throttle
class ClientServerRequest:
def __init__(self, req_type:str, req_data:Any):
self.req_type = req_type
self.req_data = req_data
class CliSrvReqTypes:
create_stream = 'CreateStream'
del_stream = 'DeleteStream'
class StreamPlot:
def __init__(self, stream, title, clear_after_end,
clear_after_each, history_len, dim_history, opacity,
index, stream_vis_args, last_update):
self.stream = stream
self.title, self.opacity = title, opacity
self.clear_after_end, self.clear_after_each = clear_after_end, clear_after_each
self.history_len, self.dim_history = history_len, dim_history
self.index, self.stream_vis_args, self.last_update = index, stream_vis_args, last_update
class ImagePlotItem:
# images are numpy array of shape [[channels,] height, width]
def __init__(self, images=None, title=None, alpha=None, cmap=None):
if not isinstance(images, tuple):
images = (images,)
self.images, self.alpha, self.cmap, self.title = images, alpha, cmap, title
class DefaultPorts:
PubSub = 40859
CliSrv = 41459
class PublisherTopics:
StreamItem = 'StreamItem'
ServerMgmt = 'ServerMgmt'
class ServerMgmtMsg:
EventServerStart = 'ServerStart'
def __init__(self, event_name:str, event_args:Any=None):
self.event_name = event_name
self.event_args = event_args

Просмотреть файл

@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -0,0 +1,5 @@
# Credits
Code in this folder is adopted from,
* https://github.com/waleedka/hiddenlayer

Просмотреть файл

@ -0,0 +1 @@

Просмотреть файл

@ -0,0 +1,169 @@
"""
HiddenLayer
Implementation graph expressions to find nodes in a graph based on a pattern.
Written by Waleed Abdulla
Licensed under the MIT License
"""
import re
class GEParser():
def __init__(self, text):
self.index = 0
self.text = text
def parse(self):
return self.serial() or self.parallel() or self.expression()
def parallel(self):
index = self.index
expressions = []
while len(expressions) == 0 or self.token("|"):
e = self.expression()
if not e:
break
expressions.append(e)
if len(expressions) >= 2:
return ParallelPattern(expressions)
# No match. Reset index
self.index = index
def serial(self):
index = self.index
expressions = []
while len(expressions) == 0 or self.token(">"):
e = self.expression()
if not e:
break
expressions.append(e)
if len(expressions) >= 2:
return SerialPattern(expressions)
self.index = index
def expression(self):
index = self.index
if self.token("("):
e = self.serial() or self.parallel() or self.op()
if e and self.token(")"):
return e
self.index = index
e = self.op()
return e
def op(self):
t = self.re(r"\w+")
if t:
c = self.condition()
return NodePattern(t, c)
def condition(self):
# TODO: not implemented yet. This function is a placeholder
index = self.index
if self.token("["):
c = self.token("1x1") or self.token("3x3")
if c:
if self.token("]"):
return c
self.index = index
def token(self, s):
return self.re(r"\s*(" + re.escape(s) + r")\s*", 1)
def string(self, s):
if s == self.text[self.index:self.index+len(s)]:
self.index += len(s)
return s
def re(self, regex, group=0):
m = re.match(regex, self.text[self.index:])
if m:
self.index += len(m.group(0))
return m.group(group)
class NodePattern():
def __init__(self, op, condition=None):
self.op = op
self.condition = condition # TODO: not implemented yet
def match(self, graph, node):
if isinstance(node, list):
return [], None
if self.op == node.op:
following = graph.outgoing(node)
if len(following) == 1:
following = following[0]
return [node], following
else:
return [], None
class SerialPattern():
def __init__(self, patterns):
self.patterns = patterns
def match(self, graph, node):
all_matches = []
for i, p in enumerate(self.patterns):
matches, following = p.match(graph, node)
if not matches:
return [], None
all_matches.extend(matches)
if i < len(self.patterns) - 1:
node = following # Might be more than one node
return all_matches, following
class ParallelPattern():
def __init__(self, patterns):
self.patterns = patterns
def match(self, graph, nodes):
if not nodes:
return [], None
nodes = nodes if isinstance(nodes, list) else [nodes]
# If a single node, assume we need to match with its siblings
if len(nodes) == 1:
nodes = graph.siblings(nodes[0])
else:
# Verify all nodes have the same parent or all have no parent
parents = [graph.incoming(n) for n in nodes]
matches = [set(p) == set(parents[0]) for p in parents[1:]]
if not all(matches):
return [], None
# TODO: If more nodes than patterns, we should consider
# all permutations of the nodes
if len(self.patterns) != len(nodes):
return [], None
patterns = self.patterns.copy()
nodes = nodes.copy()
all_matches = []
end_node = None
for p in patterns:
found = False
for n in nodes:
matches, following = p.match(graph, n)
if matches:
found = True
nodes.remove(n)
all_matches.extend(matches)
# Verify all branches end in the same node
if end_node:
if end_node != following:
return [], None
else:
end_node = following
break
if not found:
return [], None
return all_matches, end_node

Просмотреть файл

@ -0,0 +1,402 @@
"""
HiddenLayer
Implementation of the Graph class. A framework independent directed graph to
represent a neural network.
Written by Waleed Abdulla. Additions by Phil Ferriere.
Licensed under the MIT License
"""
from __future__ import absolute_import, division, print_function
import os
import re
from random import getrandbits
import inspect
import numpy as np
import html
THEMES = {
"basic": {
"background_color": "#FFFFFF",
"fill_color": "#E8E8E8",
"outline_color": "#000000",
"font_color": "#000000",
"font_name": "Times",
"font_size": "10",
"margin": "0,0",
"padding": "1.0,0.5",
},
"blue": {
"shape": "box",
"background_color": "#FFFFFF",
"fill_color": "#BCD6FC",
"outline_color": "#7C96BC",
"font_color": "#202020",
"font_name": "Verdana",
"font_size": "10",
"margin": "0,0",
"padding": "1.0",
"layer_overrides": { #TODO: change names of these keys to same as dot params
"ConvRelu": {"shape":"box", "fillcolor":"#A1C9F4"},
"Conv": {"shape":"box", "fillcolor":"#FAB0E4"},
"MaxPool": {"shape":"box", "fillcolor":"#8DE5A1"},
"Constant": {"shape":"box", "fillcolor":"#FF9F9B"},
"Shape": {"shape":"box", "fillcolor":"#D0BBFF"},
"Gather": {"shape":"box", "fillcolor":"#DEBB9B"},
"Unsqeeze": {"shape":"box", "fillcolor":"#CFCFCF"},
"Sqeeze": {"shape":"box", "fillcolor":"#FFFEA3"},
"Dropout": {"shape":"box", "fillcolor":"#B9F2F0"},
"LinearRelu": {"shape":"box", "fillcolor":"#8DE5A1"},
"Linear": {"shape":"box", "fillcolor":"#4878D0"},
"Concat": {"shape":"box", "fillcolor":"#D0BBFF"},
"Reshape": {"shape":"box", "fillcolor":"#FFFEA3"},
}
},
}
###########################################################################
# Utility Functions
###########################################################################
def detect_framework(value):
# Get all base classes
classes = inspect.getmro(value.__class__)
for c in classes:
if c.__module__.startswith("torch"):
return "torch"
elif c.__module__.startswith("tensorflow"):
return "tensorflow"
###########################################################################
# Node
###########################################################################
class Node():
"""Represents a framework-agnostic neural network layer in a directed graph."""
def __init__(self, uid, name, op, output_shape=None, params=None, combo_params=None):
"""
uid: unique ID for the layer that doesn't repeat in the computation graph.
name: Name to display
op: Framework-agnostic operation name.
"""
self.id = uid
self.name = name # TODO: clarify the use of op vs name vs title
self.op = op
self.repeat = 1
if output_shape:
assert isinstance(output_shape, (tuple, list)),\
"output_shape must be a tuple or list but received {}".format(type(output_shape))
self.output_shape = output_shape
self.params = params if params else {}
self._caption = ""
self.combo_params = combo_params
@property
def title(self):
# Default
title = self.name or self.op
if self.op == 'Dropout':
if 'ratio' in self.params:
title += ' ' + str(self.params['ratio'])
if "kernel_shape" in self.params:
# Kernel
kernel = self.params["kernel_shape"]
title += "x".join(map(str, kernel))
if "stride" in self.params:
stride = self.params["stride"]
if np.unique(stride).size == 1:
stride = stride[0]
if stride != 1:
title += "/s{}".format(str(stride))
# # Transposed
# if node.transposed:
# name = "Transposed" + name
return title
@property
def caption(self):
if self._caption:
return self._caption
caption = ""
# Stride
# if "stride" in self.params:
# stride = self.params["stride"]
# if np.unique(stride).size == 1:
# stride = stride[0]
# if stride != 1:
# caption += "/{}".format(str(stride))
return caption
def __repr__(self):
args = (self.op, self.name, self.id, self.title, self.repeat)
f = "<Node: op: {}, name: {}, id: {}, title: {}, repeat: {}"
if self.output_shape:
args += (str(self.output_shape),)
f += ", shape: {:}"
if self.params:
args += (str(self.params),)
f += ", params: {:}"
f += ">"
return f.format(*args)
###########################################################################
# Graph
###########################################################################
def build_graph(model=None, args=None, input_names=None,
transforms="default", framework_transforms="default", orientation='TB'):
# Initialize an empty graph
g = Graph(orientation=orientation)
# Detect framwork
framework = detect_framework(model)
if framework == "torch":
from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS
import_graph(g, model, args)
elif framework == "tensorflow":
from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS
import_graph(g, model)
else:
raise ValueError("`model` input param must be a PyTorch, TensorFlow, or Keras-with-TensorFlow-backend model.")
# Apply Transforms
if framework_transforms:
if framework_transforms == "default":
framework_transforms = FRAMEWORK_TRANSFORMS
for t in framework_transforms:
g = t.apply(g)
if transforms:
if transforms == "default":
from .transforms import SIMPLICITY_TRANSFORMS
transforms = SIMPLICITY_TRANSFORMS
for t in transforms:
g = t.apply(g)
return g
class Graph():
"""Tracks nodes and edges of a directed graph and supports basic operations on them."""
def __init__(self, model=None, args=None, input_names=None,
transforms="default", framework_transforms="default",
meaningful_ids=False, orientation='TB'):
self.nodes = {}
self.edges = []
self.meaningful_ids = meaningful_ids # TODO
self.theme = THEMES["blue"].copy()
self.orientation = orientation
if model:
# Detect framwork
framework = detect_framework(model)
if framework == "torch":
from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS
import_graph(self, model, args)
elif framework == "tensorflow":
from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS
import_graph(self, model)
# Apply Transforms
if framework_transforms:
if framework_transforms == "default":
framework_transforms = FRAMEWORK_TRANSFORMS
for t in framework_transforms:
t.apply(self)
if transforms:
if transforms == "default":
from .transforms import SIMPLICITY_TRANSFORMS
transforms = SIMPLICITY_TRANSFORMS
for t in transforms:
t.apply(self)
def id(self, node):
"""Returns a unique node identifier. If the node has an id
attribute (preferred), it's used. Otherwise, the hash() is returned."""
return node.id if hasattr(node, "id") else hash(node)
def add_node(self, node):
id = self.id(node)
# assert(id not in self.nodes)
self.nodes[id] = node
def add_edge(self, node1, node2, label=None):
# If the edge is already present, don't add it again.
# TODO: If an edge exists with a different label, still don't add it again.
edge = (self.id(node1), self.id(node2), label)
if edge not in self.edges:
self.edges.append(edge)
def add_edge_by_id(self, vid1, vid2, label=None):
self.edges.append((vid1, vid2, label))
def outgoing(self, node):
"""Returns nodes connecting out of the given node (or list of nodes)."""
nodes = node if isinstance(node, list) else [node]
node_ids = [self.id(n) for n in nodes]
# Find edges outgoing from this group but not incoming to it
outgoing = [self[e[1]] for e in self.edges
if e[0] in node_ids and e[1] not in node_ids]
return outgoing
def incoming(self, node):
"""Returns nodes connecting to the given node (or list of nodes)."""
nodes = node if isinstance(node, list) else [node]
node_ids = [self.id(n) for n in nodes]
# Find edges incoming to this group but not outgoing from it
incoming = [self[e[0]] for e in self.edges
if e[1] in node_ids and e[0] not in node_ids]
return incoming
def siblings(self, node):
"""Returns all nodes that share the same parent (incoming node) with
the given node, including the node itself.
"""
incoming = self.incoming(node)
# TODO: Not handling the case of multiple incoming nodes yet
if len(incoming) == 1:
incoming = incoming[0]
siblings = self.outgoing(incoming)
return siblings
else:
return [node]
def __getitem__(self, key):
if isinstance(key, list):
return [self.nodes.get(k) for k in key]
else:
return self.nodes.get(key)
def remove(self, nodes):
"""Remove a node and its edges."""
nodes = nodes if isinstance(nodes, list) else [nodes]
for node in nodes:
k = self.id(node)
self.edges = list(filter(lambda e: e[0] != k and e[1] != k, self.edges))
del self.nodes[k]
def replace(self, nodes, node):
"""Replace nodes with node. Edges incoming to nodes[0] are connected to
the new node, and nodes outgoing from nodes[-1] become outgoing from
the new node."""
nodes = nodes if isinstance(nodes, list) else [nodes]
# Is the new node part of the replace nodes (i.e. want to collapse
# a group of nodes into one of them)?
collapse = self.id(node) in self.nodes
# Add new node and edges
if not collapse:
self.add_node(node)
for in_node in self.incoming(nodes):
# TODO: check specifically for output_shape is not generic. Consider refactoring.
self.add_edge(in_node, node, in_node.output_shape if hasattr(in_node, "output_shape") else None)
for out_node in self.outgoing(nodes):
self.add_edge(node, out_node, node.output_shape if hasattr(node, "output_shape") else None)
# Remove the old nodes
for n in nodes:
if collapse and n == node:
continue
self.remove(n)
def search(self, pattern):
"""Searches the graph for a sub-graph that matches the given pattern
and returns the first match it finds.
"""
for node in self.nodes.values():
match, following = pattern.match(self, node)
if match:
return match, following
return [], None
def sequence_id(self, sequence):
"""Make up an ID for a sequence (list) of nodes.
Note: `getrandbits()` is very uninformative as a "readable" ID. Here, we build a name
such that when the mouse hovers over the drawn node in Jupyter, one can figure out
which original nodes make up the sequence. This is actually quite useful.
"""
if self.meaningful_ids:
# TODO: This might fail if the ID becomes too long
return "><".join([node.id for node in sequence])
else:
return getrandbits(64)
def build_dot(self, orientation):
"""Generate a GraphViz Dot graph.
Returns a GraphViz Digraph object.
"""
from graphviz import Digraph
# Build GraphViz Digraph
dot = Digraph()
dot.attr("graph",
bgcolor=self.theme["background_color"],
color=self.theme["outline_color"],
fontsize=self.theme["font_size"],
fontcolor=self.theme["font_color"],
fontname=self.theme["font_name"],
margin=self.theme["margin"],
rankdir=orientation,
pad=self.theme["padding"])
dot.attr("node", shape="box",
style="filled", margin="0,0",
fillcolor=self.theme["fill_color"],
color=self.theme["outline_color"],
fontsize=self.theme["font_size"],
fontcolor=self.theme["font_color"],
fontname=self.theme["font_name"])
dot.attr("edge", style="solid",
color=self.theme["outline_color"],
fontsize=self.theme["font_size"],
fontcolor=self.theme["font_color"],
fontname=self.theme["font_name"])
for k, n in self.nodes.items():
label = "<tr><td cellpadding='6'>{}</td></tr>".format(n.title)
if n.caption:
label += "<tr><td>{}</td></tr>".format(n.caption)
if n.repeat > 1:
label += "<tr><td align='right' cellpadding='2'>x{}</td></tr>".format(n.repeat)
label = "<<table border='0' cellborder='0' cellpadding='0'>{}</table>>".\
format(label)
# figure out tooltip
tooltips = set()
if len(n.params):
tooltips.update([str(n.params)])
if n.combo_params:
for params in n.combo_params:
if len(params):
tooltips.update([str(params)])
# figure out shape and color
layer_overrides = self.theme.get('layer_overrides', {})
op_overrides = layer_overrides.get(n.op or n.name, {})
dot.node(str(k), label, tooltip=html.escape(' '.join(tooltips)), **op_overrides)
for a, b, label in self.edges:
if isinstance(label, (list, tuple)):
label = ' ' + "x".join([str(l or "?") for l in label])
dot.edge(str(a), str(b), label)
return dot
def _repr_svg_(self):
"""Allows Jupyter notebook to render the graph automatically."""
return self.build_dot(self.orientation)._repr_svg_()
def save(self, path, format="pdf"):
# TODO: assert on acceptable format values
dot = self.build_dot(self.orientation)
dot.format = format
directory, file_name = os.path.split(path)
# Remove extension from file name. dot.render() adds it.
file_name = file_name.replace("." + format, "")
dot.render(file_name, directory=directory, cleanup=True)

Просмотреть файл

@ -0,0 +1,137 @@
"""
HiddenLayer
PyTorch graph importer.
Written by Waleed Abdulla
Licensed under the MIT License
"""
from __future__ import absolute_import, division, print_function
import re
from .graph import Graph, Node
from . import transforms as ht
import torch
from collections import abc
import numpy as np
# PyTorch Graph Transforms
FRAMEWORK_TRANSFORMS = [
# Hide onnx: prefix
ht.Rename(op=r"onnx::(.*)", to=r"\1"),
# ONNX uses Gemm for linear layers (stands for General Matrix Multiplication).
# It's an odd name that noone recognizes. Rename it.
ht.Rename(op=r"Gemm", to=r"Linear"),
# PyTorch layers that don't have an ONNX counterpart
ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"),
# Shorten op name
ht.Rename(op=r"BatchNormalization", to="BatchNorm"),
]
def dump_pytorch_graph(graph):
"""List all the nodes in a PyTorch graph."""
f = "{:25} {:40} {} -> {}"
print(f.format("kind", "scopeName", "inputs", "outputs"))
for node in graph.nodes():
print(f.format(node.kind(), node.scopeName(),
[i.unique() for i in node.inputs()],
[i.unique() for i in node.outputs()]
))
def pytorch_id(node):
"""Returns a unique ID for a node."""
# After ONNX simplification, the scopeName is not unique anymore
# so append node outputs to guarantee uniqueness
return node.scopeName() + "/outputs/" + "/".join([o.uniqueName() for o in node.outputs()])
def get_shape(torch_node):
"""Return the output shape of the given Pytorch node."""
# Extract node output shape from the node string representation
# This is a hack because there doesn't seem to be an official way to do it.
# See my quesiton in the PyTorch forum:
# https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
# TODO: find a better way to extract output shape
# TODO: Assuming the node has one output. Update if we encounter a multi-output node.
m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
if m:
shape = m.group(1)
shape = shape.split(",")
shape = tuple(map(int, shape))
else:
shape = None
return shape
def calc_rf(model, input_shape):
for n, p in model.named_parameters():
if not p.requires_grad:
continue;
if 'bias' in n:
p.data.fill_(0)
elif 'weight' in n:
p.data.fill_(1)
input = torch.ones(input_shape, requires_grad=True)
output = model(input)
out_shape = output.size()
ndims = len(out_shape)
grad = torch.zeros(out_shape)
l_tmp=[]
for i in xrange(ndims):
if i==0 or i==1:#batch or channel
l_tmp.append(0)
else:
l_tmp.append(out_shape[i]/2)
grad[tuple(l_tmp)] = 1
output.backward(gradient=grad)
grad_np = img_.grad[0,0].data.numpy()
idx_nonzeros = np.where(grad_np!=0)
RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros]
return RF
def import_graph(hl_graph, model, args, input_names=None, verbose=False):
# TODO: add input names to graph
if args is None:
args = [1, 3, 224, 224] # assume ImageNet default
# if args is not Tensor but is array like then convert it to torch tensor
if not isinstance(args, torch.Tensor) and \
hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
not isinstance(args, (str, abc.ByteString)):
args = torch.ones(args)
# Run the Pytorch graph to get a trace and generate a graph from it
trace, out = torch.jit.get_trace_graph(model, args)
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
torch_graph = trace.graph()
# Dump list of nodes (DEBUG only)
if verbose:
dump_pytorch_graph(torch_graph)
# Loop through nodes and build HL graph
for torch_node in torch_graph.nodes():
# Op
op = torch_node.kind()
# Parameters
params = {k: torch_node[k] for k in torch_node.attributeNames()}
# Inputs/outputs
# TODO: inputs = [i.unique() for i in node.inputs()]
outputs = [o.unique() for o in torch_node.outputs()]
# Get output shape
shape = get_shape(torch_node)
# Add HL node
hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
output_shape=shape, params=params)
hl_graph.add_node(hl_node)
# Add edges
for target_torch_node in torch_graph.nodes():
target_inputs = [i.unique() for i in target_torch_node.inputs()]
if set(outputs) & set(target_inputs):
hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
return hl_graph

Просмотреть файл

@ -0,0 +1,142 @@
"""
HiddenLayer
TensorFlow graph importer.
Written by Phil Ferriere. Edits by Waleed Abdulla.
Licensed under the MIT License
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import tensorflow as tf
from .graph import Graph, Node
from . import transforms as ht
FRAMEWORK_TRANSFORMS = [
# Rename VariableV2 op to Variable. Same for anything V2, V3, ...etc.
ht.Rename(op=r"(\w+)V\d", to=r"\1"),
ht.Prune("Const"),
ht.Prune("PlaceholderWithDefault"),
ht.Prune("Variable"),
ht.Prune("VarIsInitializedOp"),
ht.Prune("VarHandleOp"),
ht.Prune("ReadVariableOp"),
ht.PruneBranch("Assign"),
ht.PruneBranch("AssignSub"),
ht.PruneBranch("AssignAdd"),
ht.PruneBranch("AssignVariableOp"),
ht.Prune("ApplyMomentum"),
ht.Prune("ApplyAdam"),
ht.FoldId(r"^(gradients)/.*", "NoOp"), # Fold to NoOp then delete in the next step
ht.Prune("NoOp"),
ht.Rename(op=r"DepthwiseConv2dNative", to="SeparableConv"),
ht.Rename(op=r"Conv2D", to="Conv"),
ht.Rename(op=r"FusedBatchNorm", to="BatchNorm"),
ht.Rename(op=r"MatMul", to="Linear"),
ht.Fold("Conv > BiasAdd", "__first__"),
ht.Fold("Linear > BiasAdd", "__first__"),
ht.Fold("Shape > StridedSlice > Pack > Reshape", "__last__"),
ht.FoldId(r"(.+)/dropout/.*", "Dropout"),
ht.FoldId(r"(softmax_cross\_entropy)\_with\_logits.*", "SoftmaxCrossEntropy"),
]
def dump_tf_graph(tfgraph, tfgraphdef):
"""List all the nodes in a TF graph.
tfgraph: A TF Graph object.
tfgraphdef: A TF GraphDef object.
"""
print("Nodes ({})".format(len(tfgraphdef.node)))
f = "{:15} {:59} {:20} {}"
print(f.format("kind", "scopeName", "shape", "inputs"))
for node in tfgraphdef.node:
scopename = node.name
kind = node.op
inputs = node.input
shape = tf.graph_util.tensor_shape_from_node_def_name(tfgraph, scopename)
print(f.format(kind, scopename, str(shape), inputs))
def import_graph(hl_graph, tf_graph, output=None, verbose=False):
"""Convert TF graph to directed graph
tfgraph: A TF Graph object.
output: Name of the output node (string).
verbose: Set to True for debug print output
"""
# Get clean(er) list of nodes
graph_def = tf_graph.as_graph_def(add_shapes=True)
graph_def = tf.graph_util.remove_training_nodes(graph_def)
# Dump list of TF nodes (DEBUG only)
if verbose:
dump_tf_graph(tf_graph, graph_def)
# Loop through nodes and build the matching directed graph
for tf_node in graph_def.node:
# Read node details
try:
op, uid, name, shape, params = import_node(tf_node, tf_graph, verbose)
except:
if verbose:
logging.exception("Failed to read node {}".format(tf_node))
continue
# Add node
hl_node = Node(uid=uid, name=name, op=op, output_shape=shape, params=params)
hl_graph.add_node(hl_node)
# Add edges
for target_node in graph_def.node:
target_inputs = target_node.input
if uid in target_node.input:
hl_graph.add_edge_by_id(uid, target_node.name, shape)
return hl_graph
def import_node(tf_node, tf_graph, verbose=False):
# Operation type and name
op = tf_node.op
uid = tf_node.name
name = None
# Shape
shape = None
if tf_node.op != "NoOp":
try:
shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.name)
# Is the shape is known, convert to a list
if shape.ndims is not None:
shape = shape.as_list()
except:
if verbose:
logging.exception("Error reading shape of {}".format(tf_node.name))
# Parameters
# At this stage, we really only care about two parameters:
# 1/ the kernel size used by convolution layers
# 2/ the stride used by convolutional and pooling layers (TODO: not fully working yet)
# 1/ The kernel size is actually not stored in the convolution tensor but in its weight input.
# The weights input has the shape [shape=[kernel, kernel, in_channels, filters]]
# So we must fish for it
params = {}
if op == "Conv2D" or op == "DepthwiseConv2dNative":
kernel_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.input[1])
kernel_shape = [int(a) for a in kernel_shape]
params["kernel_shape"] = kernel_shape[0:2]
if 'strides' in tf_node.attr.keys():
strides = [int(a) for a in tf_node.attr['strides'].list.i]
params["stride"] = strides[1:3]
elif op == "MaxPool" or op == "AvgPool":
# 2/ the stride used by pooling layers
# See https://stackoverflow.com/questions/44124942/how-to-access-values-in-protos-in-tensorflow
if 'ksize' in tf_node.attr.keys():
kernel_shape = [int(a) for a in tf_node.attr['ksize'].list.i]
params["kernel_shape"] = kernel_shape[1:3]
if 'strides' in tf_node.attr.keys():
strides = [int(a) for a in tf_node.attr['strides'].list.i]
params["stride"] = strides[1:3]
return op, uid, name, shape, params

Просмотреть файл

@ -0,0 +1,212 @@
"""
HiddenLayer
Transforms that apply to and modify graph nodes.
Written by Waleed Abdulla
Licensed under the MIT License
"""
import re
import copy
from .graph import Node
from . import ge
###########################################################################
# Transforms
###########################################################################
def _concate_params(matches):
combo_params = [match.params for match in matches]
combo_params += [cb for match in matches if match.combo_params for cb in match.combo_params]
return combo_params
class Fold():
def __init__(self, pattern, op, name=None):
# TODO: validate that op and name are valid
self.pattern = ge.GEParser(pattern).parse()
self.op = op
self.name = name
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
while True:
matches, _ = graph.search(self.pattern)
if not matches:
break
# Replace pattern with new node
if self.op == "__first__":
combo = matches[0]
elif self.op == "__last__":
combo = matches[-1]
else:
combo = Node(uid=graph.sequence_id(matches),
name=self.name or " &gt; ".join([l.title for l in matches]),
op=self.op or self.pattern,
output_shape=matches[-1].output_shape,
combo_params=_concate_params(matches))
combo._caption = "/".join(filter(None, [l.caption for l in matches]))
graph.replace(matches, combo)
return graph
class FoldId():
def __init__(self, id_regex, op, name=None):
# TODO: validate op and name are valid
self.id_regex = re.compile(id_regex)
self.op = op
self.name = name
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
# Group nodes by the first matching group of the regex
groups = {}
for node in graph.nodes.values():
m = self.id_regex.match(node.id)
if not m:
continue
assert m.groups(), "Regular expression must have a matching group to avoid folding unrelated nodes."
key = m.group(1)
if key not in groups:
groups[key] = []
groups[key].append(node)
# Fold each group of nodes together
for key, nodes in groups.items():
# Replace with a new node
# TODO: Find last node in the sub-graph and get the output shape from it
combo = Node(uid=key,
name=self.name,
op=self.op,
combo_params=_concate_params(nodes))
graph.replace(nodes, combo)
return graph
class Prune():
def __init__(self, pattern):
self.pattern = ge.GEParser(pattern).parse()
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
while True:
matches, _ = graph.search(self.pattern)
if not matches:
break
# Remove found nodes
graph.remove(matches)
return graph
class PruneBranch():
def __init__(self, pattern):
self.pattern = ge.GEParser(pattern).parse()
def tag(self, node, tag, graph, conditional=False):
# Return if the node is already tagged
if hasattr(node, "__tag__") and node.__tag__ == "tag":
return
# If conditional, then tag the node if and only if all its
# outgoing nodes already have the same tag.
if conditional:
# Are all outgoing nodes already tagged?
outgoing = graph.outgoing(node)
tagged = filter(lambda n: hasattr(n, "__tag__") and n.__tag__ == tag,
outgoing)
if len(list(tagged)) != len(outgoing):
# Not all outgoing are tagged
return
# Tag the node
node.__tag__ = tag
# Tag incoming nodes
for n in graph.incoming(node):
self.tag(n, tag, graph, conditional=True)
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
while True:
matches, _ = graph.search(self.pattern)
if not matches:
break
# Tag found nodes and their incoming branches
for n in matches:
self.tag(n, "delete", graph)
# Find all tagged nodes and delete them
tagged = [n for n in graph.nodes.values()
if hasattr(n, "__tag__") and n.__tag__ == "delete"]
graph.remove(tagged)
return graph
class FoldDuplicates():
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
matches = True
while matches:
for node in graph.nodes.values():
pattern = ge.SerialPattern([ge.NodePattern(node.op), ge.NodePattern(node.op)])
matches, _ = pattern.match(graph, node)
if matches:
# Use op and name from the first node, and output_shape from the last
combo = Node(uid=graph.sequence_id(matches),
name=node.name,
op=node.op,
output_shape=matches[-1].output_shape,
combo_params=_concate_params(matches))
combo._caption = node.caption
combo.repeat = sum([n.repeat for n in matches])
graph.replace(matches, combo)
break
return graph
class Rename():
def __init__(self, op=None, name=None, to=None):
assert op or name, "Either op or name must be provided"
assert not(op and name), "Either op or name should be provided, but not both"
assert bool(to), "The to parameter is required"
self.to = to
self.op = re.compile(op) if op else None
self.name = re.compile(name) if name else None
def apply(self, graph):
# Copy the graph. Don't change the original.
graph = copy.deepcopy(graph)
for node in graph.nodes.values():
if self.op:
node.op = self.op.sub(self.to, node.op)
# TODO: name is not tested yet
if self.name:
node.name = self.name.sub(self.to, node.name)
return graph
# Transforms to simplify graphs by folding layers that tend to be
# used together often, such as Conv/BN/Relu.
# These transforms are used AFTER the framework specific transforms
# that map TF and PyTorch graphs to a common representation.
SIMPLICITY_TRANSFORMS = [
Fold("Conv > Conv > BatchNorm > Relu", "ConvConvBnRelu"),
Fold("Conv > BatchNorm > Relu", "ConvBnRelu"),
Fold("Conv > BatchNorm", "ConvBn"),
Fold("Conv > Relu", "ConvRelu"),
Fold("Linear > Relu", "LinearRelu"),
# Fold("ConvBnRelu > MaxPool", "ConvBnReluMaxpool"),
# Fold("ConvRelu > MaxPool", "ConvReluMaxpool"),
FoldDuplicates(),
]

Просмотреть файл

@ -0,0 +1,89 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torchstat
import pandas as pd
def model_stats(model, input_shape):
if len(input_shape) > 3:
input_shape = input_shape[1:4]
ms = torchstat.statistics.ModelStat(model, input_shape, 1)
collected_nodes = ms._analyze_model()
return _report_format(collected_nodes)
def _round_value(value, binary=False):
divisor = 1024. if binary else 1000.
if value // divisor**4 > 0:
return str(round(value / divisor**4, 2)) + 'T'
elif value // divisor**3 > 0:
return str(round(value / divisor**3, 2)) + 'G'
elif value // divisor**2 > 0:
return str(round(value / divisor**2, 2)) + 'M'
elif value // divisor > 0:
return str(round(value / divisor, 2)) + 'K'
return str(value)
def _report_format(collected_nodes):
pd.set_option('display.width', 1000)
pd.set_option('display.max_rows', 10000)
pd.set_option('display.max_columns', 10000)
data = list()
for node in collected_nodes:
name = node.name
input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format(
*[e for e in node.input_shape])
output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format(
*[e for e in node.output_shape])
parameter_quantity = node.parameter_quantity
inference_memory = node.inference_memory
MAdd = node.MAdd
Flops = node.Flops
mread, mwrite = [i for i in node.Memory]
duration = node.duration
data.append([name, input_shape, output_shape, parameter_quantity,
inference_memory, MAdd, duration, Flops, mread,
mwrite])
df = pd.DataFrame(data)
df.columns = ['module name', 'input shape', 'output shape',
'params', 'memory(MB)',
'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)']
df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7)
df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)']
total_parameters_quantity = df['params'].sum()
total_memory = df['memory(MB)'].sum()
total_operation_quantity = df['MAdd'].sum()
total_flops = df['Flops'].sum()
total_duration = df['duration[%]'].sum()
total_mread = df['MemRead(B)'].sum()
total_mwrite = df['MemWrite(B)'].sum()
total_memrw = df['MemR+W(B)'].sum()
del df['duration']
# Add Total row
total_df = pd.Series([total_parameters_quantity, total_memory,
total_operation_quantity, total_flops,
total_duration, mread, mwrite, total_memrw],
index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]',
'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'],
name='total')
df = df.append(total_df)
df = df.fillna(' ')
df['memory(MB)'] = df['memory(MB)'].apply(
lambda x: '{:.2f}'.format(x))
df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x))
df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x))
#summary = "Total params: {:,}\n".format(total_parameters_quantity)
#summary += "-" * len(str(df).split('\n')[0])
#summary += '\n'
#summary += "Total memory: {:.2f}MB\n".format(total_memory)
#summary += "Total MAdd: {}MAdd\n".format(_round_value(total_operation_quantity))
#summary += "Total Flops: {}Flops\n".format(_round_value(total_flops))
#summary += "Total MemR+W: {}B\n".format(_round_value(total_memrw, True))
return df

Просмотреть файл

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# import matplotlib before anything else
# because of VS debugger issue for multiprocessing
# https://github.com/Microsoft/ptvsd/issues/1041
from .line_plot import LinePlot
from .image_plot import ImagePlot

Просмотреть файл

@ -0,0 +1,149 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#from IPython import get_ipython, display
#if get_ipython():
# get_ipython().magic('matplotlib notebook')
#import matplotlib
#if os.name == 'posix' and "DISPLAY" not in os.environ:
# matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
#from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
#from ipykernel.pylab.backend_inline import flush_figures
from ..vis_base import VisBase
import sys, traceback, logging
from abc import abstractmethod
from .. import utils
from IPython import get_ipython #, display
import ipywidgets as widgets
class BaseMplPlot(VisBase):
def __init__(self, cell:widgets.Box=None, title:str=None, show_legend:bool=None, stream_name:str=None, console_debug:bool=False, **vis_args):
super(BaseMplPlot, self).__init__(widgets.Output(), cell, title, show_legend, stream_name=stream_name, console_debug=console_debug, **vis_args)
self._fig_init_done = False
self.show_legend = show_legend
# graph objects
self.figure = None
self._ax_main = None
# matplotlib animation
self.animation = None
self.anim_interval = None
#print(matplotlib.get_backend())
#display.display(self.cell)
# anim_interval in seconds
def init_fig(self, anim_interval:float=1.0):
"""(for derived class) Initializes matplotlib figure"""
if self._fig_init_done:
return False
# create figure and animation
self.figure = plt.figure(figsize=(8, 3))
self.anim_interval = anim_interval
plt.set_cmap('Dark2')
plt.rcParams['image.cmap']='Dark2'
self._fig_init_done = True
return True
def get_main_axis(self):
# if we don't yet have main axis, create one
if not self._ax_main:
# by default assign one subplot to whole graph
self._ax_main = self.figure.add_subplot(111)
self._ax_main.grid(True)
# change the color of the top and right spines to opaque gray
self._ax_main.spines['right'].set_color((.8,.8,.8))
self._ax_main.spines['top'].set_color((.8,.8,.8))
if self.title is not None:
title = self._ax_main.set_title(self.title)
title.set_weight('bold')
return self._ax_main
def _on_update(self, frame): # pylint: disable=unused-argument
try:
self._update_stream_plots()
except Exception as ex:
# when exception occurs here, animation will stop and there
# will be no further plot updates
# TODO: may be we don't need all of below but none of them
# are popping up exception in Jupyter Notebook because these
# exceptions occur in background?
self.last_ex = ex
print(ex)
logging.fatal(ex, exc_info=True)
traceback.print_exc(file=sys.stdout)
def show(self, blocking=False):
if not self.is_shown and self.anim_interval:
self.animation = FuncAnimation(self.figure, self._on_update, interval=self.anim_interval*1000.0)
super(BaseMplPlot, self).show(blocking)
def _post_update_stream_plot(self, stream_vis):
utils.debug_log("Plot updated", stream_vis.stream.stream_name, verbosity=5)
if self.layout_dirty:
# do not do tight_layout() call on every update
# that would jumble up the graphs! it should only called
# once each time there is change in layout
self.figure.tight_layout()
self.layout_dirty = False
# below forces redraw and it was helpful to
# repaint even if there was error in interval loop
# but it does work in native UX and not in Jupyter Notebook
#self.figure.canvas.draw()
#self.figure.canvas.flush_events()
if self._use_hbox and get_ipython():
self.widget.clear_output(wait=True)
with self.widget:
plt.show(self.figure)
# everything else that doesn't work
#self.figure.show()
#display.clear_output(wait=True)
#display.display(self.figure)
#flush_figures()
#plt.show()
#show_inline_matplotlib_plots()
#elif not get_ipython():
# self.figure.canvas.draw()
def _post_add_subscription(self, stream_vis, **stream_vis_args):
# make sure figure is initialized
self.init_fig()
self.init_stream_plot(stream_vis, **stream_vis_args)
# redo the legend
#self.figure.legend(loc='center right', bbox_to_anchor=(1.5, 0.5))
if self.show_legend:
self.figure.legend(loc='lower right')
plt.subplots_adjust(hspace=0.6)
def _show_widget_native(self, blocking:bool):
#plt.ion()
#plt.show()
return plt.show(block=blocking)
def _show_widget_notebook(self):
# no need to return anything because %matplotlib notebook will
# detect spawning of figure and paint it
# if self.figure is returned then you will see two of them
return None
#plt.show()
#return self.figure
def _can_update_stream_plots(self):
return False # we run interval timer which will flush the key
@abstractmethod
def init_stream_plot(self, stream_vis, **stream_vis_args):
"""(for derived class) Create new plot info for this stream"""
pass

Просмотреть файл

@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_mpl_plot import BaseMplPlot
from .. import utils, image_utils
import numpy as np
import skimage.transform
#from IPython import get_ipython
class ImagePlot(BaseMplPlot):
def init_stream_plot(self, stream_vis,
rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
colormap=None, viz_img_scale=None, **stream_vis_args):
stream_vis.rows, stream_vis.cols = rows, cols
stream_vis.img_channels, stream_vis.colormap = img_channels, colormap
stream_vis.img_width, stream_vis.img_height = img_width, img_height
stream_vis.viz_img_scale = viz_img_scale
# subplots holding each image
stream_vis.axs = [[None for _ in range(cols)] for _ in range(rows)]
# axis image
stream_vis.ax_imgs = [[None for _ in range(cols)] for _ in range(rows)]
def clear_plot(self, stream_vis, clear_history):
for row in range(stream_vis.rows):
for col in range(stream_vis.cols):
img = stream_vis.ax_imgs[row][col]
if img:
x, y = img.get_size()
img.set_data(np.zeros((x, y)))
def _show_stream_items(self, stream_vis, stream_items):
# as we repaint each image plot, select last if multiple events were pending
stream_item = None
for er in reversed(stream_items):
if not(er.ended or er.value is None):
stream_item = er
break
if stream_item is None:
return False
row, col, i = 0, 0, 0
dirty = False
# stream_item.value is expected to be ImagePlotItems
for image_list in stream_item.value:
# convert to imshow compatible, stitch images
images = [image_utils.to_imshow_array(img, stream_vis.img_width, stream_vis.img_height) \
for img in image_list.images if img is not None]
img_viz = image_utils.stitch_horizontal(images, width_dim=1)
# resize if requested
if stream_vis.viz_img_scale is not None:
img_viz = skimage.transform.rescale(img_viz,
(stream_vis.viz_img_scale, stream_vis.viz_img_scale), mode='reflect', preserve_range=False)
# create subplot if it doesn't exist
ax = stream_vis.axs[row][col]
if ax is None:
ax = stream_vis.axs[row][col] = \
self.figure.add_subplot(stream_vis.rows, stream_vis.cols, i+1)
ax.set_xticks([])
ax.set_yticks([])
cmap = image_list.cmap or ('Greys' if stream_vis.colormap is None and \
len(img_viz.shape) == 2 else stream_vis.colormap)
stream_vis.ax_imgs[row][col] = ax.imshow(img_viz, interpolation="none", cmap=cmap, alpha=image_list.alpha)
dirty = True
# set title
title = image_list.title
if len(title) > 12: #wordwrap if too long
title = utils.wrap_string(title) if len(title) > 24 else title
fontsize = 8
else:
fontsize = 12
ax.set_title(title, fontsize=fontsize) #'fontweight': 'light'
#ax.autoscale_view() # not needed
col = col + 1
if col >= stream_vis.cols:
col = 0
row = row + 1
if row >= stream_vis.rows:
break
i += 1
return dirty
def has_legend(self):
return self.show_legend or False

Просмотреть файл

@ -0,0 +1,148 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_mpl_plot import BaseMplPlot
import matplotlib
import matplotlib.pyplot as plt
from .. import utils
import numpy as np
from ..lv_types import EventVars
import ipywidgets as widgets
class LinePlot(BaseMplPlot):
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=True, stream_name:str=None, console_debug:bool=False, is_3d:bool=False, **vis_args):
super(LinePlot, self).__init__(cell, title, show_legend, stream_name=stream_name, console_debug=console_debug, **vis_args)
self.is_3d = is_3d #TODO: not implemented for mpl
def init_stream_plot(self, stream_vis,
xtitle='', ytitle='', color=None, xrange=None, yrange=None, **stream_vis_args):
stream_vis.xylabel_refs = [] # annotation references
# add main subplot
if len(self._stream_vises) == 0:
stream_vis.ax = self.get_main_axis()
else:
stream_vis.ax = self.get_main_axis().twinx()
#TODO: improve color selection
color = color or plt.cm.Dark2((len(self._stream_vises)%8)/8) # pylint: disable=no-member
# add default line in subplot
stream_vis.line = matplotlib.lines.Line2D([], [],
label=stream_vis.title or ytitle or str(stream_vis.index), color=color) #, linewidth=3
if stream_vis.opacity is not None:
stream_vis.line.set_alpha(stream_vis.opacity)
stream_vis.ax.add_line(stream_vis.line)
# if more than 2 y-axis then place additional outside
if len(self._stream_vises) > 1:
pos = (len(self._stream_vises)) * 30
stream_vis.ax.spines['right'].set_position(('outward', pos))
stream_vis.ax.set_xlabel(xtitle)
stream_vis.ax.set_ylabel(ytitle)
stream_vis.ax.yaxis.label.set_color(color)
stream_vis.ax.yaxis.label.set_style('italic')
stream_vis.ax.xaxis.label.set_style('italic')
if xrange is not None:
stream_vis.ax.set_xlim(*xrange)
if yrange is not None:
stream_vis.ax.set_ylim(*yrange)
def clear_plot(self, stream_vis, clear_history):
lines = stream_vis.ax.get_lines()
# if we need to keep history
if stream_vis.history_len > 1:
# make sure we have history len - 1 lines
lines_keep = 0 if clear_history else stream_vis.history_len-1
while len(lines) > lines_keep:
lines.pop(0).remove()
# dim old lines
if stream_vis.dim_history and len(lines) > 0:
alphas = np.linspace(0.05, 1, len(lines))
for line, opacity in zip(lines, alphas):
line.set_alpha(opacity)
line.set_linewidth(1)
# add new line
line = matplotlib.lines.Line2D([], [], linewidth=3)
stream_vis.ax.add_line(line)
else: #clear current line
lines[-1].set_data([], [])
# remove annotations
for label_info in stream_vis.xylabel_refs:
label_info.set_visible(False)
label_info.remove()
stream_vis.xylabel_refs.clear()
def _show_stream_items(self, stream_vis, stream_items):
vals = self._extract_vals(stream_items)
if not len(vals):
return False
line = stream_vis.ax.get_lines()[-1]
xdata, ydata = line.get_data()
zdata, anndata, txtdata, clrdata = [], [], [], []
unpacker = lambda a0=None,a1=None,a2=None,a3=None,a4=None,a5=None, *_:(a0,a1,a2,a3,a4,a5)
# add each value in trace data
# each value is of the form:
# 2D graphs:
# y
# x [, y [, annotation [, text [, color]]]]
# y
# x [, y [, z, [annotation [, text [, color]]]]]
for val in vals:
# set defaults
x, y, z = len(xdata), None, None
ann, txt, clr = None, None, None
# if val turns out to be array-like, extract x,y
val_l = utils.is_scaler_array(val)
if val_l >= 0:
if self.is_3d:
x, y, z, ann, txt, clr = unpacker(*val)
else:
x, y, ann, txt, clr, _ = unpacker(*val)
elif isinstance(val, EventVars):
x = val.x if hasattr(val, 'x') else x
y = val.y if hasattr(val, 'y') else y
z = val.z if hasattr(val, 'z') else z
ann = val.ann if hasattr(val, 'ann') else ann
txt = val.txt if hasattr(val, 'txt') else txt
clr = val.clr if hasattr(val, 'clr') else clr
if y is None:
y = next(iter(val.__dict__.values()))
else:
y = val
if ann is not None:
ann = str(ann)
if txt is not None:
txt = str(txt)
xdata.append(x)
ydata.append(y)
zdata.append(z)
if (txt):
txtdata.append(txt)
if clr:
clrdata.append(clr)
if ann: #TODO: yref should be y2 for different y axis
anndata.append(dict(x=x, y=y, xref='x', yref='y', text=ann, showarrow=False))
line.set_data(xdata, ydata)
for ann in anndata:
stream_vis.xylabel_refs.append(stream_vis.ax.text( \
ann['x'], ann['y'], ann['text']))
stream_vis.ax.relim()
stream_vis.ax.autoscale_view()
return True

Просмотреть файл

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .line_plot import LinePlot
from .embeddings_plot import EmbeddingsPlot
#from .vis_base import *

Просмотреть файл

@ -0,0 +1,100 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import plotly
import plotly.graph_objs as go
import ipywidgets as widgets
from ..vis_base import VisBase
import time
from abc import abstractmethod
from .. import utils
class BasePlotlyPlot(VisBase):
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=None, stream_name:str=None, console_debug:bool=False, **vis_args):
super(BasePlotlyPlot, self).__init__(go.FigureWidget(), cell, title, show_legend,
stream_name=stream_name, console_debug=console_debug, **vis_args)
self.widget.layout.title = title
self.widget.layout.showlegend = show_legend if show_legend is not None else True
def _add_trace(self, stream_vis):
stream_vis.trace_index = len(self.widget.data)
trace = self._create_trace(stream_vis)
if stream_vis.opacity is not None:
trace.opacity = stream_vis.opacity
self.widget.add_trace(trace)
def _add_trace_with_history(self, stream_vis):
# if history buffer isn't full
if stream_vis.history_len > len(stream_vis.trace_history):
self._add_trace(stream_vis)
stream_vis.trace_history.append(len(self.widget.data)-1)
stream_vis.cur_history_index = len(stream_vis.trace_history)-1
#if stream_vis.cur_history_index:
# self.widget.data[trace_index].showlegend = False
else:
# rotate trace
stream_vis.cur_history_index = (stream_vis.cur_history_index + 1) % stream_vis.history_len
stream_vis.trace_index = stream_vis.trace_history[stream_vis.cur_history_index]
self.clear_plot(stream_vis, False)
self.widget.data[stream_vis.trace_index].opacity = stream_vis.opacity or 1
cur_history_len = len(stream_vis.trace_history)
if stream_vis.dim_history and cur_history_len > 1:
max_opacity = stream_vis.opacity or 1
min_alpha, max_alpha, dimmed_len = max_opacity*0.05, max_opacity*0.8, cur_history_len-1
alphas = list(utils.frange(max_alpha, min_alpha, steps=dimmed_len))
for i, thi in enumerate(range(stream_vis.cur_history_index+1,
stream_vis.cur_history_index+cur_history_len)):
trace_index = stream_vis.trace_history[thi % cur_history_len]
self.widget.data[trace_index].opacity = alphas[i]
@staticmethod
def get_pallet_color(i:int):
return plotly.colors.DEFAULT_PLOTLY_COLORS[i % len(plotly.colors.DEFAULT_PLOTLY_COLORS)]
@staticmethod
def _get_axis_common_props(title:str, axis_range:tuple):
props = {'showline':True, 'showgrid': True,
'showticklabels': True, 'ticks':'inside'}
if title:
props['title'] = title
if axis_range:
props['range'] = list(axis_range)
return props
def _can_update_stream_plots(self):
return time.time() - self.q_last_processed > 0.5 # make configurable
def _post_add_subscription(self, stream_vis, **stream_vis_args):
stream_vis.trace_history, stream_vis.cur_history_index = [], None
self._add_trace_with_history(stream_vis)
self._setup_layout(stream_vis)
if not self.widget.layout.title:
self.widget.layout.title = stream_vis.title
# TODO: better way for below?
if stream_vis.history_len > 1:
self.widget.layout.showlegend = False
def _show_widget_native(self, blocking:bool):
pass
#TODO: save image, spawn browser?
def _show_widget_notebook(self):
#plotly.offline.iplot(self.widget)
return None
def _post_update_stream_plot(self, stream_vis):
# not needed for plotly as FigureWidget stays upto date
pass
@abstractmethod
def _setup_layout(self, stream_vis):
pass
@abstractmethod
def _create_trace(self, stream_vis):
pass

Просмотреть файл

@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import matplotlib.pyplot as plt
from ipywidgets import Output #, Layout
from IPython.display import display, clear_output
import numpy as np
from .line_plot import LinePlot
import time
from .. import utils
import ipywidgets as widgets
class EmbeddingsPlot(LinePlot):
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=False, stream_name:str=None, console_debug:bool=False,
is_3d:bool=True, hover_images=None, hover_image_reshape=None, **vis_args):
utils.set_default(vis_args, 'height', '8in')
super(EmbeddingsPlot, self).__init__(cell, title, show_legend,
stream_name=stream_name, console_debug=console_debug, is_3d=is_3d, **vis_args)
if hover_images is not None:
plt.ioff()
self.image_output = Output()
self.image_figure = plt.figure(figsize=(2,2))
self.image_ax = self.image_figure.add_subplot(111)
self.cell.children += (self.image_output,)
plt.ion()
self.hover_images, self.hover_image_reshape = hover_images, hover_image_reshape
self.last_ind, self.last_ind_time = -1, 0
def hover_fn(self, trace, points, state): # pylint: disable=unused-argument
if not points:
return
ind = points.point_inds[0]
if ind == self.last_ind or ind > len(self.hover_images) or ind < 0:
return
if self.last_ind == -1:
self.last_ind, self.last_ind_time = ind, time.time()
else:
elapsed = time.time() - self.last_ind_time
if elapsed < 0.3:
self.last_ind, self.last_ind_time = ind, time.time()
if elapsed < 1:
return
# else too much time since update
# else we have stable ind
with self.image_output:
plt.ioff()
if self.hover_image_reshape:
img = np.reshape(self.hover_images[ind], self.hover_image_reshape)
else:
img = self.hover_images[ind]
if img is not None:
clear_output(wait=True)
self.image_ax.imshow(img)
display(self.image_figure)
plt.ion()
return None
def _create_trace(self, stream_vis):
stream_vis.stream_vis_args.clear() #TODO remove this
utils.set_default(stream_vis.stream_vis_args, 'draw_line', False)
utils.set_default(stream_vis.stream_vis_args, 'draw_marker', True)
utils.set_default(stream_vis.stream_vis_args, 'draw_marker_text', True)
utils.set_default(stream_vis.stream_vis_args, 'hoverinfo', 'text')
utils.set_default(stream_vis.stream_vis_args, 'marker', {})
marker = stream_vis.stream_vis_args['marker']
utils.set_default(marker, 'size', 6)
utils.set_default(marker, 'colorscale', 'Jet')
utils.set_default(marker, 'showscale', False)
utils.set_default(marker, 'opacity', 0.8)
return super(EmbeddingsPlot, self)._create_trace(stream_vis)
def subscribe(self, stream, **stream_vis_args):
super(EmbeddingsPlot, self).subscribe(stream)
stream_vis = self._stream_vises[stream.stream_name]
if stream_vis.index == 0 and self.hover_images is not None:
self.widget.data[stream_vis.trace_index].on_hover(self.hover_fn)

Просмотреть файл

@ -0,0 +1,186 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import plotly.graph_objs as go
from .base_plotly_plot import BasePlotlyPlot
from ..lv_types import EventVars
from .. import utils
import ipywidgets as widgets
class LinePlot(BasePlotlyPlot):
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=True, stream_name:str=None, console_debug:bool=False,
is_3d:bool=False, **vis_args):
super(LinePlot, self).__init__(cell, title, show_legend, stream_name=stream_name, console_debug=console_debug, **vis_args)
self.is_3d = is_3d
def _setup_layout(self, stream_vis):
# handle multiple y axis
yaxis = 'yaxis' + (str(stream_vis.index + 1) if stream_vis.separate_yaxis else '')
xaxis = 'xaxis' + str(stream_vis.index+1)
axis_props = BasePlotlyPlot._get_axis_common_props(stream_vis.xtitle, stream_vis.xrange)
#axis_props['rangeslider'] = dict(visible = True)
self.widget.layout[xaxis] = axis_props
# handle multiple Y-Axis plots
color = self.widget.data[stream_vis.trace_index].line.color
yaxis = 'yaxis' + (str(stream_vis.index + 1) if stream_vis.separate_yaxis else '')
axis_props = BasePlotlyPlot._get_axis_common_props(stream_vis.ytitle, stream_vis.yrange)
axis_props['linecolor'] = color
axis_props['tickfont']=axis_props['titlefont'] = dict(color=color)
if stream_vis.index > 0 and stream_vis.separate_yaxis:
axis_props['overlaying'] = 'y'
axis_props['side'] = 'right'
if stream_vis.index > 1:
self.widget.layout.xaxis = dict(domain=[0, 1 - 0.085*(stream_vis.index-1)])
axis_props['anchor'] = 'free'
axis_props['position'] = 1 - 0.085*(stream_vis.index-2)
self.widget.layout[yaxis] = axis_props
if self.is_3d:
zaxis = 'zaxis' #+ str(stream_vis.index+1)
axis_props = BasePlotlyPlot._get_axis_common_props(stream_vis.ztitle, stream_vis.zrange)
self.widget.layout.scene[zaxis] = axis_props
self.widget.layout.margin = dict(l=0, r=0, b=0, t=0)
self.widget.layout.hoverdistance = 1
def _create_2d_trace(self, stream_vis, mode, hoverinfo, marker, line):
yaxis = 'y' + (str(stream_vis.index + 1) if stream_vis.separate_yaxis else '')
trace = go.Scatter(x=[], y=[], mode=mode, name=stream_vis.title or stream_vis.ytitle, yaxis=yaxis, hoverinfo=hoverinfo,
line=line, marker=marker)
return trace
def _create_3d_trace(self, stream_vis, mode, hoverinfo, marker, line):
trace = go.Scatter3d(x=[], y=[], z=[], mode=mode, name=stream_vis.title or stream_vis.ytitle, hoverinfo=hoverinfo,
line=line, marker=marker)
return trace
def _create_trace(self, stream_vis):
stream_vis.separate_yaxis = stream_vis.stream_vis_args.get('separate_yaxis', True)
stream_vis.xtitle = stream_vis.stream_vis_args.get('xtitle',None)
stream_vis.ytitle = stream_vis.stream_vis_args.get('ytitle',None)
stream_vis.ztitle = stream_vis.stream_vis_args.get('ztitle',None)
stream_vis.color = stream_vis.stream_vis_args.get('color',None)
stream_vis.xrange = stream_vis.stream_vis_args.get('xrange',None)
stream_vis.yrange = stream_vis.stream_vis_args.get('yrange',None)
stream_vis.zrange = stream_vis.stream_vis_args.get('zrange',None)
draw_line = stream_vis.stream_vis_args.get('draw_line',True)
draw_marker = stream_vis.stream_vis_args.get('draw_marker',True)
draw_marker_text = stream_vis.stream_vis_args.get('draw_marker_text',False)
hoverinfo = stream_vis.stream_vis_args.get('hoverinfo',None)
marker = stream_vis.stream_vis_args.get('marker',{})
line = stream_vis.stream_vis_args.get('line',{})
utils.set_default(line, 'color', stream_vis.color or BasePlotlyPlot.get_pallet_color(stream_vis.index))
mode = 'lines' if draw_line else ''
if draw_marker:
mode = ('' if mode=='' else mode+'+') + 'markers'
if draw_marker_text:
mode = ('' if mode=='' else mode+'+') + 'text'
if self.is_3d:
return self._create_3d_trace(stream_vis, mode, hoverinfo, marker, line)
else:
return self._create_2d_trace(stream_vis, mode, hoverinfo, marker, line)
def _show_stream_items(self, stream_vis, stream_items):
vals = self._extract_vals(stream_items)
if not len(vals):
return False
# get trace data
trace = self.widget.data[stream_vis.trace_index]
xdata, ydata, zdata, anndata, txtdata, clrdata = list(trace.x), list(trace.y), [], [], [], []
if self.is_3d:
zdata = list(trace.z)
unpacker = lambda a0=None,a1=None,a2=None,a3=None,a4=None,a5=None, *_:(a0,a1,a2,a3,a4,a5)
# add each value in trace data
# each value is of the form:
# 2D graphs:
# y
# x [, y [, annotation [, text [, color]]]]
# y
# x [, y [, z, [annotation [, text [, color]]]]]
for val in vals:
# set defaults
x, y, z = len(xdata), None, None
ann, txt, clr = None, None, None
# if val turns out to be array-like, extract x,y
val_l = utils.is_scaler_array(val)
if val_l >= 0:
if self.is_3d:
x, y, z, ann, txt, clr = unpacker(*val)
else:
x, y, ann, txt, clr, _ = unpacker(*val)
elif isinstance(val, EventVars):
x = val.x if hasattr(val, 'x') else x
y = val.y if hasattr(val, 'y') else y
z = val.z if hasattr(val, 'z') else z
ann = val.ann if hasattr(val, 'ann') else ann
txt = val.txt if hasattr(val, 'txt') else txt
clr = val.clr if hasattr(val, 'clr') else clr
if y is None:
y = next(iter(val.__dict__.values()))
else:
y = val
if ann is not None:
ann = str(ann)
if txt is not None:
txt = str(txt)
xdata.append(x)
ydata.append(y)
zdata.append(z)
if txt is not None:
txtdata.append(txt)
if clr is not None:
clrdata.append(clr)
if ann: #TODO: yref should be y2 for different y axis
anndata.append(dict(x=x, y=y, xref='x', yref='y', text=ann, showarrow=False))
self.widget.data[stream_vis.trace_index].x = xdata
self.widget.data[stream_vis.trace_index].y = ydata
if self.is_3d:
self.widget.data[stream_vis.trace_index].z = zdata
# add text
if len(txtdata):
exisitng = self.widget.data[stream_vis.trace_index].text
exisitng = list(exisitng) if utils.is_array_like(exisitng) else []
exisitng += txtdata
self.widget.data[stream_vis.trace_index].text = exisitng
# add annotation
if len(anndata):
existing = list(self.widget.layout.annotations)
existing += anndata
self.widget.layout.annotations = existing
# add color
if len(clrdata):
exisitng = self.widget.data[stream_vis.trace_index].marker.color
exisitng = list(exisitng) if utils.is_array_like(exisitng) else []
exisitng += clrdata
self.widget.data[stream_vis.trace_index].marker.color = exisitng
return True
def clear_plot(self, stream_vis, clear_history):
traces = range(len(stream_vis.trace_history)) if clear_history else (stream_vis.trace_index,)
for i in traces:
stream_vis.trace_index = i
self.widget.data[stream_vis.trace_index].x = []
self.widget.data[stream_vis.trace_index].y = []
if self.is_3d:
self.widget.data[stream_vis.trace_index].z = []
self.widget.data[stream_vis.trace_index].text = ""
# TODO: avoid removing annotations for other streams
self.widget.layout.annotations = []

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torchvision import models, transforms
import torch
import torch.nn.functional as F
from . import utils, image_utils
import os
def get_model(model_name):
model = models.__dict__[model_name](pretrained=True)
return model
def tensors2batch(tensors, preprocess_transform=None):
if preprocess_transform:
tensors = tuple(preprocess_transform(i) for i in tensors)
if not utils.is_array_like(tensors):
tensors = tuple(tensors)
return torch.stack(tensors, dim=0)
def int2tensor(val):
return torch.LongTensor([val])
def image_class2tensor(image_path, class_index=None, image_convert_mode=None,
image_transform=None):
raw_input = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
if image_transform:
input_x = image_transform(raw_input)
else:
input_x = transforms.ToTensor()(raw_input)
input_x = input_x.unsqueeze(0) #convert to batch of 1
target_class = int2tensor(class_index) if class_index is not None else None
return raw_input, input_x, target_class
def batch_predict(model, inputs, input_transform=None, device=None):
if input_transform:
batch = torch.stack(tuple(input_transform(i) for i in inputs), dim=0)
else:
batch = torch.stack(inputs, dim=0)
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
model.to(device)
batch = batch.to(device)
outputs = model(batch)
return outputs
def logits2probabilities(logits):
return F.softmax(logits, dim=1)
def tensor2numpy(t):
return t.detach().cpu().numpy()

Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#from receptivefield.pytorch import PytorchReceptiveField
##from receptivefield.image import get_default_image
#import numpy as np
#def _get_rf(model, sample_pil_img):
# # define model functions
# def model_fn():
# model.eval()
# return model
# input_shape = np.array(sample_pil_img).shape
# rf = PytorchReceptiveField(model_fn)
# rf_params = rf.compute(input_shape=input_shape)
# return rf, rf_params
#def plot_receptive_field(model, sample_pil_img, layout=(2, 2), figsize=(6, 6)):
# rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
# return rf.plot_rf_grids(
# custom_image=sample_pil_img,
# figsize=figsize,
# layout=layout)
#def plot_grads_at(model, sample_pil_img, feature_map_index=0, point=(8,8), figsize=(6, 6)):
# rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
# return rf.plot_gradient_at(fm_id=feature_map_index, point=point, image=None, figsize=figsize)

Просмотреть файл

@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import threading
import time
import weakref
class RepeatedTimer:
class State:
Stopped=0
Paused=1
Running=2
def __init__(self, secs, callback, count=None):
self.secs = secs
self.callback = weakref.WeakMethod(callback) if callback else None
self._thread = None
self._state = RepeatedTimer.State.Stopped
self.pause_wait = threading.Event()
self.pause_wait.set()
self._continue_thread = False
self.count = count
def start(self):
self._continue_thread = True
self.pause_wait.set()
if self._thread is None or not self._thread.isAlive():
self._thread = threading.Thread(target=self._runner, name='RepeatedTimer', daemon=True)
self._thread.start()
self._state = RepeatedTimer.State.Running
def stop(self, block=False):
self.pause_wait.set()
self._continue_thread = False
if block and not (self._thread is None or not self._thread.isAlive()):
self._thread.join()
self._state = RepeatedTimer.State.Stopped
def get_state(self):
return self._state
def pause(self):
if self._state == RepeatedTimer.State.Running:
self.pause_wait.clear()
self._state = RepeatedTimer.State.Paused
# else nothing to do
def unpause(self):
if self._state == RepeatedTimer.State.Paused:
self.pause_wait.set()
if self._state == RepeatedTimer.State.Paused:
self._state = RepeatedTimer.State.Running
# else nothing to do
def _runner(self):
while (self._continue_thread):
if self.count:
self.count -= 0
if not self.count:
self._continue_thread = False
if self._continue_thread:
self.pause_wait.wait()
if self.callback and self.callback():
self.callback()()
if self._continue_thread:
time.sleep(self.secs)
self._thread = None
self._state = RepeatedTimer.State.Stopped

Просмотреть файл

@ -0,0 +1,5 @@
# Credits
Code in this folder is adopted from
* https://github.com/yulongwang12/visual-attribution
* https://github.com/marcotcr/lime

Просмотреть файл

@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Просмотреть файл

@ -0,0 +1,153 @@
import numpy as np
from torch.autograd import Variable, Function
import torch
import types
class VanillaGradExplainer(object):
def __init__(self, model):
self.model = model
def _backprop(self, inp, ind):
inp.requires_grad = True
if inp.grad is not None:
inp.grad.zero_()
if ind.grad is not None:
ind.grad.zero_()
self.model.eval()
self.model.zero_grad()
output = self.model(inp)
if ind is None:
ind = output.max(1)[1]
grad_out = output.clone()
grad_out.fill_(0.0)
grad_out.scatter_(1, ind.unsqueeze(0).t(), 1.0)
output.backward(grad_out)
return inp.grad
def explain(self, inp, ind=None, raw_inp=None):
return self._backprop(inp, ind)
class GradxInputExplainer(VanillaGradExplainer):
def __init__(self, model):
super(GradxInputExplainer, self).__init__(model)
def explain(self, inp, ind=None, raw_inp=None):
grad = self._backprop(inp, ind)
return inp * grad
class SaliencyExplainer(VanillaGradExplainer):
def __init__(self, model):
super(SaliencyExplainer, self).__init__(model)
def explain(self, inp, ind=None, raw_inp=None):
grad = self._backprop(inp, ind)
return grad.abs()
class IntegrateGradExplainer(VanillaGradExplainer):
def __init__(self, model, steps=100):
super(IntegrateGradExplainer, self).__init__(model)
self.steps = steps
def explain(self, inp, ind=None, raw_inp=None):
grad = 0
inp_data = inp.clone()
for alpha in np.arange(1 / self.steps, 1.0, 1 / self.steps):
new_inp = Variable(inp_data * alpha, requires_grad=True)
g = self._backprop(new_inp, ind)
grad += g
return grad * inp_data / self.steps
class DeconvExplainer(VanillaGradExplainer):
def __init__(self, model):
super(DeconvExplainer, self).__init__(model)
self._override_backward()
def _override_backward(self):
class _ReLU(Function):
@staticmethod
def forward(ctx, input):
output = torch.clamp(input, min=0)
return output
@staticmethod
def backward(ctx, grad_output):
grad_inp = torch.clamp(grad_output, min=0)
return grad_inp
def new_forward(self, x):
return _ReLU.apply(x)
def replace(m):
if m.__class__.__name__ == 'ReLU':
m.forward = types.MethodType(new_forward, m)
self.model.apply(replace)
class GuidedBackpropExplainer(VanillaGradExplainer):
def __init__(self, model):
super(GuidedBackpropExplainer, self).__init__(model)
self._override_backward()
def _override_backward(self):
class _ReLU(Function):
@staticmethod
def forward(ctx, input):
output = torch.clamp(input, min=0)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
mask1 = (output > 0).float()
mask2 = (grad_output > 0).float()
grad_inp = mask1 * mask2 * grad_output
grad_output.copy_(grad_inp)
return grad_output
def new_forward(self, x):
return _ReLU.apply(x)
def replace(m):
if m.__class__.__name__ == 'ReLU':
m.forward = types.MethodType(new_forward, m)
self.model.apply(replace)
# modified from https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L80
class SmoothGradExplainer(object):
def __init__(self, model, base_explainer=None, stdev_spread=0.15,
nsamples=25, magnitude=True):
self.base_explainer = base_explainer or VanillaGradExplainer(model)
self.stdev_spread = stdev_spread
self.nsamples = nsamples
self.magnitude = magnitude
def explain(self, inp, ind=None, raw_inp=None):
stdev = self.stdev_spread * (inp.max() - inp.min())
total_gradients = 0
for i in range(self.nsamples):
noise = torch.randn_like(inp) * stdev
noisy_inp = inp + noise
noisy_inp.retain_grad()
grad = self.base_explainer.explain(noisy_inp, ind)
if self.magnitude:
total_gradients += grad ** 2
else:
total_gradients += grad
return total_gradients / self.nsamples

Просмотреть файл

@ -0,0 +1,80 @@
from .backprop import GradxInputExplainer
import types
import torch.nn.functional as F
from torch.autograd import Variable
# Based on formulation in DeepExplain, https://arxiv.org/abs/1711.06104
# https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L221-L272
class DeepLIFTRescaleExplainer(GradxInputExplainer):
def __init__(self, model):
super(DeepLIFTRescaleExplainer, self).__init__(model)
self._prepare_reference()
self.baseline_inp = None
self._override_backward()
def _prepare_reference(self):
def init_refs(m):
name = m.__class__.__name__
if name.find('ReLU') != -1:
m.ref_inp_list = []
m.ref_out_list = []
def ref_forward(self, x):
self.ref_inp_list.append(x.data.clone())
out = F.relu(x)
self.ref_out_list.append(out.data.clone())
return out
def ref_replace(m):
name = m.__class__.__name__
if name.find('ReLU') != -1:
m.forward = types.MethodType(ref_forward, m)
self.model.apply(init_refs)
self.model.apply(ref_replace)
def _reset_preference(self):
def reset_refs(m):
name = m.__class__.__name__
if name.find('ReLU') != -1:
m.ref_inp_list = []
m.ref_out_list = []
self.model.apply(reset_refs)
def _baseline_forward(self, inp):
if self.baseline_inp is None:
self.baseline_inp = inp.data.clone()
self.baseline_inp.fill_(0.0)
self.baseline_inp = Variable(self.baseline_inp)
else:
self.baseline_inp.fill_(0.0)
# get ref
_ = self.model(self.baseline_inp)
def _override_backward(self):
def new_backward(self, grad_out):
ref_inp, inp = self.ref_inp_list
ref_out, out = self.ref_out_list
delta_out = out - ref_out
delta_in = inp - ref_inp
g1 = (delta_in.abs() > 1e-5).float() * grad_out * \
delta_out / delta_in
mask = ((ref_inp + inp) > 0).float()
g2 = (delta_in.abs() <= 1e-5).float() * 0.5 * mask * grad_out
return g1 + g2
def backward_replace(m):
name = m.__class__.__name__
if name.find('ReLU') != -1:
m.backward = types.MethodType(new_backward, m)
self.model.apply(backward_replace)
def explain(self, inp, ind=None, raw_inp=None):
self._reset_preference()
self._baseline_forward(inp)
g = super(DeepLIFTRescaleExplainer, self).explain(inp, ind)
return g

Просмотреть файл

@ -0,0 +1,242 @@
import torch
import numpy as np
from .inverter_util import RelevancePropagator
class EpsilonLrp(object):
def __init__(self, model):
self.model = InnvestigateModel(model)
def explain(self, inp, ind=None, raw_inp=None):
predicitions, saliency = self.model.innvestigate(inp, ind)
return saliency
class InnvestigateModel(torch.nn.Module):
"""
ATTENTION:
Currently, innvestigating a network only works if all
layers that have to be inverted are specified explicitly
and registered as a module. If., for example,
only the functional max_poolnd is used, the inversion will not work.
"""
def __init__(self, the_model, lrp_exponent=1, beta=.5, epsilon=1e-6,
method="e-rule"):
"""
Model wrapper for pytorch models to 'innvestigate' them
with layer-wise relevance propagation (LRP) as introduced by Bach et. al
(https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140).
Given a class level probability produced by the model under consideration,
the LRP algorithm attributes this probability to the nodes in each layer.
This allows for visualizing the relevance of input pixels on the resulting
class probability.
Args:
the_model: Pytorch model, e.g. a pytorch.nn.Sequential consisting of
different layers. Not all layers are supported yet.
lrp_exponent: Exponent for rescaling the importance values per node
in a layer when using the e-rule method.
beta: Beta value allows for placing more (large beta) emphasis on
nodes that positively contribute to the activation of a given node
in the subsequent layer. Low beta value allows for placing more emphasis
on inhibitory neurons in a layer. Only relevant for method 'b-rule'.
epsilon: Stabilizing term to avoid numerical instabilities if the norm (denominator
for distributing the relevance) is close to zero.
method: Different rules for the LRP algorithm, b-rule allows for placing
more or less focus on positive / negative contributions, whereas
the e-rule treats them equally. For more information,
see the paper linked above.
"""
super(InnvestigateModel, self).__init__()
self.model = the_model
self.device = torch.device("cpu", 0)
self.prediction = None
self.r_values_per_layer = None
self.only_max_score = None
# Initialize the 'Relevance Propagator' with the chosen rule.
# This will be used to back-propagate the relevance values
# through the layers in the innvestigate method.
self.inverter = RelevancePropagator(lrp_exponent=lrp_exponent,
beta=beta, method=method, epsilon=epsilon,
device=self.device)
# Parsing the individual model layers
self.register_hooks(self.model)
if method == "b-rule" and float(beta) in (-1., 0):
which = "positive" if beta == -1 else "negative"
which_opp = "negative" if beta == -1 else "positive"
print("WARNING: With the chosen beta value, "
"only " + which + " contributions "
"will be taken into account.\nHence, "
"if in any layer only " + which_opp +
" contributions exist, the "
"overall relevance will not be conserved.\n")
def cuda(self, device=None):
self.device = torch.device("cuda", device)
self.inverter.device = self.device
return super(InnvestigateModel, self).cuda(device)
def cpu(self):
self.device = torch.device("cpu", 0)
self.inverter.device = self.device
return super(InnvestigateModel, self).cpu()
def register_hooks(self, parent_module):
"""
Recursively unrolls a model and registers the required
hooks to save all the necessary values for LRP in the forward pass.
Args:
parent_module: Model to unroll and register hooks for.
Returns:
None
"""
for mod in parent_module.children():
if list(mod.children()):
self.register_hooks(mod)
continue
mod.register_forward_hook(
self.inverter.get_layer_fwd_hook(mod))
if isinstance(mod, torch.nn.ReLU):
mod.register_backward_hook(
self.relu_hook_function
)
@staticmethod
def relu_hook_function(module, grad_in, grad_out):
"""
If there is a negative gradient, change it to zero.
"""
return (torch.clamp(grad_in[0], min=0.0),)
def __call__(self, in_tensor):
"""
The innvestigate wrapper returns the same prediction as the
original model, but wraps the model call method in the evaluate
method to save the last prediction.
Args:
in_tensor: Model input to pass through the pytorch model.
Returns:
Model output.
"""
return self.evaluate(in_tensor)
def evaluate(self, in_tensor):
"""
Evaluates the model on a new input. The registered forward hooks will
save all the data that is necessary to compute the relevance per neuron per layer.
Args:
in_tensor: New input for which to predict an output.
Returns:
Model prediction
"""
# Reset module list. In case the structure changes dynamically,
# the module list is tracked for every forward pass.
self.inverter.reset_module_list()
self.prediction = self.model(in_tensor)
return self.prediction
def get_r_values_per_layer(self):
if self.r_values_per_layer is None:
print("No relevances have been calculated yet, returning None in"
" get_r_values_per_layer.")
return self.r_values_per_layer
def innvestigate(self, in_tensor=None, rel_for_class=None):
"""
Method for 'innvestigating' the model with the LRP rule chosen at
the initialization of the InnvestigateModel.
Args:
in_tensor: Input for which to evaluate the LRP algorithm.
If input is None, the last evaluation is used.
If no evaluation has been performed since initialization,
an error is raised.
rel_for_class (int): Index of the class for which the relevance
distribution is to be analyzed. If None, the 'winning' class
is used for indexing.
Returns:
Model output and relevances of nodes in the input layer.
In order to get relevance distributions in other layers, use
the get_r_values_per_layer method.
"""
if self.r_values_per_layer is not None:
for elt in self.r_values_per_layer:
del elt
self.r_values_per_layer = None
with torch.no_grad():
# Check if innvestigation can be performed.
if in_tensor is None and self.prediction is None:
raise RuntimeError("Model needs to be evaluated at least "
"once before an innvestigation can be "
"performed. Please evaluate model first "
"or call innvestigate with a new input to "
"evaluate.")
# Evaluate the model anew if a new input is supplied.
if in_tensor is not None:
self.evaluate(in_tensor)
# If no class index is specified, analyze for class
# with highest prediction.
if rel_for_class is None:
# Default behaviour is innvestigating the output
# on an arg-max-basis, if no class is specified.
org_shape = self.prediction.size()
# Make sure shape is just a 1D vector per batch example.
self.prediction = self.prediction.view(org_shape[0], -1)
max_v, _ = torch.max(self.prediction, dim=1, keepdim=True)
only_max_score = torch.zeros_like(self.prediction).to(self.device)
only_max_score[max_v == self.prediction] = self.prediction[max_v == self.prediction]
relevance_tensor = only_max_score.view(org_shape)
self.prediction.view(org_shape)
else:
org_shape = self.prediction.size()
self.prediction = self.prediction.view(org_shape[0], -1)
only_max_score = torch.zeros_like(self.prediction).to(self.device)
only_max_score[:, rel_for_class] += self.prediction[:, rel_for_class]
relevance_tensor = only_max_score.view(org_shape)
self.prediction.view(org_shape)
# We have to iterate through the model backwards.
# The module list is computed for every forward pass
# by the model inverter.
rev_model = self.inverter.module_list[::-1]
relevance = relevance_tensor.detach()
del relevance_tensor
# List to save relevance distributions per layer
r_values_per_layer = [relevance]
for layer in rev_model:
# Compute layer specific backwards-propagation of relevance values
relevance = self.inverter.compute_propagated_relevance(layer, relevance)
r_values_per_layer.append(relevance.cpu())
self.r_values_per_layer = r_values_per_layer
del relevance
if self.device.type == "cuda":
torch.cuda.empty_cache()
return self.prediction, r_values_per_layer[-1]
def forward(self, in_tensor):
return self.model.forward(in_tensor)
def extra_repr(self):
r"""Set the extra representation of the module
To print customized extra information, you should re-implement
this method in your own modules. Both single-line and multi-line
strings are acceptable.
"""
return self.model.extra_repr()

Просмотреть файл

@ -0,0 +1,56 @@
import torch
from .backprop import VanillaGradExplainer
def _get_layer(model, key_list):
a = model
for key in key_list:
a = a._modules[key]
return a
class GradCAMExplainer(VanillaGradExplainer):
def __init__(self, model, target_layer_name_keys=None, use_inp=False):
super(GradCAMExplainer, self).__init__(model)
self.target_layer = _get_layer(model, target_layer_name_keys)
self.use_inp = use_inp
self.intermediate_act = []
self.intermediate_grad = []
self._register_forward_backward_hook()
def _register_forward_backward_hook(self):
def forward_hook_input(m, i, o):
self.intermediate_act.append(i[0].data.clone())
def forward_hook_output(m, i, o):
self.intermediate_act.append(o.data.clone())
def backward_hook(m, grad_i, grad_o):
self.intermediate_grad.append(grad_o[0].data.clone())
if self.use_inp:
self.target_layer.register_forward_hook(forward_hook_input)
else:
self.target_layer.register_forward_hook(forward_hook_output)
self.target_layer.register_backward_hook(backward_hook)
def _reset_intermediate_lists(self):
self.intermediate_act = []
self.intermediate_grad = []
def explain(self, inp, ind=None, raw_inp=None):
self._reset_intermediate_lists()
_ = super(GradCAMExplainer, self)._backprop(inp, ind)
grad = self.intermediate_grad[0]
act = self.intermediate_act[0]
weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
cam = weights * act
cam = cam.sum(1).unsqueeze(1)
cam = torch.clamp(cam, min=0)
return cam

Просмотреть файл

@ -0,0 +1,526 @@
import torch
import torch.nn
import numpy as np
import torch.nn.functional as F
class Flatten(torch.nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, in_tensor):
return in_tensor.view((in_tensor.size()[0], -1))
def module_tracker(fwd_hook_func):
"""
Wrapper for tracking the layers throughout the forward pass.
Args:
fwd_hook_func: Forward hook function to be wrapped.
Returns:
Wrapped method.
"""
def hook_wrapper(relevance_propagator_instance, layer, *args):
relevance_propagator_instance.module_list.append(layer)
return fwd_hook_func(relevance_propagator_instance, layer, *args)
return hook_wrapper
class RelevancePropagator:
"""
Class for computing the relevance propagation and supplying
the necessary forward hooks for all layers.
"""
# All layers that do not require any specific forward hooks.
# This is due to the fact that they are all one-to-one
# mappings and hence no normalization is needed (each node only
# influences exactly one other node -> relevance conservation
# ensures that relevance is just inherited in a one-to-one manner, too).
allowed_pass_layers = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.ReLU, torch.nn.ELU, Flatten,
torch.nn.Dropout, torch.nn.Dropout2d,
torch.nn.Dropout3d,
torch.nn.Softmax,
torch.nn.LogSoftmax)
# Implemented rules for relevance propagation.
available_methods = ["e-rule", "b-rule"]
def __init__(self, lrp_exponent, beta, method, epsilon, device):
self.device = device
self.layer = None
self.p = lrp_exponent
self.beta = beta
self.eps = epsilon
self.warned_log_softmax = False
self.module_list = []
if method not in self.available_methods:
raise NotImplementedError("Only methods available are: " +
str(self.available_methods))
self.method = method
def reset_module_list(self):
"""
The module list is reset for every evaluation, in change the order or number
of layers changes dynamically.
Returns:
None
"""
self.module_list = []
# Try to free memory
if self.device.type == "cuda":
torch.cuda.empty_cache()
def compute_propagated_relevance(self, layer, relevance):
"""
This method computes the backward pass for the incoming relevance
for the specified layer.
Args:
layer: Layer to be reverted.
relevance: Incoming relevance from higher up in the network.
Returns:
The
"""
if isinstance(layer,
(torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)):
return self.max_pool_nd_inverse(layer, relevance).detach()
elif isinstance(layer,
(torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
return self.conv_nd_inverse(layer, relevance).detach()
elif isinstance(layer, torch.nn.LogSoftmax):
# Only layer that does not conserve relevance. Mainly used
# to make probability out of the log values. Should probably
# be changed to pure passing and the user should make sure
# the layer outputs are sensible (0 would be 100% class probability,
# but no relevance could be passed on).
if relevance.sum() < 0:
relevance[relevance == 0] = -1e6
relevance = relevance.exp()
if not self.warned_log_softmax:
print("WARNING: LogSoftmax layer was "
"turned into probabilities.")
self.warned_log_softmax = True
return relevance
elif isinstance(layer, self.allowed_pass_layers):
# The above layers are one-to-one mappings of input to
# output nodes. All the relevance in the output will come
# entirely from the input node. Given the conservation
# of relevance, the input is as relevant as the output.
return relevance
elif isinstance(layer, torch.nn.Linear):
return self.linear_inverse(layer, relevance).detach()
else:
raise NotImplementedError("The network contains layers that"
" are currently not supported {0:s}".format(str(layer)))
def get_layer_fwd_hook(self, layer):
"""
Each layer might need to save very specific data during the forward
pass in order to allow for relevance propagation in the backward
pass. For example, for max_pooling, we need to store the
indices of the max values. In convolutional layers, we need to calculate
the normalizations, to ensure the overall amount of relevance is conserved.
Args:
layer: Layer instance for which forward hook is needed.
Returns:
Layer-specific forward hook.
"""
if isinstance(layer,
(torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)):
return self.max_pool_nd_fwd_hook
if isinstance(layer,
(torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
return self.conv_nd_fwd_hook
if isinstance(layer, self.allowed_pass_layers):
return self.silent_pass # No hook needed.
if isinstance(layer, torch.nn.Linear):
return self.linear_fwd_hook
else:
raise NotImplementedError("The network contains layers that"
" are currently not supported {0:s}".format(str(layer)))
@staticmethod
def get_conv_method(conv_module):
"""
Get dimension-specific convolution.
The forward pass and inversion are made in a
'dimensionality-agnostic' manner and are the same for
all nd instances of the layer, except for the functional
that needs to be used.
Args:
conv_module: instance of convolutional layer.
Returns:
The correct functional used in the convolutional layer.
"""
conv_func_mapper = {
torch.nn.Conv1d: F.conv1d,
torch.nn.Conv2d: F.conv2d,
torch.nn.Conv3d: F.conv3d
}
return conv_func_mapper[type(conv_module)]
@staticmethod
def get_inv_conv_method(conv_module):
"""
Get dimension-specific convolution inversion layer.
The forward pass and inversion are made in a
'dimensionality-agnostic' manner and are the same for
all nd instances of the layer, except for the functional
that needs to be used.
Args:
conv_module: instance of convolutional layer.
Returns:
The correct functional used for inverting the convolutional layer.
"""
conv_func_mapper = {
torch.nn.Conv1d: F.conv_transpose1d,
torch.nn.Conv2d: F.conv_transpose2d,
torch.nn.Conv3d: F.conv_transpose3d
}
return conv_func_mapper[type(conv_module)]
@module_tracker
def silent_pass(self, m, in_tensor: torch.Tensor,
out_tensor: torch.Tensor):
# Placeholder forward hook for layers that do not need
# to store any specific data. Still useful for module tracking.
pass
@staticmethod
def get_inv_max_pool_method(max_pool_instance):
"""
Get dimension-specific max_pooling layer.
The forward pass and inversion are made in a
'dimensionality-agnostic' manner and are the same for
all nd instances of the layer, except for the functional
that needs to be used.
Args:
max_pool_instance: instance of max_pool layer.
Returns:
The correct functional used in the max_pooling layer.
"""
conv_func_mapper = {
torch.nn.MaxPool1d: F.max_unpool1d,
torch.nn.MaxPool2d: F.max_unpool2d,
torch.nn.MaxPool3d: F.max_unpool3d
}
return conv_func_mapper[type(max_pool_instance)]
def linear_inverse(self, m, relevance_in):
if self.method == "e-rule":
m.in_tensor = m.in_tensor.pow(self.p)
w = m.weight.pow(self.p)
norm = F.linear(m.in_tensor, w, bias=None)
norm = norm + torch.sign(norm) * self.eps
relevance_in[norm == 0] = 0
norm[norm == 0] = 1
relevance_out = F.linear(relevance_in / norm,
w.t(), bias=None)
relevance_out *= m.in_tensor
del m.in_tensor, norm, w, relevance_in
return relevance_out
if self.method == "b-rule":
out_c, in_c = m.weight.size()
w = m.weight.repeat((4, 1))
# First and third channel repetition only contain the positive weights
w[:out_c][w[:out_c] < 0] = 0
w[2 * out_c:3 * out_c][w[2 * out_c:3 * out_c] < 0] = 0
# Second and fourth channel repetition with only the negative weights
w[1 * out_c:2 * out_c][w[1 * out_c:2 * out_c] > 0] = 0
w[-out_c:][w[-out_c:] > 0] = 0
# Repeat across channel dimension (pytorch always has channels first)
m.in_tensor = m.in_tensor.repeat((1, 4))
m.in_tensor[:, :in_c][m.in_tensor[:, :in_c] < 0] = 0
m.in_tensor[:, -in_c:][m.in_tensor[:, -in_c:] < 0] = 0
m.in_tensor[:, 1 * in_c:3 * in_c][m.in_tensor[:, 1 * in_c:3 * in_c] > 0] = 0
# Normalize such that the sum of the individual importance values
# of the input neurons divided by the norm
# yields 1 for an output neuron j if divided by norm (v_ij in paper).
# Norm layer just sums the importance values of the inputs
# contributing to output j for each j. This will then serve as the normalization
# such that the contributions of the neurons sum to 1 in order to
# properly split up the relevance of j amongst its roots.
norm_shape = m.out_shape
norm_shape[1] *= 4
norm = torch.zeros(norm_shape).to(self.device)
for i in range(4):
norm[:, out_c * i:(i + 1) * out_c] = F.linear(
m.in_tensor[:, in_c * i:(i + 1) * in_c], w[out_c * i:(i + 1) * out_c], bias=None)
# Double number of output channels for positive and negative norm per
# channel.
norm_shape[1] = norm_shape[1] // 2
new_norm = torch.zeros(norm_shape).to(self.device)
new_norm[:, :out_c] = norm[:, :out_c] + norm[:, out_c:2 * out_c]
new_norm[:, out_c:] = norm[:, 2 * out_c:3 * out_c] + norm[:, 3 * out_c:]
norm = new_norm
# Some 'rare' neurons only receive either
# only positive or only negative inputs.
# Conservation of relevance does not hold, if we also
# rescale those neurons by (1+beta) or -beta.
# Therefore, catch those first and scale norm by
# the according value, such that it cancels in the fraction.
# First, however, avoid NaNs.
mask = norm == 0
# Set the norm to anything non-zero, e.g. 1.
# The actual inputs are zero at this point anyways, that
# is why norm is zero in the first place.
norm[mask] = 1
# The norm in the b-rule has shape (N, 2*out_c, *spatial_dims).
# The first out_c block corresponds to the positive norms,
# the second out_c block corresponds to the negative norms.
# We find the rare neurons by choosing those nodes per channel
# in which either the positive norm ([:, :out_c]) is zero, or
# the negative norm ([:, :out_c]) is zero.
rare_neurons = (mask[:, :out_c] + mask[:, out_c:])
# Also, catch new possibilities for norm == zero to avoid NaN..
# The actual value of norm again does not really matter, since
# the pre-factor will be zero in this case.
norm[:, :out_c][rare_neurons] *= 1 if self.beta == -1 else 1 + self.beta
norm[:, out_c:][rare_neurons] *= 1 if self.beta == 0 else -self.beta
# Add stabilizer term to norm to avoid numerical instabilities.
norm += self.eps * torch.sign(norm)
input_relevance = relevance_in.squeeze(dim=-1).repeat(1, 4)
input_relevance[:, :2*out_c] *= (1+self.beta)/norm[:, :out_c].repeat(1, 2)
input_relevance[:, 2*out_c:] *= -self.beta/norm[:, out_c:].repeat(1, 2)
inv_w = w.t()
relevance_out = torch.zeros_like(m.in_tensor)
for i in range(4):
relevance_out[:, i*in_c:(i+1)*in_c] = F.linear(
input_relevance[:, i*out_c:(i+1)*out_c],
weight=inv_w[:, i*out_c:(i+1)*out_c], bias=None)
relevance_out *= m.in_tensor
sum_weights = torch.zeros([in_c, in_c * 4, 1]).to(self.device)
for i in range(in_c):
sum_weights[i, i::in_c] = 1
relevance_out = F.conv1d(relevance_out[:, :, None], weight=sum_weights, bias=None)
del sum_weights, input_relevance, norm, rare_neurons, \
mask, new_norm, m.in_tensor, w, inv_w
return relevance_out
@module_tracker
def linear_fwd_hook(self, m, in_tensor: torch.Tensor,
out_tensor: torch.Tensor):
setattr(m, "in_tensor", in_tensor[0])
setattr(m, "out_shape", list(out_tensor.size()))
return
def max_pool_nd_inverse(self, layer_instance, relevance_in):
# In case the output had been reshaped for a linear layer,
# make sure the relevance is put into the same shape as before.
relevance_in = relevance_in.view(layer_instance.out_shape)
invert_pool = self.get_inv_max_pool_method(layer_instance)
inverted = invert_pool(relevance_in, layer_instance.indices,
layer_instance.kernel_size, layer_instance.stride,
layer_instance.padding, output_size=layer_instance.in_shape)
del layer_instance.indices
return inverted
@module_tracker
def max_pool_nd_fwd_hook(self, m, in_tensor: torch.Tensor,
out_tensor: torch.Tensor):
# Ignore unused for pylint
_ = self
# Save the return indices value to make sure
tmp_return_indices = bool(m.return_indices)
m.return_indices = True
_, indices = m.forward(in_tensor[0])
m.return_indices = tmp_return_indices
setattr(m, "indices", indices)
setattr(m, 'out_shape', out_tensor.size())
setattr(m, 'in_shape', in_tensor[0].size())
def conv_nd_inverse(self, m, relevance_in):
# In case the output had been reshaped for a linear layer,
# make sure the relevance is put into the same shape as before.
relevance_in = relevance_in.view(m.out_shape)
# Get required values from layer
inv_conv_nd = self.get_inv_conv_method(m)
conv_nd = self.get_conv_method(m)
if self.method == "e-rule":
with torch.no_grad():
m.in_tensor = m.in_tensor.pow(self.p).detach()
w = m.weight.pow(self.p).detach()
norm = conv_nd(m.in_tensor, weight=w, bias=None,
stride=m.stride, padding=m.padding,
groups=m.groups)
norm = norm + torch.sign(norm) * self.eps
relevance_in[norm == 0] = 0
norm[norm == 0] = 1
relevance_out = inv_conv_nd(relevance_in/norm,
weight=w, bias=None,
padding=m.padding, stride=m.stride,
groups=m.groups)
relevance_out *= m.in_tensor
del m.in_tensor, norm, w
return relevance_out
if self.method == "b-rule":
with torch.no_grad():
w = m.weight
out_c, in_c = m.out_channels, m.in_channels
repeats = np.array(np.ones_like(w.size()).flatten(), dtype=int)
repeats[0] *= 4
w = w.repeat(tuple(repeats))
# First and third channel repetition only contain the positive weights
w[:out_c][w[:out_c] < 0] = 0
w[2 * out_c:3 * out_c][w[2 * out_c:3 * out_c] < 0] = 0
# Second and fourth channel repetition with only the negative weights
w[1 * out_c:2 * out_c][w[1 * out_c:2 * out_c] > 0] = 0
w[-out_c:][w[-out_c:] > 0] = 0
repeats = np.array(np.ones_like(m.in_tensor.size()).flatten(), dtype=int)
repeats[1] *= 4
# Repeat across channel dimension (pytorch always has channels first)
m.in_tensor = m.in_tensor.repeat(tuple(repeats))
m.in_tensor[:, :in_c][m.in_tensor[:, :in_c] < 0] = 0
m.in_tensor[:, -in_c:][m.in_tensor[:, -in_c:] < 0] = 0
m.in_tensor[:, 1 * in_c:3 * in_c][m.in_tensor[:, 1 * in_c:3 * in_c] > 0] = 0
groups = 4
# Normalize such that the sum of the individual importance values
# of the input neurons divided by the norm
# yields 1 for an output neuron j if divided by norm (v_ij in paper).
# Norm layer just sums the importance values of the inputs
# contributing to output j for each j. This will then serve as the normalization
# such that the contributions of the neurons sum to 1 in order to
# properly split up the relevance of j amongst its roots.
norm = conv_nd(m.in_tensor, weight=w, bias=None, stride=m.stride,
padding=m.padding, dilation=m.dilation, groups=groups * m.groups)
# Double number of output channels for positive and negative norm per
# channel. Using list with out_tensor.size() allows for ND generalization
new_shape = m.out_shape
new_shape[1] *= 2
new_norm = torch.zeros(new_shape).to(self.device)
new_norm[:, :out_c] = norm[:, :out_c] + norm[:, out_c:2 * out_c]
new_norm[:, out_c:] = norm[:, 2 * out_c:3 * out_c] + norm[:, 3 * out_c:]
norm = new_norm
# Some 'rare' neurons only receive either
# only positive or only negative inputs.
# Conservation of relevance does not hold, if we also
# rescale those neurons by (1+beta) or -beta.
# Therefore, catch those first and scale norm by
# the according value, such that it cancels in the fraction.
# First, however, avoid NaNs.
mask = norm == 0
# Set the norm to anything non-zero, e.g. 1.
# The actual inputs are zero at this point anyways, that
# is why norm is zero in the first place.
norm[mask] = 1
# The norm in the b-rule has shape (N, 2*out_c, *spatial_dims).
# The first out_c block corresponds to the positive norms,
# the second out_c block corresponds to the negative norms.
# We find the rare neurons by choosing those nodes per channel
# in which either the positive norm ([:, :out_c]) is zero, or
# the negative norm ([:, :out_c]) is zero.
rare_neurons = (mask[:, :out_c] + mask[:, out_c:])
# Also, catch new possibilities for norm == zero to avoid NaN..
# The actual value of norm again does not really matter, since
# the pre-factor will be zero in this case.
norm[:, :out_c][rare_neurons] *= 1 if self.beta == -1 else 1 + self.beta
norm[:, out_c:][rare_neurons] *= 1 if self.beta == 0 else -self.beta
# Add stabilizer term to norm to avoid numerical instabilities.
norm += self.eps * torch.sign(norm)
spatial_dims = [1] * len(relevance_in.size()[2:])
input_relevance = relevance_in.repeat(1, 4, *spatial_dims)
input_relevance[:, :2*out_c] *= (1+self.beta)/norm[:, :out_c].repeat(1, 2, *spatial_dims)
input_relevance[:, 2*out_c:] *= -self.beta/norm[:, out_c:].repeat(1, 2, *spatial_dims)
# Each of the positive / negative entries needs its own
# convolution. TODO: Can this be done in groups, too?
relevance_out = torch.zeros_like(m.in_tensor)
# Weird code to make up for loss of size due to stride
tmp_result = result = None
for i in range(4):
tmp_result = inv_conv_nd(
input_relevance[:, i*out_c:(i+1)*out_c],
weight=w[i*out_c:(i+1)*out_c],
bias=None, padding=m.padding, stride=m.stride,
groups=m.groups)
result = torch.zeros_like(relevance_out[:, i*in_c:(i+1)*in_c])
tmp_size = tmp_result.size()
slice_list = [slice(0, l) for l in tmp_size]
result[slice_list] += tmp_result
relevance_out[:, i*in_c:(i+1)*in_c] = result
relevance_out *= m.in_tensor
sum_weights = torch.zeros([in_c, in_c * 4, *spatial_dims]).to(self.device)
for i in range(m.in_channels):
sum_weights[i, i::in_c] = 1
relevance_out = conv_nd(relevance_out, weight=sum_weights, bias=None)
del sum_weights, m.in_tensor, result, mask, rare_neurons, norm, \
new_norm, input_relevance, tmp_result, w
return relevance_out
@module_tracker
def conv_nd_fwd_hook(self, m, in_tensor: torch.Tensor,
out_tensor: torch.Tensor):
setattr(m, "in_tensor", in_tensor[0])
setattr(m, 'out_shape', list(out_tensor.size()))
return

Просмотреть файл

Просмотреть файл

@ -0,0 +1,179 @@
"""
Contains abstract functionality for learning locally linear sparse model.
"""
from __future__ import print_function
import numpy as np
from sklearn.linear_model import Ridge, lars_path
from sklearn.utils import check_random_state
class LimeBase(object):
"""Class for learning a locally linear sparse model from perturbed data"""
def __init__(self,
kernel_fn,
verbose=False,
random_state=None):
"""Init function
Args:
kernel_fn: function that transforms an array of distances into an
array of proximity values (floats).
verbose: if true, print local prediction values from linear model.
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
self.kernel_fn = kernel_fn
self.verbose = verbose
self.random_state = check_random_state(random_state)
@staticmethod
def generate_lars_path(weighted_data, weighted_labels):
"""Generates the lars path for weighted data.
Args:
weighted_data: data that has been weighted by kernel
weighted_label: labels, weighted by kernel
Returns:
(alphas, coefs), both are arrays corresponding to the
regularization parameter and coefficients, respectively
"""
x_vector = weighted_data
alphas, _, coefs = lars_path(x_vector,
weighted_labels,
method='lasso',
verbose=False)
return alphas, coefs
def forward_selection(self, data, labels, weights, num_features):
"""Iteratively adds features to the model"""
clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
used_features = []
for _ in range(min(num_features, data.shape[1])):
max_ = -100000000
best = 0
for feature in range(data.shape[1]):
if feature in used_features:
continue
clf.fit(data[:, used_features + [feature]], labels,
sample_weight=weights)
score = clf.score(data[:, used_features + [feature]],
labels,
sample_weight=weights)
if score > max_:
best = feature
max_ = score
used_features.append(best)
return np.array(used_features)
def feature_selection(self, data, labels, weights, num_features, method):
"""Selects features for the model. see explain_instance_with_data to
understand the parameters."""
if method == 'none':
return np.array(range(data.shape[1]))
elif method == 'forward_selection':
return self.forward_selection(data, labels, weights, num_features)
elif method == 'highest_weights':
clf = Ridge(alpha=0, fit_intercept=True,
random_state=self.random_state)
clf.fit(data, labels, sample_weight=weights)
feature_weights = sorted(zip(range(data.shape[0]),
clf.coef_ * data[0]),
key=lambda x: np.abs(x[1]),
reverse=True)
return np.array([x[0] for x in feature_weights[:num_features]])
elif method == 'lasso_path':
weighted_data = ((data - np.average(data, axis=0, weights=weights))
* np.sqrt(weights[:, np.newaxis]))
weighted_labels = ((labels - np.average(labels, weights=weights))
* np.sqrt(weights))
nonzero = range(weighted_data.shape[1])
_, coefs = self.generate_lars_path(weighted_data,
weighted_labels)
for i in range(len(coefs.T) - 1, 0, -1):
nonzero = coefs.T[i].nonzero()[0]
if len(nonzero) <= num_features:
break
used_features = nonzero
return used_features
elif method == 'auto':
if num_features <= 6:
n_method = 'forward_selection'
else:
n_method = 'highest_weights'
return self.feature_selection(data, labels, weights,
num_features, n_method)
def explain_instance_with_data(self,
neighborhood_data,
neighborhood_labels,
distances,
label,
num_features,
feature_selection='auto',
model_regressor=None):
"""Takes perturbed data, labels and distances, returns explanation.
Args:
neighborhood_data: perturbed data, 2d array. first element is
assumed to be the original data point.
neighborhood_labels: corresponding perturbed labels. should have as
many columns as the number of possible labels.
distances: distances to original data point.
label: label for which we want an explanation
num_features: maximum number of features in explanation
feature_selection: how to select num_features. options are:
'forward_selection': iteratively add features to the model.
This is costly when num_features is high
'highest_weights': selects the features that have the highest
product of absolute weight * original data point when
learning with all the features
'lasso_path': chooses features based on the lasso
regularization path
'none': uses all features, ignores num_features
'auto': uses forward_selection if num_features <= 6, and
'highest_weights' otherwise.
model_regressor: sklearn regressor to use in explanation.
Defaults to Ridge regression if None. Must have
model_regressor.coef_ and 'sample_weight' as a parameter
to model_regressor.fit()
Returns:
(intercept, exp, score, local_pred):
intercept is a float.
exp is a sorted list of tuples, where each tuple (x,y) corresponds
to the feature id (x) and the local weight (y). The list is sorted
by decreasing absolute value of y.
score is the R^2 value of the returned explanation
local_pred is the prediction of the explanation model on the original instance
"""
weights = self.kernel_fn(distances)
labels_column = neighborhood_labels[:, label]
used_features = self.feature_selection(neighborhood_data,
labels_column,
weights,
num_features,
feature_selection)
if model_regressor is None:
model_regressor = Ridge(alpha=1, fit_intercept=True,
random_state=self.random_state)
easy_model = model_regressor
easy_model.fit(neighborhood_data[:, used_features],
labels_column, sample_weight=weights)
prediction_score = easy_model.score(
neighborhood_data[:, used_features],
labels_column, sample_weight=weights)
local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
if self.verbose:
print('Intercept', easy_model.intercept_)
print('Prediction_local', local_pred,)
print('Right:', neighborhood_labels[0, label])
return (easy_model.intercept_,
sorted(zip(used_features, easy_model.coef_),
key=lambda x: np.abs(x[1]), reverse=True),
prediction_score, local_pred)

Просмотреть файл

@ -0,0 +1,261 @@
"""
Functions for explaining classifiers that use Image data.
"""
import copy
from functools import partial
import numpy as np
import sklearn
import sklearn.preprocessing
from sklearn.utils import check_random_state
from skimage.color import gray2rgb
from . import lime_base
from .wrappers.scikit_image import SegmentationAlgorithm
class ImageExplanation(object):
def __init__(self, image, segments):
"""Init function.
Args:
image: 3d numpy array
segments: 2d numpy array, with the output from skimage.segmentation
"""
self.image = image
self.segments = segments
self.intercept = {}
self.local_exp = {}
self.local_pred = None
def get_image_and_mask(self, label, positive_only=True, hide_rest=False,
num_features=5, min_weight=0.):
"""Init function.
Args:
label: label to explain
positive_only: if True, only take superpixels that contribute to
the prediction of the label. Otherwise, use the top
num_features superpixels, which can be positive or negative
towards the label
hide_rest: if True, make the non-explanation part of the return
image gray
num_features: number of superpixels to include in explanation
min_weight: TODO
Returns:
(image, mask), where image is a 3d numpy array and mask is a 2d
numpy array that can be used with
skimage.segmentation.mark_boundaries
"""
if label not in self.local_exp:
raise KeyError('Label not in explanation')
segments = self.segments
image = self.image
exp = self.local_exp[label]
mask = np.zeros(segments.shape, segments.dtype)
if hide_rest:
temp = np.zeros(self.image.shape)
else:
temp = self.image.copy()
if positive_only:
fs = [x[0] for x in exp
if x[1] > 0 and x[1] > min_weight][:num_features]
for f in fs:
temp[segments == f] = image[segments == f].copy()
mask[segments == f] = 1
return temp, mask
else:
for f, w in exp[:num_features]:
if np.abs(w) < min_weight:
continue
c = 0 if w < 0 else 1
mask[segments == f] = 1 if w < 0 else 2
temp[segments == f] = image[segments == f].copy()
temp[segments == f, c] = np.max(image)
for cp in [0, 1, 2]:
if c == cp:
continue
# temp[segments == f, cp] *= 0.5
return temp, mask
class LimeImageExplainer(object):
"""Explains predictions on Image (i.e. matrix) data.
For numerical features, perturb them by sampling from a Normal(0,1) and
doing the inverse operation of mean-centering and scaling, according to the
means and stds in the training data. For categorical features, perturb by
sampling according to the training distribution, and making a binary
feature that is 1 when the value is the same as the instance being
explained."""
def __init__(self, kernel_width=.25, kernel=None, verbose=False,
feature_selection='auto', random_state=None):
"""Init function.
Args:
kernel_width: kernel width for the exponential kernel.
If None, defaults to sqrt(number of columns) * 0.75.
kernel: similarity kernel that takes euclidean distances and kernel
width as input and outputs weights in (0,1). If None, defaults to
an exponential kernel.
verbose: if true, print local prediction values from linear model
feature_selection: feature selection method. can be
'forward_selection', 'lasso_path', 'none' or 'auto'.
See function 'explain_instance_with_data' in lime_base.py for
details on what each of the options does.
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
kernel_width = float(kernel_width)
if kernel is None:
def kernel(d, kernel_width):
return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
kernel_fn = partial(kernel, kernel_width=kernel_width)
self.random_state = check_random_state(random_state)
self.feature_selection = feature_selection
self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state)
def explain_instance(self, image, classifier_fn, labels=(1,),
hide_color=None,
top_labels=5, num_features=100000, num_samples=1000,
batch_size=10,
segmentation_fn=None,
distance_metric='cosine',
model_regressor=None,
random_seed=None):
"""Generates explanations for a prediction.
First, we generate neighborhood data by randomly perturbing features
from the instance (see __data_inverse). We then learn locally weighted
linear models on this neighborhood data to explain each of the classes
in an interpretable way (see lime_base.py).
Args:
image: 3 dimension RGB image. If this is only two dimensional,
we will assume it's a grayscale image and call gray2rgb.
classifier_fn: classifier prediction probability function, which
takes a numpy array and outputs prediction probabilities. For
ScikitClassifiers , this is classifier.predict_proba.
labels: iterable with labels to be explained.
hide_color: TODO
top_labels: if not None, ignore labels and produce explanations for
the K labels with highest prediction probabilities, where K is
this parameter.
num_features: maximum number of features present in explanation
num_samples: size of the neighborhood to learn the linear model
batch_size: TODO
distance_metric: the distance metric to use for weights.
model_regressor: sklearn regressor to use in explanation. Defaults
to Ridge regression in LimeBase. Must have model_regressor.coef_
and 'sample_weight' as a parameter to model_regressor.fit()
segmentation_fn: SegmentationAlgorithm, wrapped skimage
segmentation function
random_seed: integer used as random seed for the segmentation
algorithm. If None, a random integer, between 0 and 1000,
will be generated using the internal random number generator.
Returns:
An Explanation object (see explanation.py) with the corresponding
explanations.
"""
if len(image.shape) == 2:
image = gray2rgb(image)
if random_seed is None:
random_seed = self.random_state.randint(0, high=1000)
if segmentation_fn is None:
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
max_dist=200, ratio=0.2,
random_seed=random_seed)
try:
segments = segmentation_fn(image)
except ValueError as e:
raise e
fudged_image = image.copy()
if hide_color is None:
for x in np.unique(segments):
fudged_image[segments == x] = (
np.mean(image[segments == x][:, 0]),
np.mean(image[segments == x][:, 1]),
np.mean(image[segments == x][:, 2]))
else:
fudged_image[:] = hide_color
top = labels
data, labels = self.data_labels(image, fudged_image, segments,
classifier_fn, num_samples,
batch_size=batch_size)
distances = sklearn.metrics.pairwise_distances(
data,
data[0].reshape(1, -1),
metric=distance_metric
).ravel()
ret_exp = ImageExplanation(image, segments)
if top_labels:
top = np.argsort(labels[0])[-top_labels:]
ret_exp.top_labels = list(top)
ret_exp.top_labels.reverse()
for label in top:
(ret_exp.intercept[label],
ret_exp.local_exp[label],
ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
data, labels, distances, label, num_features,
model_regressor=model_regressor,
feature_selection=self.feature_selection)
return ret_exp
def data_labels(self,
image,
fudged_image,
segments,
classifier_fn,
num_samples,
batch_size=10):
"""Generates images and predictions in the neighborhood of this image.
Args:
image: 3d numpy array, the image
fudged_image: 3d numpy array, image to replace original image when
superpixel is turned off
segments: segmentation of the image
classifier_fn: function that takes a list of images and returns a
matrix of prediction probabilities
num_samples: size of the neighborhood to learn the linear model
batch_size: classifier_fn will be called on batches of this size.
Returns:
A tuple (data, labels), where:
data: dense num_samples * num_superpixels
labels: prediction probabilities matrix
"""
n_features = np.unique(segments).shape[0]
data = self.random_state.randint(0, 2, num_samples * n_features)\
.reshape((num_samples, n_features))
labels = []
data[0, :] = 1
imgs = []
for row in data:
temp = copy.deepcopy(image)
zeros = np.where(row == 0)[0]
mask = np.zeros(segments.shape).astype(bool)
for z in zeros:
mask[segments == z] = True
temp[mask] = fudged_image[mask]
imgs.append(temp)
if len(imgs) == batch_size:
preds = classifier_fn(np.array(imgs))
labels.extend(preds)
imgs = []
if len(imgs) > 0:
preds = classifier_fn(np.array(imgs))
labels.extend(preds)
return data, np.array(labels)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,117 @@
import types
from lime.utils.generic_utils import has_arg
from skimage.segmentation import felzenszwalb, slic, quickshift
class BaseWrapper(object):
"""Base class for LIME Scikit-Image wrapper
Args:
target_fn: callable function or class instance
target_params: dict, parameters to pass to the target_fn
'target_params' takes parameters required to instanciate the
desired Scikit-Image class/model
"""
def __init__(self, target_fn=None, **target_params):
self.target_fn = target_fn
self.target_params = target_params
self.target_fn = target_fn
self.target_params = target_params
def _check_params(self, parameters):
"""Checks for mistakes in 'parameters'
Args :
parameters: dict, parameters to be checked
Raises :
ValueError: if any parameter is not a valid argument for the target function
or the target function is not defined
TypeError: if argument parameters is not iterable
"""
a_valid_fn = []
if self.target_fn is None:
if callable(self):
a_valid_fn.append(self.__call__)
else:
raise TypeError('invalid argument: tested object is not callable,\
please provide a valid target_fn')
elif isinstance(self.target_fn, types.FunctionType) \
or isinstance(self.target_fn, types.MethodType):
a_valid_fn.append(self.target_fn)
else:
a_valid_fn.append(self.target_fn.__call__)
if not isinstance(parameters, str):
for p in parameters:
for fn in a_valid_fn:
if has_arg(fn, p):
pass
else:
raise ValueError('{} is not a valid parameter'.format(p))
else:
raise TypeError('invalid argument: list or dictionnary expected')
def set_params(self, **params):
"""Sets the parameters of this estimator.
Args:
**params: Dictionary of parameter names mapped to their values.
Raises :
ValueError: if any parameter is not a valid argument
for the target function
"""
self._check_params(params)
self.target_params = params
def filter_params(self, fn, override=None):
"""Filters `target_params` and return those in `fn`'s arguments.
Args:
fn : arbitrary function
override: dict, values to override target_params
Returns:
result : dict, dictionary containing variables
in both target_params and fn's arguments.
"""
override = override or {}
result = {}
for name, value in self.target_params.items():
if has_arg(fn, name):
result.update({name: value})
result.update(override)
return result
class SegmentationAlgorithm(BaseWrapper):
""" Define the image segmentation function based on Scikit-Image
implementation and a set of provided parameters
Args:
algo_type: string, segmentation algorithm among the following:
'quickshift', 'slic', 'felzenszwalb'
target_params: dict, algorithm parameters (valid model paramters
as define in Scikit-Image documentation)
"""
def __init__(self, algo_type, **target_params):
self.algo_type = algo_type
if (self.algo_type == 'quickshift'):
BaseWrapper.__init__(self, quickshift, **target_params)
kwargs = self.filter_params(quickshift)
self.set_params(**kwargs)
elif (self.algo_type == 'felzenszwalb'):
BaseWrapper.__init__(self, felzenszwalb, **target_params)
kwargs = self.filter_params(felzenszwalb)
self.set_params(**kwargs)
elif (self.algo_type == 'slic'):
BaseWrapper.__init__(self, slic, **target_params)
kwargs = self.filter_params(slic)
self.set_params(**kwargs)
def __call__(self, *args):
return self.target_fn(args[0], **self.target_params)

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from skimage.segmentation import mark_boundaries
from torchvision import transforms
import torch
from .lime import lime_image
import numpy as np
from .. import imagenet_utils, pytorch_utils, utils
class LimeImageExplainer:
def __init__(self, model, predict_fn):
self.model = model
self.predict_fn = predict_fn
def preprocess_input(self, inp):
return inp
def preprocess_label(self, label):
return label
def explain(self, inp, ind=None, raw_inp=None, top_labels=5, hide_color=0, num_samples=1000,
positive_only=True, num_features=5, hide_rest=True, pixel_val_max=255.0):
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(self.preprocess_input(raw_inp), self.predict_fn,
top_labels=5, hide_color=0, num_samples=1000)
temp, mask = explanation.get_image_and_mask(self.preprocess_label(ind) or explanation.top_labels[0],
positive_only=True, num_features=5, hide_rest=True)
img = mark_boundaries(temp/pixel_val_max, mask)
img = torch.from_numpy(img)
img = torch.transpose(img, 0, 2)
img = torch.transpose(img, 1, 2)
return img.unsqueeze(0)
class LimeImagenetExplainer(LimeImageExplainer):
def __init__(self, model, predict_fn=None):
super(LimeImagenetExplainer, self).__init__(model, predict_fn or self._imagenet_predict)
def _preprocess_transform(self):
transf = transforms.Compose([
transforms.ToTensor(),
imagenet_utils.get_normalize_transform()
])
return transf
def preprocess_input(self, inp):
return np.array(imagenet_utils.get_resize_transform()(inp))
def preprocess_label(self, label):
return label.item() if label is not None and utils.has_method(label, 'item') else label
def _imagenet_predict(self, images):
probs = imagenet_utils.predict(self.model, images, image_transform=self._preprocess_transform())
return pytorch_utils.tensor2numpy(probs)

Просмотреть файл

@ -0,0 +1,51 @@
import torch
import numpy as np
from torch.autograd import Variable
from skimage.util import view_as_windows
# modified from https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L291-L342
# note the different dim order in pytorch (NCHW) and tensorflow (NHWC)
class OcclusionExplainer:
def __init__(self, model, window_shape=10, step=1):
self.model = model
self.window_shape = window_shape
self.step = step
def explain(self, inp, ind=None, raw_inp=None):
self.model.eval()
with torch.no_grad():
return OcclusionExplainer._occlusion(inp, self.model, self.window_shape, self.step)
@staticmethod
def _occlusion(inp, model, window_shape, step=None):
if type(window_shape) == int:
window_shape = (window_shape, window_shape, 3)
if step is None:
step = 1
n, c, h, w = inp.data.size()
total_dim = c * h * w
index_matrix = np.arange(total_dim).reshape(h, w, c)
idx_patches = view_as_windows(index_matrix, window_shape, step).reshape(
(-1,) + window_shape)
heatmap = np.zeros((n, h, w, c), dtype=np.float32).reshape((-1), total_dim)
weights = np.zeros_like(heatmap)
inp_data = inp.data.clone()
new_inp = Variable(inp_data)
eval0 = model(new_inp)
pred_id = eval0.max(1)[1].data[0]
for i, p in enumerate(idx_patches):
mask = np.ones((h, w, c)).flatten()
mask[p.flatten()] = 0
th_mask = torch.from_numpy(mask.reshape(1, h, w, c).transpose(0, 3, 1, 2)).float().cuda()
masked_xs = Variable(th_mask * inp_data)
delta = (eval0[0, pred_id] - model(masked_xs)[0, pred_id]).data.cpu().numpy()
delta_aggregated = np.sum(delta.reshape(n, -1), -1, keepdims=True)
heatmap[:, p.flatten()] += delta_aggregated
weights[:, p.flatten()] += p.size
attribution = np.reshape(heatmap / (weights + 1e-10), (n, h, w, c)).transpose(0, 3, 1, 2)
return torch.from_numpy(attribution)

Просмотреть файл

@ -0,0 +1,112 @@
from .gradcam import GradCAMExplainer
from .backprop import VanillaGradExplainer, GradxInputExplainer, SaliencyExplainer, \
IntegrateGradExplainer, DeconvExplainer, GuidedBackpropExplainer, SmoothGradExplainer
from .deeplift import DeepLIFTRescaleExplainer
from .occlusion import OcclusionExplainer
from .epsilon_lrp import EpsilonLrp
from .lime_image_explainer import LimeImageExplainer, LimeImagenetExplainer
import skimage.transform
import torch
import matplotlib.pyplot as plt
import math
from .. import image_utils
class ImageSaliencyResult:
def __init__(self, raw_image, saliency, title, saliency_alpha=0.4, saliency_cmap='jet'):
self.raw_image, self.saliency, self.title = raw_image, saliency, title
self.saliency_alpha, self.saliency_cmap = saliency_alpha, saliency_cmap
def _get_explainer(explainer_name, model, layer_path=None):
if explainer_name == 'gradcam':
return GradCAMExplainer(model, target_layer_name_keys=layer_path, use_inp=True)
if explainer_name == 'vanilla_grad':
return VanillaGradExplainer(model)
if explainer_name == 'grad_x_input':
return GradxInputExplainer(model)
if explainer_name == 'saliency':
return SaliencyExplainer(model)
if explainer_name == 'integrate_grad':
return IntegrateGradExplainer(model)
if explainer_name == 'deconv':
return DeconvExplainer(model)
if explainer_name == 'guided_backprop':
return GuidedBackpropExplainer(model)
if explainer_name == 'smooth_grad':
return SmoothGradExplainer(model)
if explainer_name == 'deeplift':
return DeepLIFTRescaleExplainer(model)
if explainer_name == 'occlusion':
return OcclusionExplainer(model)
if explainer_name == 'lrp':
return EpsilonLrp(model)
if explainer_name == 'lime_imagenet':
return LimeImagenetExplainer(model)
raise ValueError('Explainer {} is not recognized'.format(explainer_name))
def _get_layer_path(model):
if model.__class__.__name__ == 'VGG':
return ['features', '30'] # pool5
elif model.__class__.__name__ == 'GoogleNet':
return ['pool5']
elif model.__class__.__name__ == 'ResNet':
return ['avgpool'] #layer4
elif model.__class__.__name__ == 'Inception3':
return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10'
else: #unknown network
return None
def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input = input.to(device)
if label is not None:
label = label.to(device)
if input.grad is not None:
input.grad.zero_()
if label is not None and label.grad is not None:
label.grad.zero_()
model.eval()
model.zero_grad()
layer_path = layer_path or _get_layer_path(model)
exp = _get_explainer(method, model, layer_path)
saliency = exp.explain(input, label, raw_input)
saliency = saliency.abs().sum(dim=1)[0].squeeze()
saliency -= saliency.min()
saliency /= (saliency.max() + 1e-20)
return saliency.detach().cpu().numpy()
def get_image_saliency_results(model, raw_image, input, label,
methods=['lime_imagenet', 'gradcam', 'smooth_grad',
'guided_backprop', 'deeplift', 'grad_x_input'],
layer_path=None):
results = []
for method in methods:
sal = get_saliency(model, raw_image, input, label, method=method)
results.append(ImageSaliencyResult(raw_image, sal, method))
return results
def get_image_saliency_plot(image_saliency_results, cols = 2, figsize = None):
rows = math.ceil(len(image_saliency_results) / cols)
figsize=figsize or (8, 3 * rows)
figure = plt.figure(figsize=figsize) #figsize=(8, 3)
for i, r in enumerate(image_saliency_results):
ax = figure.add_subplot(rows, cols, i+1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(r.title, fontdict={'fontsize': 24}) #'fontweight': 'light'
#upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear')
saliency_upsampled = skimage.transform.resize(r.saliency,
(r.raw_image.height, r.raw_image.width))
image_utils.show_image(r.raw_image, img2=saliency_upsampled,
alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)
return figure

52
tensorwatch/stream.py Normal file
Просмотреть файл

@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import weakref, uuid
from typing import Any
from . import utils
class Stream:
def __init__(self, stream_name:str=None, console_debug:bool=False):
self._subscribers = weakref.WeakSet()
self._subscribed_to = weakref.WeakSet()
self.held_refs = set() # on some rare occasion we might want stream to hold references of other streams
self.closed = False
self.console_debug = console_debug
self.stream_name = stream_name or str(uuid.uuid4()) # useful to use as key and avoid circular references
def subscribe(self, stream:'Stream'): # notify other stream
utils.debug_log('{} added {} as subscription'.format(self.stream_name, stream.stream_name))
stream._subscribers.add(self)
self._subscribed_to.add(stream)
def unsubscribe(self, stream:'Stream'):
utils.debug_log('{} removed {} as subscription'.format(self.stream_name, stream.stream_name))
stream._subscribers.discard(self)
self._subscribed_to.discard(stream)
self.held_refs.discard(stream)
#stream.held_refs.discard(self) # not needed as only subscriber should hold ref
def write(self, val:Any, from_stream:'Stream'=None):
if self.console_debug:
print(self.stream_name, val)
for subscriber in self._subscribers:
subscriber.write(val, from_stream=self)
def load(self, from_stream:'Stream'=None):
for subscribed_to in self._subscribed_to:
subscribed_to.load(from_stream=self)
def close(self):
if not self.closed:
for subscribed_to in self._subscribed_to:
subscribed_to._subscribers.discard(self)
self._subscribed_to.clear()
self.closed = True
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
self.close()

Просмотреть файл

@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Sequence
from .zmq_stream import ZmqStream
from .file_stream import FileStream
from .stream import Stream
from .stream_union import StreamUnion
class StreamFactory:
r"""Allows to create shared stream such as file and ZMQ streams
"""
def __init__(self)->None:
self.closed = None
self._streams:Dict[str, Stream] = None
self._reset()
def _reset(self):
self._streams:Dict[str, Stream] = {}
self.closed = False
def close(self):
if not self.closed:
for stream in self._streams.values():
stream.close()
self._reset()
self.closed = True
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
self.close()
def get_streams(self, stream_types:Sequence[str], for_write:bool=None)->Stream:
streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
return streams
def get_combined_stream(self, stream_types:Sequence[str], for_write:bool=None)->Stream:
streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
if len(streams) == 1:
return self._streams[0]
else:
# we create new union of child but this is not necessory
return StreamUnion(streams, for_write=for_write)
def _create_stream_by_string(self, stream_spec:str, for_write:bool)->Stream:
parts = stream_spec.split(':', 1) if stream_spec is not None else ['']
stream_type = parts[0]
stream_args = parts[1] if len(parts) > 1 else None
if stream_type == 'tcp':
port = int(stream_args or 0)
stream_name = '{}:{}:{}'.format(stream_type, port, for_write)
if stream_name not in self._streams:
self._streams[stream_name] = ZmqStream(for_write=for_write,
port=port, stream_name=stream_name, block_until_connected=False)
# else we already have this stream
return self._streams[stream_name]
if stream_args is None: # file name specified without 'file:' prefix
stream_args = stream_type
stream_type = 'file'
if len(stream_type) == 1: # windows drive letter
stream_type = 'file'
stream_args = stream_spec
if stream_type == 'file':
if stream_args is None:
raise ValueError('File name must be specified for stream type "file"')
stream_name = '{}:{}:{}'.format(stream_type, stream_args, for_write)
if stream_name not in self._streams:
self._streams[stream_name] = FileStream(for_write=for_write,
file_name=stream_args, stream_name=stream_name)
# else we already have this stream
return self._streams[stream_name]
if stream_type == '':
return Stream()
raise ValueError('stream_type "{}" has unknown type'.format(stream_type))

Просмотреть файл

@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .stream import Stream
from typing import Iterator
class StreamUnion(Stream):
def __init__(self, child_streams:Iterator[Stream], for_write:bool, stream_name:str=None, console_debug:bool=False) -> None:
super(StreamUnion, self).__init__(stream_name=stream_name, console_debug=console_debug)
# save references, child streams does away only if parent goes away
self.child_streams = child_streams
# when someone does write to us, we write to all our listeners
if for_write:
for child_stream in child_streams:
child_stream.subscribe(self)
else:
# union of all child streams
for child_stream in child_streams:
self.subscribe(child_stream)

82
tensorwatch/text_vis.py Normal file
Просмотреть файл

@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from . import utils
import pandas as pd
import ipywidgets as widgets
from IPython import get_ipython #, display
from .vis_base import VisBase
class TextVis(VisBase):
def __init__(self, cell:widgets.Box=None, title:str=None, show_legend:bool=None,
stream_name:str=None, console_debug:bool=False, **vis_args):
super(TextVis, self).__init__(widgets.HTML(), cell, title, show_legend,
stream_name=stream_name, console_debug=console_debug, **vis_args)
self.df = pd.DataFrame([])
def _get_column_prefix(self, stream_vis, i):
return '[S.{}]:{}'.format(stream_vis.index, i)
def _get_title(self, stream_vis):
title = stream_vis.title or 'Stream ' + str(len(self._stream_vises))
return title
# this will be called from _show_stream_items
def _append(self, stream_vis, vals):
if vals is None:
self.df = self.df.append(pd.Series({self._get_column_prefix(stream_vis, 0) : None}),
sort=False, ignore_index=True)
return
for val in vals:
if val is None or utils.is_scalar(val):
self.df = self.df.append(pd.Series({self._get_column_prefix(stream_vis, 0) : val}),
sort=False, ignore_index=True)
elif utils.is_array_like(val):
val_dict = {}
for i,val_i in enumerate(val):
val_dict[self._get_column_prefix(stream_vis, i)] = val_i
self.df = self.df.append(pd.Series(val_dict), sort=False, ignore_index=True)
else:
self.df = self.df.append(pd.Series(val.__dict__), sort=False, ignore_index=True)
def _post_add_subscription(self, stream_vis, **stream_vis_args):
only_summary = stream_vis_args.get('only_summary', False)
stream_vis.text = self._get_title(stream_vis)
stream_vis.only_summary = only_summary
def clear_plot(self, stream_vis, clear_history):
self.df = self.df.iloc[0:0]
def _show_stream_items(self, stream_vis, stream_items):
for stream_item in stream_items:
if stream_item.ended:
self.df = self.df.append(pd.Series({'Ended':True}),
sort=False, ignore_index=True)
else:
vals = self._extract_vals((stream_item,))
self._append(stream_vis, vals)
return True
def _post_update_stream_plot(self, stream_vis):
if get_ipython():
if not stream_vis.only_summary:
self.widget.value = self.df.to_html(classes=['output_html', 'rendered_html'])
else:
self.widget.value = self.df.describe().to_html(classes=['output_html', 'rendered_html'])
# below doesn't work because of threading issue
#self.widget.clear_output(wait=True)
#with self.widget:
# display.display(self.df)
else:
last_recs = self.df.iloc[[-1]].to_dict('records')
if len(last_recs) == 1:
print(last_recs[0])
else:
print(last_recs)
def _show_widget_native(self, blocking:bool):
return None # we will be using console
def _show_widget_notebook(self):
return self.widget

352
tensorwatch/utils.py Normal file
Просмотреть файл

@ -0,0 +1,352 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np #pip install numpy
import math
import time
import sys
import os
import inspect
import re
import uuid
from collections import abc
import textwrap
from functools import wraps
import gc
import timeit
def MeasureTime(f):
@wraps(f)
def _wrapper(*args, **kwargs):
gcold = gc.isenabled()
gc.disable()
start_time = timeit.default_timer()
try:
result = f(*args, **kwargs)
finally:
elapsed = timeit.default_timer() - start_time
if gcold:
gc.enable()
print('Function "{}": {}s'.format(f.__name__, elapsed))
return result
return _wrapper
class MeasureBlockTime:
def __init__(self, name="(block)", no_print = False, disable_gc = True, format_str=":.2f"):
self.name = name
self.no_print = no_print
self.disable_gc = disable_gc
self.format_str = format_str
self.gcold = None
self.start_time = None
self.elapsed = None
def __enter__(self):
if self.disable_gc:
self.gcold = gc.isenabled()
gc.disable()
self.start_time = timeit.default_timer()
return self
def __exit__(self,ty,val,tb):
self.elapsed = timeit.default_timer() - self.start_time
if self.disable_gc and self.gcold:
gc.enable()
if not self.no_print:
print(('{}: {' + self.format_str + '}s').format(self.name, self.elapsed))
return False #re-raise any exceptions
def getTime():
return timeit.default_timer()
def getElapsedTime(start_time):
return timeit.default_timer() - start_time
def string_to_uint8_array(bstr):
return np.fromstring(bstr, np.uint8)
def string_to_float_array(bstr):
return np.fromstring(bstr, np.float32)
def list_to_2d_float_array(flst, width, height):
return np.reshape(np.asarray(flst, np.float32), (height, width))
def get_pfm_array(response):
return list_to_2d_float_array(response.image_data_float, response.width, response.height)
# creates same list as len of seq filled with val - if val is already not a list of same size
def fill_like(val, seq):
l = len(seq)
if is_array_like(val) and len(val) == l:
return val
return [val] * len(seq)
def is_array_like(obj, string_is_array=False, tuple_is_array=True):
result = hasattr(obj, "__len__") and hasattr(obj, '__getitem__')
if result and not string_is_array and isinstance(obj, (str, abc.ByteString)):
result = False
if result and not tuple_is_array and isinstance(obj, tuple):
result = False
return result
def is_scalar(x):
return x is None or np.isscalar(x)
def is_scaler_array(x): #detects (x,y) or [x, y]
if is_array_like(x):
if len(x) > 0:
return len(x) if is_scalar(x[0]) else -1
else:
return 0
else:
return -1
def get_public_fields(obj):
return [attr for attr in dir(obj)
if not (attr.startswith("_")
or inspect.isbuiltin(attr)
or inspect.isfunction(attr)
or inspect.ismethod(attr))]
def set_default(dictionary, key, default_val, replace_none=True):
if key not in dictionary or (replace_none and dictionary[key] is None):
dictionary[key] = default_val
def to_array_like(val):
if is_array_like(val):
return val
return [val]
def to_dict(obj):
return dict([attr, getattr(obj, attr)] for attr in get_public_fields(obj))
def to_str(obj):
return str(to_dict(obj))
def write_file(filename, bstr):
with open(filename, 'wb') as afile:
afile.write(bstr)
# helper method for converting getOrientation to roll/pitch/yaw
# https:#en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
def has_method(o, name):
return callable(getattr(o, name, None))
def to_eularian_angles(q):
z = q.z_val
y = q.y_val
x = q.x_val
w = q.w_val
ysqr = y * y
# roll (x-axis rotation)
t0 = +2.0 * (w*x + y*z)
t1 = +1.0 - 2.0*(x*x + ysqr)
roll = math.atan2(t0, t1)
# pitch (y-axis rotation)
t2 = +2.0 * (w*y - z*x)
if (t2 > 1.0):
t2 = 1
if (t2 < -1.0):
t2 = -1.0
pitch = math.asin(t2)
# yaw (z-axis rotation)
t3 = +2.0 * (w*z + x*y)
t4 = +1.0 - 2.0 * (ysqr + z*z)
yaw = math.atan2(t3, t4)
return (pitch, roll, yaw)
# TODO: sync with AirSim utils.py
def wait_key(message = ''):
''' Wait for a key press on the console and return it. '''
if message != '':
print (message)
result = None
if os.name == 'nt':
import msvcrt
result = msvcrt.getch()
else:
# pylint: disable=import-error
import termios # pylint: disable=import-error
fd = sys.stdin.fileno()
oldterm = termios.tcgetattr(fd)
newattr = termios.tcgetattr(fd)
newattr[3] = newattr[3] & ~termios.ICANON & ~termios.ECHO
termios.tcsetattr(fd, termios.TCSANOW, newattr)
try:
result = sys.stdin.read(1)
except IOError:
pass
finally:
termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm)
return result
def read_pfm(file):
""" Read a pfm file """
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
header = str(bytes.decode(header, encoding='utf-8'))
if header == 'PF':
color = True
elif header == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
temp_str = str(bytes.decode(file.readline(), encoding='utf-8'))
dim_match = re.match(r'^(\d+)\s(\d+)\s$', temp_str)
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
# DEY: I don't know why this was there.
#data = np.flipud(data)
file.close()
return data, scale
def write_pfm(file, image, scale=1):
""" Write a pfm file """
file = open(file, 'wb')
color = None
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
color = False
else:
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
temp_str = '%d %d\n' % (image.shape[1], image.shape[0])
file.write(temp_str.encode('utf-8'))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
temp_str = '%f\n' % scale
file.write(temp_str.encode('utf-8'))
image.tofile(file)
def write_png(filename, image):
""" image must be numpy array H X W X channels
"""
import zlib, struct
buf = image.flatten().tobytes()
width = image.shape[1]
height = image.shape[0]
# reverse the vertical line order and add null bytes at the start
width_byte_4 = width * 4
raw_data = b''.join(b'\x00' + buf[span:span + width_byte_4]
for span in range((height - 1) * width_byte_4, -1, - width_byte_4))
def png_pack(png_tag, data):
chunk_head = png_tag + data
return (struct.pack("!I", len(data)) +
chunk_head +
struct.pack("!I", 0xFFFFFFFF & zlib.crc32(chunk_head)))
png_bytes = b''.join([
b'\x89PNG\r\n\x1a\n',
png_pack(b'IHDR', struct.pack("!2I5B", width, height, 8, 6, 0, 0, 0)),
png_pack(b'IDAT', zlib.compress(raw_data, 9)),
png_pack(b'IEND', b'')])
write_file(filename, png_bytes)
def add_windows_ctrl_c():
def handler(a,b=None): # pylint: disable=unused-argument
sys.exit(1)
add_windows_ctrl_c.is_handler_installed = \
vars(add_windows_ctrl_c).setdefault('is_handler_installed',False)
if sys.platform == "win32" and not add_windows_ctrl_c.is_handler_installed:
if sys.stdin is not None and sys.stdin.isatty():
#this is Console based application
import win32api
win32api.SetConsoleCtrlHandler(handler, True)
#else do not install handler for non-console applications
add_windows_ctrl_c.is_handler_installed = True
_utils_debug_verbosity=0
_utils_start_time = time.time()
def set_debug_verbosity(verbosity=0):
global _utils_debug_verbosity # pylink: disable=global-statement
_utils_debug_verbosity = verbosity
def debug_log(msg, param=None, verbosity=3):
global _utils_debug_verbosity # pylink: disable=global-statement
if _utils_debug_verbosity is not None and _utils_debug_verbosity >= verbosity:
print("[Debug][{}]: {} : {} : t={:.2f}".format(verbosity, msg, param, time.time()-_utils_start_time))
def get_uuid(is_hex = False):
return str(uuid.uuid4()) if not is_hex else uuid.uuid4().hex
def is_uuid4(s, is_hex=False):
try:
val = uuid.UUID(s, version=4)
return val.hex == s if is_hex else str(val) == s
except ValueError:
return False
def frange(start, stop=None, step=None, steps=None):
if stop is None:
start, stop = 0, start
if steps is None:
if step is None:
step = 1
steps = int((stop-start)/step)
else:
if step is not None:
raise ValueError("Both step and steps cannot be specified")
step = (stop-start)/steps
for _ in range(steps):
yield start
start += step
def wrap_string(s, chars_per_line=12):
return "\n".join(textwrap.wrap(s, chars_per_line))
def is_eof(f):
s = f.read(1)
if s != b'': # restore position
f.seek(-1, os.SEEK_CUR)
return s == b''

174
tensorwatch/vis_base.py Normal file
Просмотреть файл

@ -0,0 +1,174 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys, time, threading, queue, functools
from typing import Any
from types import MethodType
from abc import ABCMeta, abstractmethod
from .lv_types import StreamPlot, StreamItem
from . import utils
from .stream import Stream
from IPython import get_ipython, display
import ipywidgets as widgets
class VisBase(Stream, metaclass=ABCMeta):
def __init__(self, widget, cell:widgets.Box, title:str, show_legend:bool, stream_name:str=None, console_debug:bool=False, **vis_args):
super(VisBase, self).__init__(stream_name=stream_name, console_debug=console_debug)
self.lock = threading.Lock()
self._use_hbox = True
utils.set_default(vis_args, 'cell_width', '100%')
self.widget = widget
self.cell = cell or widgets.HBox(layout=widgets.Layout(\
width=vis_args['cell_width'])) if self._use_hbox else None
if self._use_hbox:
self.cell.children += (self.widget,)
self._stream_vises = {}
self.is_shown = cell is not None
self.title = title
self.last_ex = None
self.layout_dirty = False
self.q_last_processed = 0
def subscribe(self, stream:Stream, title=None, clear_after_end=False, clear_after_each=False,
show:bool=False, history_len=1, dim_history=True, opacity=None, **stream_vis_args):
# in this ovedrride we don't call base class method
with self.lock:
self.layout_dirty = True
stream_vis = StreamPlot(stream, title, clear_after_end,
clear_after_each, history_len, dim_history, opacity,
len(self._stream_vises), stream_vis_args, 0)
stream_vis._clear_pending = False
stream_vis._pending_items = queue.Queue()
self._stream_vises[stream.stream_name] = stream_vis
self._post_add_subscription(stream_vis, **stream_vis_args)
super(VisBase, self).subscribe(stream)
if show or (show is None and not self.is_shown):
return self.show()
def show(self, blocking:bool=False):
self.is_shown = True
if get_ipython():
if self._use_hbox:
display.display(self.cell) # this method doesn't need returns
#return self.cell
else:
return self._show_widget_notebook()
else:
return self._show_widget_native(blocking)
def write(self, val:Any, from_stream:'Stream'=None):
# let the base class know about new item, this will notify any subscribers
super(VisBase, self).write(val)
stream_vis:StreamPlot = None
if from_stream:
stream_vis = self._stream_vises.get(from_stream.stream_name, None)
if not stream_vis: # select the first one we have
stream_vis = next(iter(self._stream_vises.values()))
VisBase.write_stream_plot(self, stream_vis, val)
@staticmethod
def write_stream_plot(vis, stream_vis:StreamPlot, stream_item:StreamItem):
with vis.lock: # this could be from separate thread!
#if stream_vis is None:
# utils.debug_log('stream_vis not specified in VisBase.write')
# stream_vis = next(iter(vis._stream_vises.values())) # use first as default
utils.debug_log("Stream received: {}".format(stream_item.stream_name), verbosity=5)
stream_vis._pending_items.put(stream_item)
# if we accumulated enough of pending items then let's process them
if vis._can_update_stream_plots():
vis._update_stream_plots()
def _extract_results(self, stream_vis):
stream_items, clear_current, clear_history = [], False, False
while not stream_vis._pending_items.empty():
stream_item = stream_vis._pending_items.get()
if stream_item.stream_reset:
utils.debug_log("Stream reset", stream_item.stream_name)
stream_items.clear() # no need to process these events
clear_current, clear_history = True, True
else:
# check if there was an exception
if stream_item.exception is not None:
#TODO: need better handling here?
print(stream_item.exception, file=sys.stderr)
raise stream_item.exception
# state management for _clear_pending
# if we need to clear plot before putting in data, do so
if stream_vis._clear_pending:
stream_items.clear()
clear_current = True
stream_vis._clear_pending = False
if stream_vis.clear_after_each or (stream_item.ended and stream_vis.clear_after_end):
stream_vis._clear_pending = True
stream_items.append(stream_item)
return stream_items, clear_current, clear_history
def _extract_vals(self, stream_items):
vals = []
for stream_item in stream_items:
if stream_item.ended or stream_item.value is None:
pass # no values to add
else:
if utils.is_array_like(stream_item.value, tuple_is_array=False):
vals.extend(stream_item.value)
else:
vals.append(stream_item.value)
return vals
@abstractmethod
def clear_plot(self, stream_vis, clear_history):
"""(for derived class) Clears the data in specified plot before new data is redrawn"""
pass
@abstractmethod
def _show_stream_items(self, stream_vis, stream_items):
"""(for derived class) Plot the data in given axes"""
pass
@abstractmethod
def _post_add_subscription(self, stream_vis, **stream_vis_args):
pass
# typically we want to batch up items for performance
def _can_update_stream_plots(self):
return True
@abstractmethod
def _post_update_stream_plot(self, stream_vis):
pass
def _update_stream_plots(self):
with self.lock:
self.q_last_processed = time.time()
for stream_vis in self._stream_vises.values():
stream_items, clear_current, clear_history = self._extract_results(stream_vis)
if clear_current:
self.clear_plot(stream_vis, clear_history)
# if we have something to render
dirty = self._show_stream_items(stream_vis, stream_items)
if dirty:
self._post_update_stream_plot(stream_vis)
stream_vis.last_update = time.time()
@abstractmethod
def _show_widget_native(self, blocking:bool):
pass
@abstractmethod
def _show_widget_notebook(self):
pass

72
tensorwatch/visualizer.py Normal file
Просмотреть файл

@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .stream import Stream
from .vis_base import VisBase
import ipywidgets as widgets
class Visualizer:
def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None,
cell:'Visualizer'=None, title:str=None,
clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None,
rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
colormap=None, viz_img_scale=None,
# these image params are for hover on point for t-sne
hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None,
only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None,
xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False,
vis_args={}, stream_vis_args={})->None:
cell = cell._host_base.cell if cell is not None else None
if host:
self._host_base = host._host_base
else:
self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape,
cell_width=cell_width, cell_height=cell_height,
**vis_args)
self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each,
history_len=history_len, dim_history=dim_history, opacity=opacity,
only_summary=only_summary if 'summary' != vis_type else True,
separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color,
xrange=xrange, yrange=yrange, zrange=zrange,
draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True,
draw_marker=draw_marker,
rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels,
colormap=colormap, viz_img_scale=viz_img_scale,
**stream_vis_args)
stream.load()
def show(self):
return self._host_base.show()
def _get_vis_base(self, vis_type, cell:widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase:
if vis_type is None:
from .text_vis import TextVis
return TextVis(cell=cell, title=title, **vis_args)
if vis_type in ['text', 'summary']:
from .text_vis import TextVis
return TextVis(cell=cell, title=title, **vis_args)
if vis_type in ['plotly-line', 'scatter', 'plotly-scatter',
'line3d', 'scatter3d', 'mesh3d']:
from . import plotly
return plotly.LinePlot(cell=cell, title=title,
is_3d=vis_type in ['line3d', 'scatter3d', 'mesh3d'], **vis_args)
if vis_type in ['image', 'mpl-image']:
from . import mpl
return mpl.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
if vis_type in ['line', 'mpl-line', 'mpl-scatter']:
from . import mpl
return mpl.LinePlot(cell=cell, title=title, **vis_args)
if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']:
from . import plotly
return plotly.EmbeddingsPlot(cell=cell, title=title, is_3d='2d' not in vis_type,
hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args)
else:
raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type))

92
tensorwatch/watcher.py Normal file
Просмотреть файл

@ -0,0 +1,92 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import uuid
from typing import Sequence
from .zmq_wrapper import ZmqWrapper
from .watcher_base import WatcherBase
from .lv_types import CliSrvReqTypes
from .lv_types import DefaultPorts, PublisherTopics, ServerMgmtMsg
from . import utils
import threading, time
class Watcher(WatcherBase):
def __init__(self, filename:str=None, port:int=0, srv_name:str=None):
super(Watcher, self).__init__()
self.port = port
self.filename = filename
# used to detect server restarts
self.srv_name = srv_name or str(uuid.uuid4())
# define vars in __init__
self._clisrv = None
self._zmq_stream_pub = None
self._file = None
self._th = None
self._open_devices()
def _open_devices(self):
if self.port is not None:
self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
is_server=True, callback=self._clisrv_callback)
# notify existing listeners of our ID
self._zmq_stream_pub = self._stream_factory.get_streams(stream_types=['tcp:'+str(self.port)], for_write=True)[0]
# ZMQ quirk: we must wait a bit after opening port and before sending message
# TODO: can we do better?
self._th = threading.Thread(target=self._send_server_start)
self._th.start()
if self.filename is not None:
self._file = self._stream_factory.get_streams(stream_types=['file:'+self.filename], for_write=True)[0]
def _send_server_start(self):
time.sleep(2)
self._zmq_stream_pub.write(ServerMgmtMsg(event_name=ServerMgmtMsg.EventServerStart,
event_args=self.srv_name), topic=PublisherTopics.ServerMgmt)
def default_devices(self)->Sequence[str]: # overriden
devices = []
if self.port is not None:
devices.append('tcp:' + str(self.port))
if self.filename is not None:
devices.append('file:' + self.filename)
return devices
def close(self):
if not self.closed:
if self._clisrv is not None:
self._clisrv.close()
if self._zmq_stream_pub is not None:
self._zmq_stream_pub.close()
if self._file is not None:
self._file.close()
utils.debug_log("Watcher is closed", verbosity=1)
super(Watcher, self).close()
def _reset(self):
self._clisrv = None
self._zmq_stream_pub = None
self._file = None
self._th = None
utils.debug_log("Watcher reset", verbosity=1)
super(Watcher, self)._reset()
def _clisrv_callback(self, clisrv, clisrv_req): # pylint: disable=unused-argument
utils.debug_log("Received client request", clisrv_req.req_type)
# request = create stream
if clisrv_req.req_type == CliSrvReqTypes.create_stream:
stream_req = clisrv_req.req_data
self.create_stream(stream_name=stream_req.stream_name, devices=stream_req.devices,
event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle,
vis_params=stream_req.vis_params)
return None # ignore return as we can't send back stream obj
elif clisrv_req.req_type == CliSrvReqTypes.del_stream:
stream_name = clisrv_req.req_data
return self.del_stream(stream_name)
else:
raise ValueError('ClientServer Request Type {} is not recognized'.format(clisrv_req))

218
tensorwatch/watcher_base.py Normal file
Просмотреть файл

@ -0,0 +1,218 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Sequence
from .lv_types import EventVars, StreamItem, StreamCreateRequest, VisParams
from .evaler import Evaler
from .stream import Stream
from .stream_factory import StreamFactory
from .filtered_stream import FilteredStream
import uuid
import time
from . import utils
class WatcherBase:
class StreamInfo:
def __init__(self, req:StreamCreateRequest, evaler:Evaler, stream:Stream,
index:int, disabled=False, last_sent:float=None)->None:
r"""Holds togaher stream_req, stream and evaler
"""
self.req, self.evaler, self.stream = req, evaler, stream
self.index, self.disabled, self.last_sent = index, disabled, last_sent
self.item_count = 0 # creator of StreamItem needs to set to set item num
def __init__(self) -> None:
self.closed = None
self._reset()
def _reset(self):
# for each event, store (stream_name, stream_info)
self._stream_infos:Dict[str, Dict[str, WatcherBase.StreamInfo]] = {}
self._global_vars:Dict[str, Any] = {}
self._stream_count = 0
# factory streams are shared per watcher instance
self._stream_factory = StreamFactory()
# each StreamItem should be stamped by its creator
self.creator_id = str(uuid.uuid4())
self.closed = False
def close(self):
if not self.closed:
# close all the streams
for stream_infos in self._stream_infos.values(): # per event
for stream_info in stream_infos.values():
stream_info.stream.close()
self._stream_factory.close()
self._reset() # clean variables
self.closed = True
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
self.close()
def default_devices(self)->Sequence[str]:
return None
def open_stream(self, stream_name:str=None, devices:Sequence[str]=None,
event_name:str='')->Stream:
r"""Opens stream from specified devices or returns one by name if
it was created before.
"""
# TODO: what if devices were specified AND stream exist in cache?
# create devices is any
devices = devices or self.default_devices()
device_streams = None
if devices is not None:
# we open devices in read-only mode
device_streams = self._stream_factory.get_streams(stream_types=devices,
for_write=False)
# if no devices then open stream by name from cache
if device_streams is None:
if stream_name is None:
raise ValueError('Both device and stream_name cannot be None')
# first search by event
stream_infos = self._stream_infos.get(event_name, None)
if stream_infos is None:
raise ValueError('Requested event was not found: ' + event_name)
# then search by stream name
stream_info = stream_infos.get(stream_name, None)
if stream_info is None:
raise ValueError('Requested stream was not found: ' + stream_name)
return stream_info.stream
# if we have device, first create stream and then attach device to it
stream = Stream(stream_name=stream_name)
for device_stream in device_streams:
# each device may have multiple streams so let's filter it
filtered_stream = FilteredStream(source_stream=device_stream,
filter_expr=((lambda steam_item: (steam_item, steam_item.stream_name == stream_name)) \
if stream_name is not None \
else None))
stream.subscribe(filtered_stream)
stream.held_refs.add(filtered_stream) # otherwise filtered stream will be destroyed by gc
return stream
def create_stream(self, stream_name:str=None, devices:Sequence[str]=None, event_name:str='',
expr=None, throttle:float=None, vis_params:VisParams=None)->Stream:
r"""Create stream with or without expression and attach to devices where
it will be written to.
"""
stream_name = stream_name or str(uuid.uuid4())
# we allow few shortcuts, so modify expression if needed
expr = expr
if expr=='' or expr=='x':
expr = 'map(lambda x:x, l)'
elif expr.strip().startswith('lambda '):
expr = 'map({}, l)'.format(expr)
# else no rewrites
# if no expression specified then we don't create evaler
evaler = Evaler(expr) if expr is not None else None
# get stream infos for this event
stream_infos = self._stream_infos.get(event_name, None)
# if first for this event, create dictionary
if stream_infos is None:
stream_infos = self._stream_infos[event_name] = {}
stream_info = stream_infos.get(stream_name, None)
if not stream_info:
utils.debug_log("Creating stream", stream_name)
stream = Stream(stream_name=stream_name)
devices = devices or self.default_devices()
if devices is not None:
# attached devices are opened in write-only mode
device_streams = self._stream_factory.get_streams(stream_types=devices,
for_write=True)
for device_stream in device_streams:
device_stream.subscribe(stream)
stream_req = StreamCreateRequest(stream_name=stream_name, devices=devices, event_name=event_name,
expr=expr, throttle=throttle, vis_params=vis_params)
stream_info = stream_infos[stream_name] = WatcherBase.StreamInfo(
stream_req, evaler, stream, self._stream_count)
self._stream_count += 1
else:
# TODO: throw error?
utils.debug_log("Stream already exist, not creating again", stream_name)
return stream_info.stream
def set_globals(self, **global_vars):
self._global_vars.update(global_vars)
def observe(self, event_name:str='', **obs_vars) -> None:
# get stream requests for this event
stream_infos = self._stream_infos.get(event_name, {})
# TODO: remove list() call - currently needed because of error dictionary
# can't be changed - happens when multiple clients gets started
for stream_info in list(stream_infos.values()):
if stream_info.disabled or stream_info.evaler is None:
continue
# apply throttle
if stream_info.req.throttle is None or stream_info.last_sent is None or \
time.time() - stream_info.last_sent >= stream_info.req.throttle:
stream_info.last_sent = time.time()
events_vars = EventVars(self._global_vars, **obs_vars)
self._eval_wrie(stream_info, events_vars)
else:
utils.debug_log("Throttled", event_name, verbosity=5)
def _eval_wrie(self, stream_info:'WatcherBase.StreamInfo', event_vars:EventVars):
eval_return = stream_info.evaler.post(event_vars)
if eval_return.is_valid:
event_name = stream_info.req.event_name
stream_item = StreamItem(stream_info.item_count,
eval_return.result, stream_info.req.stream_name, self.creator_id, stream_info.index,
exception=eval_return.exception)
stream_info.stream.write(stream_item)
stream_info.item_count += 1
utils.debug_log("eval_return sent", event_name, verbosity=5)
else:
utils.debug_log("Invalid eval_return not sent", verbosity=5)
def end_event(self, event_name:str='', disable_streams=False) -> None:
stream_infos = self._stream_infos.get(event_name, {})
for stream_info in stream_infos.values():
if not stream_info.disabled:
self._end_stream_req(stream_info, disable_streams)
def _end_stream_req(self, stream_info:'WatcherBase.StreamInfo', disable_stream:bool):
eval_return = stream_info.evaler.post(ended=True,
continue_thread=not disable_stream)
# TODO: check eval_return.is_valid ?
# event_name = stream_info.req.event_name
if disable_stream:
stream_info.disabled = True
utils.debug_log("{} stream disabled".format(stream_info.req.stream_name), verbosity=1)
stream_item = StreamItem(item_index=stream_info.item_count,
value=eval_return.result, stream_name=stream_info.req.stream_name,
creator_id=self.creator_id, stream_index=stream_info.index,
exception=eval_return.exception, ended=True)
stream_info.stream.write(stream_item)
stream_info.item_count += 1
def del_stream(self, stream_name:str) -> None:
utils.debug_log("deleting stream", stream_name)
for stream_infos in self._stream_infos.values(): # per event
stream_info = stream_infos.get(stream_name, None)
if stream_info:
stream_info.disabled = True
stream_info.evaler.abort()
return True
#TODO: to enable delete we need to protect iteration in set_vars
#del stream_reqs[stream_info.req.stream_name]
return False

Просмотреть файл

@ -0,0 +1,87 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, Sequence, List
from .zmq_wrapper import ZmqWrapper
from .lv_types import CliSrvReqTypes, ClientServerRequest, DefaultPorts
from .lv_types import VisParams, PublisherTopics, ServerMgmtMsg, StreamCreateRequest
from .stream import Stream
from .zmq_mgmt_stream import ZmqMgmtStream
from . import utils
from .watcher_base import WatcherBase
class WatcherClient(WatcherBase):
r"""Extends watcher to add methods so calls for create and delete stream can be sent to server.
"""
def __init__(self, filename:str=None, port:int=0):
super(WatcherClient, self).__init__()
self.port = port
self.filename = filename
# define vars in __init__
self._clisrv = None # client-server sockets allows to send create/del stream requests
self._zmq_srvmgmt_sub = None
self._file = None
self._open()
def _reset(self):
self._clisrv = None
self._zmq_srvmgmt_sub = None
self._file = None
utils.debug_log("WatcherClient reset", verbosity=1)
super(WatcherClient, self)._reset()
def _open(self):
if self.port is not None:
self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
is_server=False)
# create subscription where we will receive server management events
self._zmq_srvmgmt_sub = ZmqMgmtStream(clisrv=self._clisrv, for_write=False, port=self.port,
stream_name='zmq_srvmgmt_sub:'+str(self.port)+':False')
if self.filename is not None:
self._file = self._stream_factory.get_streams(stream_types=['file:'+self.filename], for_write=False)[0]
def close(self):
if not self.closed:
self._zmq_srvmgmt_sub.close()
self._clisrv.close()
utils.debug_log("WatcherClient is closed", verbosity=1)
super(WatcherClient, self).close()
def default_devices(self)->Sequence[str]: # overriden
devices = []
if self.port is not None:
devices.append('tcp:' + str(self.port))
if self.filename is not None:
devices.append('file:' + self.filename)
return devices
# override to send request to server, instead of underlying WatcherBase base class
def create_stream(self, stream_name:str=None, devices:Sequence[str]=None, event_name:str='',
expr=None, throttle:float=1, vis_params:VisParams=None)->Stream: # overriden
stream_req = StreamCreateRequest(stream_name=stream_name, devices=devices or self.default_devices(),
event_name=event_name, expr=expr, throttle=throttle, vis_params=vis_params)
self._zmq_srvmgmt_sub.add_stream_req(stream_req)
if stream_req.devices is not None:
stream = self.open_stream(stream_name=stream_req.stream_name,
devices=stream_req.devices, event_name=stream_req.event_name)
else: # we cannot return remote streams that are not backed by a device
stream = None
return stream
# override to set devices default to tcp
def open_stream(self, stream_name:str=None, devices:Sequence[str]=None,
event_name:str='')->Stream: # overriden
return super(WatcherClient, self).open_stream(stream_name=stream_name, devices=devices,
event_name=event_name)
# override to send request to server
def del_stream(self, stream_name:str) -> None:
self._zmq_srvmgmt_sub.del_stream(stream_req)

Просмотреть файл

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any
from .zmq_stream import ZmqStream
from .lv_types import PublisherTopics, ServerMgmtMsg, StreamCreateRequest
from .zmq_wrapper import ZmqWrapper
from .lv_types import CliSrvReqTypes, ClientServerRequest
from . import utils
class ZmqMgmtStream(ZmqStream):
# default topic is mgmt
def __init__(self, clisrv:ZmqWrapper.ClientServer, for_write:bool, port:int=0, topic=PublisherTopics.ServerMgmt, block_until_connected=True,
stream_name:str=None, console_debug:bool=False):
super(ZmqMgmtStream, self).__init__(for_write=for_write, port=port, topic=topic,
block_until_connected=block_until_connected, stream_name=stream_name, console_debug=console_debug)
self._clisrv = clisrv
self._stream_reqs:Dict[str,StreamCreateRequest] = {}
def write(self, mgmt_msg:Any, from_stream:'Stream'=None):
r"""Handles server management events.
"""
utils.debug_log("Received - SeverMgmtevent", mgmt_msg)
# if server was restarted then send create stream requests again
if mgmt_msg.event_name == ServerMgmtMsg.EventServerStart:
for stream_req in self._stream_reqs.values():
self._send_create_stream(stream_req)
super(ZmqMgmtStream, self).write(mgmt_msg)
def add_stream_req(self, stream_req:StreamCreateRequest)->None:
self._send_create_stream(stream_req)
# save this for later for resend if server restarts
self._stream_reqs[stream_req.stream_name] = stream_req
# override to send request to server
def del_stream(self, stream_name:str) -> None:
clisrv_req = ClientServerRequest(CliSrvReqTypes.del_stream, stream_name)
self._clisrv.send_obj(clisrv_req)
self._stream_reqs.pop(stream_name, None)
def _send_create_stream(self, stream_req):
utils.debug_log("sending create streamreq...")
clisrv_req = ClientServerRequest(CliSrvReqTypes.create_stream, stream_req)
self._clisrv.send_obj(clisrv_req)
utils.debug_log("sent create streamreq")
def close(self):
if not self.closed:
self._stream_reqs = {}
self._clisrv = None
utils.debug_log('ZmqMgmtStream is closed', verbosity=1)
super(ZmqMgmtStream, self).close()

49
tensorwatch/zmq_stream.py Normal file
Просмотреть файл

@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any
from .zmq_wrapper import ZmqWrapper
from .stream import Stream
from .lv_types import DefaultPorts, PublisherTopics
from . import utils
# on writes send data on ZMQ transport
class ZmqStream(Stream):
def __init__(self, for_write:bool, port:int=0, topic=PublisherTopics.StreamItem, block_until_connected=True,
stream_name:str=None, console_debug:bool=False):
super(ZmqStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
self.for_write = for_write
self._zmq = None
self.topic = topic
self._open(for_write, port, block_until_connected)
utils.debug_log('ZmqStream started', verbosity=1)
def _open(self, for_write:bool, port:int, block_until_connected:bool):
if for_write:
self._zmq = ZmqWrapper.Publication(port=DefaultPorts.PubSub+port,
block_until_connected=block_until_connected)
else:
self._zmq = ZmqWrapper.Subscription(port=DefaultPorts.PubSub+port,
topic=self.topic, callback=self._on_subscription_item)
def close(self):
if not self.closed:
self._zmq.close()
self._zmq = None
utils.debug_log('ZmqStream is closed', verbosity=1)
super(ZmqStream, self).close()
def _on_subscription_item(self, val:Any):
utils.debug_log('Received subscription item', verbosity=5)
self.write(val)
def write(self, val:Any, from_stream:'Stream'=None, topic=None):
super(ZmqStream, self).write(val)
if self.for_write:
topic = topic or self.topic
utils.debug_log('Sent subscription item', verbosity=5)
self._zmq.send_obj(val, topic)
# else if this was opened for read then we have subscription and
# we shouldn't be calling send_obj

290
tensorwatch/zmq_wrapper.py Normal file
Просмотреть файл

@ -0,0 +1,290 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import zmq
import errno
import pickle
from zmq.eventloop import ioloop, zmqstream
import zmq.utils.monitor
import functools, sys, logging, traceback
from threading import Thread, Event
from . import utils
import weakref
class ZmqWrapper:
_thread:Thread = None
_ioloop:ioloop.IOLoop = None
_start_event:Event = None
_ioloop_block:Event = None # indicates if there is any blocking IOLoop call in progress
@staticmethod
def initialize():
# create thread that will wait on IO Loop
if ZmqWrapper._thread is None:
ZmqWrapper._thread = Thread(target=ZmqWrapper._run_io_loop, name='ZMQIOLoop', daemon=True)
ZmqWrapper._start_event = Event()
ZmqWrapper._ioloop_block = Event()
ZmqWrapper._ioloop_block.set() # no blocking call in progress right now
ZmqWrapper._thread.start()
# this is needed to make sure IO Loop has enough time to start
ZmqWrapper._start_event.wait()
@staticmethod
def close():
# terminate the IO Loop
if ZmqWrapper._thread is not None:
ZmqWrapper._ioloop_block.set() # free any blocking call
ZmqWrapper._ioloop.add_callback(ZmqWrapper._ioloop.stop)
ZmqWrapper._thread = None
ZmqWrapper._ioloop = None
print("ZMQ IOLoop is now closed")
@staticmethod
def get_timer(secs, callback, start=True):
utils.debug_log("Adding PeriodicCallback", secs)
pc = ioloop.PeriodicCallback(callback, secs * 1e3)
if (start):
pc.start()
return pc
@staticmethod
def _run_io_loop():
if 'asyncio' in sys.modules:
# tornado may be using asyncio,
# ensure an eventloop exists for this thread
import asyncio
asyncio.set_event_loop(asyncio.new_event_loop())
ZmqWrapper._ioloop = ioloop.IOLoop()
ZmqWrapper._ioloop.make_current()
while ZmqWrapper._thread is not None:
try:
ZmqWrapper._start_event.set()
utils.debug_log("starting ioloop...")
ZmqWrapper._ioloop.start()
except zmq.ZMQError as ex:
if ex.errno == errno.EINTR:
print("Cannot start IOLoop! ZMQError: {}".format(ex), file=sys.stderr)
continue
else:
raise
# Utility method to run given function on IOLoop
# this is blocking method if has_rresult=True
# This can be called from any thread.
@staticmethod
def _io_loop_call(has_result, f, *kargs, **kwargs):
class Result:
def __init__(self, val=None):
self.val = val
def wrapper(f, r, *kargs, **kwargs):
try:
r.val = f(*kargs, **kwargs)
ZmqWrapper._ioloop_block.set()
except Exception as ex:
print(ex)
logging.fatal(ex, exc_info=True)
traceback.print_exc(file=sys.stdout)
# We will add callback in IO Loop and then wait for that
# call back to be completed
# If result is expected then we wait other wise fire and forget
if has_result:
if not ZmqWrapper._ioloop_block.is_set():
# TODO: better way to raise this error?
print('Previous blocking call on IOLoop is not yet complete!')
ZmqWrapper._ioloop_block.clear()
r = Result()
f_wrapped = functools.partial(wrapper, f, r, *kargs, **kwargs)
ZmqWrapper._ioloop.add_callback(f_wrapped)
utils.debug_log("Waiting for call on ioloop", f, verbosity=5)
ZmqWrapper._ioloop_block.wait()
utils.debug_log("Call on ioloop done", f, verbosity=5)
return r.val
else:
f_wrapped = functools.partial(f, *kargs, **kwargs)
ZmqWrapper._ioloop.add_callback(f_wrapped)
class Publication:
def __init__(self, port, host="*", block_until_connected=True):
# define vars
self._socket = None
self._mon_socket = None
self._mon_stream = None
ZmqWrapper.initialize()
utils.debug_log('Creating Publication', port, verbosity=1)
# make sure the call blocks until connection is made
ZmqWrapper._io_loop_call(block_until_connected, self._start_srv, port, host)
def _start_srv(self, port, host):
context = zmq.Context()
self._socket = context.socket(zmq.PUB)
utils.debug_log('Binding socket', (host, port), verbosity=5)
self._socket.bind("tcp://%s:%d" % (host, port))
utils.debug_log('Bound socket', (host, port), verbosity=5)
self._mon_socket = self._socket.get_monitor_socket(zmq.EVENT_CONNECTED | zmq.EVENT_DISCONNECTED)
self._mon_stream = zmqstream.ZMQStream(self._mon_socket)
self._mon_stream.on_recv(self._on_mon)
def close(self):
if self._socket:
ZmqWrapper._io_loop_call(False, self._socket.close)
# we need this wrapper method as self._socket might not be there yet
def _send_multipart(self, parts):
#utils.debug_log('_send_multipart', parts, verbosity=6)
return self._socket.send_multipart(parts)
def send_obj(self, obj, topic=""):
ZmqWrapper._io_loop_call(False, self._send_multipart,
[topic.encode(), pickle.dumps(obj)])
def _on_mon(self, msg):
ev = zmq.utils.monitor.parse_monitor_message(msg)
event = ev['event']
endpoint = ev['endpoint']
if event == zmq.EVENT_CONNECTED:
utils.debug_log("Subscriber connect event", endpoint, verbosity=1)
elif event == zmq.EVENT_DISCONNECTED:
utils.debug_log("Subscriber disconnect event", endpoint, verbosity=1)
class Subscription:
# subscribe to topic, call callback when object is received on topic
def __init__(self, port, topic="", callback=None, host="localhost"):
self._socket = None
self._stream = None
self.topic = None
ZmqWrapper.initialize()
utils.debug_log('Creating Subscription', port, verbosity=1)
ZmqWrapper._io_loop_call(False, self._add_sub,
port, topic=topic, callback=callback, host=host)
def close(self):
if self._socket:
ZmqWrapper._io_loop_call(False, self._socket.close)
def _add_sub(self, port, topic, callback, host):
def callback_wrapper(weak_callback, msg):
[topic, obj_s] = msg # pylint: disable=unused-variable
try:
if weak_callback and weak_callback():
weak_callback()(pickle.loads(obj_s))
except Exception as ex:
print(ex, file=sys.stderr) # TODO: standardize this
raise
# connect to stream socket
context = zmq.Context()
self.topic = topic.encode()
self._socket = context.socket(zmq.SUB)
utils.debug_log("Subscriber connecting...", (host, port), verbosity=1)
self._socket.connect("tcp://%s:%d" % (host, port))
utils.debug_log("Subscriber connected!", (host, port), verbosity=1)
# setup socket filtering
if topic != "":
self._socket.setsockopt(zmq.SUBSCRIBE, self.topic)
# if callback is specified then create a stream and set it
# for on_recv event - this would require running ioloop
if callback is not None:
self._stream = zmqstream.ZMQStream(self._socket)
wr_cb = weakref.WeakMethod(callback)
wrapper = functools.partial(callback_wrapper, wr_cb)
self._stream.on_recv(wrapper)
#else use receive_obj
def _receive_obj(self):
[topic, obj_s] = self._socket.recv_multipart() # pylint: disable=unbalanced-tuple-unpacking
if topic != self.topic:
raise ValueError("Expected topic: %s, Received topic: %s" % (topic, self.topic))
return pickle.loads(obj_s)
def receive_obj(self):
return ZmqWrapper._io_loop_call(True, self._receive_obj)
def _get_socket_identity(self):
ep_id = self._socket.getsockopt(zmq.LAST_ENDPOINT)
return ep_id
def get_socket_identity(self):
return ZmqWrapper._io_loop_call(True, self._get_socket_identity)
class ClientServer:
def __init__(self, port, is_server, callback=None, host=None):
self._socket = None
self._stream = None
ZmqWrapper.initialize()
utils.debug_log('Creating ClientServer', (is_server, port), verbosity=1)
# make sure call blocks until connection is made
# otherwise variables would not be available
ZmqWrapper._io_loop_call(True, self._connect,
port, is_server, callback, host)
def close(self):
if self._socket:
ZmqWrapper._io_loop_call(False, self._socket.close)
def _connect(self, port, is_server, callback, host):
def callback_wrapper(callback, msg):
utils.debug_log("Server received request...", verbosity=6)
[obj_s] = msg
try:
ret = callback(self, pickle.loads(obj_s))
# we must send reply to complete the cycle
self._socket.send_multipart([pickle.dumps((ret, None))])
except Exception as ex:
print("ClientServer call raised exception: ", ex, file=sys.stderr)
# we must send reply to complete the cycle
self._socket.send_multipart([pickle.dumps((None, ex))])
utils.debug_log("Server sent response", verbosity=6)
context = zmq.Context()
if is_server:
host = host or "127.0.0.1"
self._socket = context.socket(zmq.REP)
utils.debug_log('Binding socket', (host, port), verbosity=5)
self._socket.bind("tcp://%s:%d" % (host, port))
utils.debug_log('Bound socket', (host, port), verbosity=5)
else:
host = host or "localhost"
self._socket = context.socket(zmq.REQ)
self._socket.setsockopt(zmq.REQ_CORRELATE, 1)
self._socket.setsockopt(zmq.REQ_RELAXED, 1)
utils.debug_log("Client connecting...", verbosity=1)
self._socket.connect("tcp://%s:%d" % (host, port))
utils.debug_log("Client connected!", verbosity=1)
if callback is not None:
self._stream = zmqstream.ZMQStream(self._socket)
wrapper = functools.partial(callback_wrapper, callback)
self._stream.on_recv(wrapper)
#else use receive_obj
def send_obj(self, obj):
ZmqWrapper._io_loop_call(False, self._socket.send_multipart,
[pickle.dumps(obj)])
def receive_obj(self):
# pylint: disable=unpacking-non-sequence
[obj_s] = ZmqWrapper._io_loop_call(True, self._socket.recv_multipart)
return pickle.loads(obj_s)
def request(self, req_obj):
utils.debug_log("Client sending request...", verbosity=6)
self.send_obj(req_obj)
r = self.receive_obj()
utils.debug_log("Client received response", verbosity=6)
return r

Просмотреть файл

@ -0,0 +1,18 @@
import tensorwatch as tw
import numpy as np
import matplotlib.pyplot as plt
import time
import torchvision.datasets as datasets
fruits_ds = datasets.ImageFolder(r'D:\datasets\fruits-360\Training')
mnist_ds = datasets.MNIST('../data', train=True, download=True)
images = [tw.ImagePlotItem(fruits_ds[i][0], title=str(i)) for i in range(5)] + \
[tw.ImagePlotItem(mnist_ds[i][0], title=str(i)) for i in range(5)]
stream = tw.ArrayStream(images)
img_plot = tw.Visualizer(stream, vis_type='image', viz_img_scale=3)
img_plot.show()
tw.image_utils.plt_loop()

Просмотреть файл

@ -0,0 +1,15 @@
import tensorwatch as tw
import objgraph, time #pip install objgraph
cli = tw.WatcherClient()
time.sleep(10)
del cli
import gc
gc.collect()
import time
time.sleep(2)
objgraph.show_backrefs(objgraph.by_type('WatcherClient'), refcounts=True, filename='b.png')

Просмотреть файл

@ -0,0 +1,10 @@
import tensorwatch as tw
from tensorwatch import evaler
e = evaler.Evaler('reduce(lambda x,y: x+y, map(lambda x:x**2, filter(lambda x: x%2==0, l)))')
for i in range(5):
eval_return = e.post(i)
print(i, eval_return)
eval_return = e.post(ended=True)
print(i, eval_return)

Просмотреть файл

@ -0,0 +1,15 @@
import ipywidgets as widgets
from IPython import get_ipython
class PrinterX:
def __init__(self):
self.w = w=widgets.HTML()
def show(self):
return self.w
def write(self,s):
self.w.value = s
print("Running from within ipython?", get_ipython() is not None)
p=PrinterX()
p.show()
p.write('ffffffffff')

Просмотреть файл

@ -0,0 +1,20 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch.mpl.line_plot import LinePlot
from tensorwatch.image_utils import plt_loop
from tensorwatch.stream import Stream
from tensorwatch.lv_types import StreamItem
def main():
watcher = WatcherBase()
line_plot = LinePlot()
stream = watcher.create_stream(expr='lambda vars:vars.x')
line_plot.subscribe(stream)
line_plot.show()
for i in range(5):
watcher.observe(x=(i, i*i))
plt_loop()
main()

Просмотреть файл

@ -0,0 +1,13 @@
import pandas as pd
class SomeThing:
def __init__(self, x, y):
self.x, self.y = x, y
things = [SomeThing(1,2), SomeThing(3,4), SomeThing(4,5)]
df = pd.DataFrame([t.__dict__ for t in things ])
print(df.iloc[[-1]].to_dict('records')[0])
#print(df.to_html())
print(df.style.render())

Просмотреть файл

@ -0,0 +1,19 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch.plotly.line_plot import LinePlot
from tensorwatch.image_utils import plt_loop
from tensorwatch.stream import Stream
from tensorwatch.lv_types import StreamItem
def main():
watcher = WatcherBase()
line_plot = LinePlot()
stream = watcher.create_stream(expr='lambda vars:vars.x')
line_plot.subscribe(stream)
line_plot.show()
for i in range(5):
watcher.observe(x=(i, i*i))
main()

Просмотреть файл

@ -0,0 +1,14 @@
from tensorwatch.stream import Stream
s1 = Stream(stream_name='s1', console_debug=True)
s2 = Stream(stream_name='s2', console_debug=True)
s3 = Stream(stream_name='s3', console_debug=True)
s1.subscribe(s2)
s2.subscribe(s3)
s3.write('S3 wrote this')
s2.write('S2 wrote this')
s1.write('S1 wrote this')

Просмотреть файл

@ -0,0 +1,23 @@
import threading
import time
import sys
def handler(a,b=None):
sys.exit(1)
def install_handler():
if sys.platform == "win32":
if sys.stdin is not None and sys.stdin.isatty():
#this is Console based application
import win32api
win32api.SetConsoleCtrlHandler(handler, True)
def work():
time.sleep(10000)
t = threading.Thread(target=work, name='ThreadTest')
t.daemon = True
t.start()
while(True):
t.join(0.1) #100ms ~ typical human response
# you will get KeyboardIntrupt exception

Просмотреть файл

@ -0,0 +1,15 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch.stream import Stream
def main():
watcher = WatcherBase()
console_pub = Stream(stream_name = 'S1', console_debug=True)
stream = watcher.create_stream(expr='lambda vars:vars.x**2')
console_pub.subscribe(stream)
for i in range(5):
watcher.observe(x=i)
main()

31
test/dlc/dlc.py Normal file
Просмотреть файл

@ -0,0 +1,31 @@
import tensorwatch as tw
import time
import math
from tensorwatch import utils
utils.set_debug_verbosity(4)
def dlc_show_rand_outputs():
cli = cli_train = tw.WatcherClient()
imgs = cli.create_stream(event_name='batch',
expr="top(l, out_xform=pyt_img_img_out_xform, group_key=lambda x:'', topk=10, order='rnd')",
throttle=1)
img_plot = tw.mpl.ImagePlot()
img_plot.show(imgs, img_width=39, img_height=69, viz_img_scale=10)
utils.wait_key()
def img2img_rnd():
cli_train = tw.WatcherClient()
cli = tw.WatcherClient()
imgs = cli_train.create_stream(event_name='batch',
expr="top(l, out_xform=pyt_img_img_out_xform, group_key=lambda x:'', topk=2, order='rnd')",
throttle=1)
img_plot = tw.mpl.ImagePlot()
img_plot.show(imgs, img_width=100, img_height=100, viz_img_scale=3, cols=1)
utils.wait_key()
dlc_show_rand_outputs()
img2img_rnd()

Просмотреть файл

@ -0,0 +1,28 @@
from tensorwatch.watcher_base import WatcherBase
from tensorwatch.stream import Stream
from tensorwatch.file_stream import FileStream
from tensorwatch.mpl.line_plot import LinePlot
from tensorwatch.image_utils import plt_loop
import tensorwatch as tw
def file_write():
watcher = WatcherBase()
stream = watcher.create_stream(expr='lambda vars:(vars.x, vars.x**2)',
devices=[r'c:\temp\obs.txt'])
for i in range(5):
watcher.observe(x=i)
def file_read():
watcher = WatcherBase()
stream = watcher.open_stream(devices=[r'c:\temp\obs.txt'])
vis = tw.Visualizer(stream, vis_type='mpl-line')
vis.show()
plt_loop()
def main():
file_write()
file_read()
main()

31
test/live_graph.py Normal file
Просмотреть файл

@ -0,0 +1,31 @@
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from random import randrange
from threading import Thread
import time
class LiveGraph:
def __init__(self):
self.x_data, self.y_data = [], []
self.figure = plt.figure()
self.line, = plt.plot(self.x_data, self.y_data)
self.animation = FuncAnimation(self.figure, self.update, interval=1000)
self.th = Thread(target=self.thread_f, name='LiveGraph', daemon=True)
self.th.start()
def update(self, frame):
self.line.set_data(self.x_data, self.y_data)
self.figure.gca().relim()
self.figure.gca().autoscale_view()
return self.line,
def show(self):
plt.show()
def thread_f(self):
x = 0
while True:
self.x_data.append(x)
x += 1
self.y_data.append(randrange(0, 100))
time.sleep(1)

110
test/mnist/cli_mnist.py Normal file
Просмотреть файл

@ -0,0 +1,110 @@
import tensorwatch as tw
import time
import math
from tensorwatch import utils
import matplotlib.pyplot as plt
utils.set_debug_verbosity(4)
def img_in_class():
cli_train = tw.WatcherClient()
imgs = cli_train.create_stream(event_name='batch',
expr="top(l, out_xform=pyt_img_class_out_xform, order='rnd')", throttle=1)
img_plot = tw.mpl.ImagePlot()
img_plot.subscribe(imgs, viz_img_scale=3)
img_plot.show()
tw.image_utils.plt_loop()
def show_find_lr():
cli_train = tw.WatcherClient()
plot = tw.mpl.LinePlot()
train_batch_loss = cli_train.create_stream(event_name='batch',
expr='lambda d:(d.tt.scheduler.get_lr()[0], d.metrics.batch_loss)')
plot.subscribe(train_batch_loss, xtitle='Epoch', ytitle='Loss')
utils.wait_key()
def plot_grads():
train_cli = tw.WatcherClient()
grads = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
p = tw.plotly.LinePlot('Demo')
p.subscribe(grads, xtitle='Epoch', ytitle='Gradients', history_len=30, new_on_eval=True)
utils.wait_key()
def plot_grads1():
train_cli = tw.WatcherClient()
grads = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
grad_plot = tw.mpl.LinePlot()
grad_plot.subscribe(grads, xtitle='Epoch', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
grad_plot.show()
tw.plt_loop()
def plot_weight():
train_cli = tw.WatcherClient()
params = train_cli.create_stream(event_name='batch',
expr='lambda d:agg_params(d.model, lambda p: p.abs().mean().item())', throttle=1)
params_plot = tw.mpl.LinePlot()
params_plot.subscribe(params, xtitle='Epoch', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
params_plot.show()
tw.plt_loop()
def epoch_stats():
train_cli = tw.WatcherClient(port=0)
test_cli = tw.WatcherClient(port=1)
plot = tw.mpl.LinePlot()
train_loss = train_cli.create_stream(event_name="epoch",
expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_loss)')
plot.subscribe(train_loss, xtitle='Epoch', ytitle='Train Loss')
test_acc = test_cli.create_stream(event_name="epoch",
expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_accuracy)')
plot.subscribe(test_acc, xtitle='Epoch', ytitle='Test Accuracy', ylim=(0,1))
plot.show()
tw.plt_loop()
def batch_stats():
train_cli = tw.WatcherClient()
stream = train_cli.create_stream(event_name="batch",
expr='lambda v:(v.metrics.epochf, v.metrics.batch_loss)', throttle=0.75)
train_loss = tw.Visualizer(stream, clear_after_end=False, vis_type='mpl-line',
xtitle='Epoch', ytitle='Train Loss')
#train_acc = tw.Visualizer('lambda v:(v.metrics.epochf, v.metrics.epoch_loss)', event_name="batch",
# xtitle='Epoch', ytitle='Train Accuracy', clear_after_end=False, yrange=(0,1),
# vis=train_loss, vis_type='mpl-line')
train_loss.show()
tw.image_utils.plt_loop()
def text_stats():
train_cli = tw.WatcherClient()
stream = train_cli.create_stream(event_name="batch",
expr='lambda d:(d.x, d.metrics.batch_loss)')
trl = tw.Visualizer(stream, vis_type=None)
trl.show()
input('Paused...')
#epoch_stats()
#plot_weight()
#plot_grads1()
#img_in_class()
#text_stats()
batch_stats()

Просмотреть файл

@ -0,0 +1,16 @@
from tensorwatch.saliency import saliency
from tensorwatch import image_utils, imagenet_utils, pytorch_utils
model = pytorch_utils.get_model('resnet50')
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/dogs.png', 240, #'../data/elephant.png', 101,
image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB')
results = saliency.get_image_saliency_results(model, raw_input, input, target_class)
figure = saliency.get_image_saliency_plot(results)
image_utils.plt_loop()

Просмотреть файл

@ -0,0 +1,10 @@
import torch
import torchvision.models
import tensorwatch as tw
vgg16_model = torchvision.models.vgg16()
drawing = tw.draw_model(vgg16_model, [1, 3, 224, 224])
drawing.save('abc')
input("Press any key")

10
test/pre_train/tsny.py Normal file
Просмотреть файл

@ -0,0 +1,10 @@
import tensorwatch as tw
from regim import *
ds = DataUtils.mnist_datasets(linearize=True, train_test=False)
ds = DataUtils.sample_by_class(ds, k=5, shuffle=True, as_np=True, no_test=True)
comps = tw.get_tsne_components(ds)
print(comps)
plot = tw.Visualizer(comps, hover_images=ds[0], hover_image_reshape=(28,28), vis_type='tsne')
plot.show()

77
test/simple_log/cli_ij.py Normal file
Просмотреть файл

@ -0,0 +1,77 @@
import tensorwatch as tw
import time
import math
from tensorwatch import utils
utils.set_debug_verbosity(4)
def mpl_line_plot():
cli = tw.WatcherClient()
p = tw.mpl.LinePlot(title='Demo')
s1 = cli.create_stream(event_name='ev_i', expr='map(lambda v:math.sqrt(v.val)*2, l)')
p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)')
p.show()
tw.plt_loop()
def mpl_history_plot():
cli = tw.WatcherClient()
p2 = tw.mpl.LinePlot(title='History Demo')
p2s1 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.val, math.sqrt(v.val)*2), l)')
p2.subscribe(p2s1, xtitle='Index', ytitle='sqrt(ev_j)', clear_after_end=True, history_len=15)
p2.show()
tw.plt_loop()
def show_stream():
cli = tw.WatcherClient()
print("Subscribing to event ev_i...")
s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:math.sqrt(v.val), l)')
r1 = tw.TextVis(title='L1')
r1.subscribe(s1)
r1.show()
print("Subscribing to event ev_j...")
s2 = cli.create_stream(event_name="ev_j", expr='map(lambda v:v.val*v.val, l)')
r2 = tw.TextVis(title='L2')
r2.subscribe(s2)
r2.show()
print("Waiting for key...")
utils.wait_key()
# this no longer directly supported
# TODO: create stream that allows enumeration from buffered values
#def read_stream():
# cli = tw.WatcherClient()
# with cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') as s1:
# for stream_item in s1:
# print(stream_item.value)
# print('done')
# utils.wait_key()
def plotly_line_graph():
cli = tw.WatcherClient()
s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)')
p = tw.plotly.LinePlot()
p.subscribe(s1)
p.show()
utils.wait_key()
def plotly_history_graph():
cli = tw.WatcherClient()
p = tw.plotly.LinePlot(title='Demo')
s2 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.x, v.val), l)')
p.subscribe(s2, ytitle='ev_j', history_len=15)
p.show()
utils.wait_key()
mpl_line_plot()
#mpl_history_plot()
#show_stream()
#plotly_line_graph()
#plotly_history_graph()

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше