initial
This commit is contained in:
Родитель
8d9d7490f0
Коммит
9e5f1f3355
|
@ -1,3 +1,6 @@
|
|||
data/
|
||||
!data/log-20180928-175828.dlclog
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
@ -102,3 +105,345 @@ venv.bak/
|
|||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
## Ignore Visual Studio temporary files, build results, and
|
||||
## files generated by popular Visual Studio add-ons.
|
||||
##
|
||||
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
||||
|
||||
# User-specific files
|
||||
*.rsuser
|
||||
*.suo
|
||||
*.user
|
||||
*.userosscache
|
||||
*.sln.docstates
|
||||
|
||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||
*.userprefs
|
||||
|
||||
# Build results
|
||||
[Dd]ebug/
|
||||
[Dd]ebugPublic/
|
||||
[Rr]elease/
|
||||
[Rr]eleases/
|
||||
x64/
|
||||
x86/
|
||||
[Aa][Rr][Mm]/
|
||||
[Aa][Rr][Mm]64/
|
||||
bld/
|
||||
[Bb]in/
|
||||
[Oo]bj/
|
||||
[Ll]og/
|
||||
|
||||
# Visual Studio 2015/2017 cache/options directory
|
||||
.vs/
|
||||
# Uncomment if you have tasks that create the project's static files in wwwroot
|
||||
#wwwroot/
|
||||
|
||||
# Visual Studio 2017 auto generated files
|
||||
Generated\ Files/
|
||||
|
||||
# MSTest test Results
|
||||
[Tt]est[Rr]esult*/
|
||||
[Bb]uild[Ll]og.*
|
||||
|
||||
# NUNIT
|
||||
*.VisualState.xml
|
||||
TestResult.xml
|
||||
|
||||
# Build Results of an ATL Project
|
||||
[Dd]ebugPS/
|
||||
[Rr]eleasePS/
|
||||
dlldata.c
|
||||
|
||||
# Benchmark Results
|
||||
BenchmarkDotNet.Artifacts/
|
||||
|
||||
# .NET Core
|
||||
project.lock.json
|
||||
project.fragment.lock.json
|
||||
artifacts/
|
||||
|
||||
# StyleCop
|
||||
StyleCopReport.xml
|
||||
|
||||
# Files built by Visual Studio
|
||||
*_i.c
|
||||
*_p.c
|
||||
*_h.h
|
||||
*.ilk
|
||||
*.meta
|
||||
*.obj
|
||||
*.iobj
|
||||
*.pch
|
||||
*.pdb
|
||||
*.ipdb
|
||||
*.pgc
|
||||
*.pgd
|
||||
*.rsp
|
||||
*.sbr
|
||||
*.tlb
|
||||
*.tli
|
||||
*.tlh
|
||||
*.tmp
|
||||
*.tmp_proj
|
||||
*_wpftmp.csproj
|
||||
*.log
|
||||
*.vspscc
|
||||
*.vssscc
|
||||
.builds
|
||||
*.pidb
|
||||
*.svclog
|
||||
*.scc
|
||||
|
||||
# Chutzpah Test files
|
||||
_Chutzpah*
|
||||
|
||||
# Visual C++ cache files
|
||||
ipch/
|
||||
*.aps
|
||||
*.ncb
|
||||
*.opendb
|
||||
*.opensdf
|
||||
*.sdf
|
||||
*.cachefile
|
||||
*.VC.db
|
||||
*.VC.VC.opendb
|
||||
|
||||
# Visual Studio profiler
|
||||
*.psess
|
||||
*.vsp
|
||||
*.vspx
|
||||
*.sap
|
||||
|
||||
# Visual Studio Trace Files
|
||||
*.e2e
|
||||
|
||||
# TFS 2012 Local Workspace
|
||||
$tf/
|
||||
|
||||
# Guidance Automation Toolkit
|
||||
*.gpState
|
||||
|
||||
# ReSharper is a .NET coding add-in
|
||||
_ReSharper*/
|
||||
*.[Rr]e[Ss]harper
|
||||
*.DotSettings.user
|
||||
|
||||
# JustCode is a .NET coding add-in
|
||||
.JustCode
|
||||
|
||||
# TeamCity is a build add-in
|
||||
_TeamCity*
|
||||
|
||||
# DotCover is a Code Coverage Tool
|
||||
*.dotCover
|
||||
|
||||
# AxoCover is a Code Coverage Tool
|
||||
.axoCover/*
|
||||
!.axoCover/settings.json
|
||||
|
||||
# Visual Studio code coverage results
|
||||
*.coverage
|
||||
*.coveragexml
|
||||
|
||||
# NCrunch
|
||||
_NCrunch_*
|
||||
.*crunch*.local.xml
|
||||
nCrunchTemp_*
|
||||
|
||||
# MightyMoose
|
||||
*.mm.*
|
||||
AutoTest.Net/
|
||||
|
||||
# Web workbench (sass)
|
||||
.sass-cache/
|
||||
|
||||
# Installshield output folder
|
||||
[Ee]xpress/
|
||||
|
||||
# DocProject is a documentation generator add-in
|
||||
DocProject/buildhelp/
|
||||
DocProject/Help/*.HxT
|
||||
DocProject/Help/*.HxC
|
||||
DocProject/Help/*.hhc
|
||||
DocProject/Help/*.hhk
|
||||
DocProject/Help/*.hhp
|
||||
DocProject/Help/Html2
|
||||
DocProject/Help/html
|
||||
|
||||
# Click-Once directory
|
||||
publish/
|
||||
|
||||
# Publish Web Output
|
||||
*.[Pp]ublish.xml
|
||||
*.azurePubxml
|
||||
# Note: Comment the next line if you want to checkin your web deploy settings,
|
||||
# but database connection strings (with potential passwords) will be unencrypted
|
||||
*.pubxml
|
||||
*.publishproj
|
||||
|
||||
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
||||
# checkin your Azure Web App publish settings, but sensitive information contained
|
||||
# in these scripts will be unencrypted
|
||||
PublishScripts/
|
||||
|
||||
# NuGet Packages
|
||||
*.nupkg
|
||||
# The packages folder can be ignored because of Package Restore
|
||||
**/[Pp]ackages/*
|
||||
# except build/, which is used as an MSBuild target.
|
||||
!**/[Pp]ackages/build/
|
||||
# Uncomment if necessary however generally it will be regenerated when needed
|
||||
#!**/[Pp]ackages/repositories.config
|
||||
# NuGet v3's project.json files produces more ignorable files
|
||||
*.nuget.props
|
||||
*.nuget.targets
|
||||
|
||||
# Microsoft Azure Build Output
|
||||
csx/
|
||||
*.build.csdef
|
||||
|
||||
# Microsoft Azure Emulator
|
||||
ecf/
|
||||
rcf/
|
||||
|
||||
# Windows Store app package directories and files
|
||||
AppPackages/
|
||||
BundleArtifacts/
|
||||
Package.StoreAssociation.xml
|
||||
_pkginfo.txt
|
||||
*.appx
|
||||
|
||||
# Visual Studio cache files
|
||||
# files ending in .cache can be ignored
|
||||
*.[Cc]ache
|
||||
# but keep track of directories ending in .cache
|
||||
!?*.[Cc]ache/
|
||||
|
||||
# Others
|
||||
ClientBin/
|
||||
~$*
|
||||
*~
|
||||
*.dbmdl
|
||||
*.dbproj.schemaview
|
||||
*.jfm
|
||||
*.pfx
|
||||
*.publishsettings
|
||||
orleans.codegen.cs
|
||||
|
||||
# Including strong name files can present a security risk
|
||||
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
||||
#*.snk
|
||||
|
||||
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
||||
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
||||
#bower_components/
|
||||
# ASP.NET Core default setup: bower directory is configured as wwwroot/lib/ and bower restore is true
|
||||
**/wwwroot/lib/
|
||||
|
||||
# RIA/Silverlight projects
|
||||
Generated_Code/
|
||||
|
||||
# Backup & report files from converting an old project file
|
||||
# to a newer Visual Studio version. Backup files are not needed,
|
||||
# because we have git ;-)
|
||||
_UpgradeReport_Files/
|
||||
Backup*/
|
||||
UpgradeLog*.XML
|
||||
UpgradeLog*.htm
|
||||
ServiceFabricBackup/
|
||||
*.rptproj.bak
|
||||
|
||||
# SQL Server files
|
||||
*.mdf
|
||||
*.ldf
|
||||
*.ndf
|
||||
|
||||
# Business Intelligence projects
|
||||
*.rdl.data
|
||||
*.bim.layout
|
||||
*.bim_*.settings
|
||||
*.rptproj.rsuser
|
||||
|
||||
# Microsoft Fakes
|
||||
FakesAssemblies/
|
||||
|
||||
# GhostDoc plugin setting file
|
||||
*.GhostDoc.xml
|
||||
|
||||
# Node.js Tools for Visual Studio
|
||||
.ntvs_analysis.dat
|
||||
node_modules/
|
||||
|
||||
# Visual Studio 6 build log
|
||||
*.plg
|
||||
|
||||
# Visual Studio 6 workspace options file
|
||||
*.opt
|
||||
|
||||
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
||||
*.vbw
|
||||
|
||||
# Visual Studio LightSwitch build output
|
||||
**/*.HTMLClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/ModelManifest.xml
|
||||
**/*.Server/GeneratedArtifacts
|
||||
**/*.Server/ModelManifest.xml
|
||||
_Pvt_Extensions
|
||||
|
||||
# Paket dependency manager
|
||||
.paket/paket.exe
|
||||
paket-files/
|
||||
|
||||
# FAKE - F# Make
|
||||
.fake/
|
||||
|
||||
# JetBrains Rider
|
||||
.idea/
|
||||
*.sln.iml
|
||||
|
||||
# CodeRush personal settings
|
||||
.cr/personal
|
||||
|
||||
# Python Tools for Visual Studio (PTVS)
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Cake - Uncomment if you are using it
|
||||
# tools/**
|
||||
# !tools/packages.config
|
||||
|
||||
# Tabs Studio
|
||||
*.tss
|
||||
|
||||
# Telerik's JustMock configuration file
|
||||
*.jmconfig
|
||||
|
||||
# BizTalk build output
|
||||
*.btp.cs
|
||||
*.btm.cs
|
||||
*.odx.cs
|
||||
*.xsd.cs
|
||||
|
||||
# OpenCover UI analysis results
|
||||
OpenCover/
|
||||
|
||||
# Azure Stream Analytics local run output
|
||||
ASALocalRun/
|
||||
|
||||
# MSBuild Binary and Structured Log
|
||||
*.binlog
|
||||
|
||||
# NVidia Nsight GPU debugger configuration file
|
||||
*.nvuser
|
||||
|
||||
# MFractors (Xamarin productivity tool) working folder
|
||||
.mfractor/
|
||||
|
||||
# Local History for Visual Studio
|
||||
.localhistory/
|
||||
|
||||
# BeatPulse healthcheck temp database
|
||||
healthchecksdb
|
|
@ -0,0 +1,15 @@
|
|||
[MASTER]
|
||||
ignore=model_graph,
|
||||
receptive_field,
|
||||
saliency
|
||||
|
||||
[TYPECHECK]
|
||||
ignored-modules=numpy,torch,matplotlib,pyplot,zmq
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
disable=protected-access,
|
||||
broad-except,
|
||||
global-statement,
|
||||
fixme,
|
||||
C,
|
||||
R
|
|
@ -0,0 +1,5 @@
|
|||
# Authors
|
||||
|
||||
TensorWatch was originally conceived and created by [Shital Shah](https://www.shitalshah.com) during early 2019.
|
||||
|
||||
List of all contributors since our first release in May 2019 can be [found here](https://github.com/Microsoft/tensorwatch/graphs/contributors).
|
|
@ -0,0 +1,6 @@
|
|||
# What's new
|
||||
|
||||
Below is summarized list of important changes. This does not include minor/less important changes or bug fixes or documentation update. This list updated every few months. For complete detailed changes, please review [commit history](https://github.com/Microsoft/tensorwatch/commits/master).
|
||||
|
||||
### May, 2019
|
||||
* First release!
|
|
@ -0,0 +1,14 @@
|
|||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to
|
||||
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
|
||||
and actually do, grant us the rights to use your contribution. For details, visit
|
||||
https://cla.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need
|
||||
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
|
||||
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,18 @@
|
|||
# Read This First
|
||||
|
||||
## If you are reporting a bug
|
||||
* Make sure to write **all reproduction steps**
|
||||
* Include full error message in text form
|
||||
* Search issues for error message before filing issue
|
||||
* Attach screenshot if applicable
|
||||
* Include code to run if applicable
|
||||
|
||||
## If you have question
|
||||
* Add clear and concise title
|
||||
* Add OS, TensorWatch version, Python version if applicable
|
||||
* Include context on what you are trying to achieve
|
||||
* Include details of what you already did to find answers
|
||||
|
||||
**What's better than filing issue? Filing a pull request :).**
|
||||
|
||||
------------------------------------ (Remove above before filing the issue) ------------------------------------
|
21
LICENSE
21
LICENSE
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
|
@ -0,0 +1,22 @@
|
|||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
NOTICES AND INFORMATION
|
||||
Do Not Translate or Localize
|
||||
|
||||
This software incorporates material from third parties. Microsoft makes certain
|
||||
open source code available at http://3rdpartysource.microsoft.com, or you may
|
||||
send a check or money order for US $5.00, including the product name, the open
|
||||
source component name, and version number, to:
|
||||
|
||||
Source Code Compliance Team
|
||||
Microsoft Corporation
|
||||
One Microsoft Way
|
||||
Redmond, WA 98052
|
||||
USA
|
||||
|
||||
Notwithstanding any other terms, you may reverse engineer this software to the
|
||||
extent required to debug changes to any libraries licensed under the GNU Lesser
|
||||
General Public License.
|
||||
|
||||
|
||||
**Component.** https://github.com/Swall0w/torchstat
|
||||
|
||||
**Open Source License/Copyright Notice.**
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018 Swall0w - Alan
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
**Component.** https://github.com/waleedka/hiddenlayer
|
||||
|
||||
**Open Source License/Copyright Notice.**
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018 Waleed Abdulla, Phil Ferriere
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
**Component.** https://github.com/yulongwang12/visual-attribution
|
||||
|
||||
**Open Source License/Copyright Notice.**
|
||||
BSD 2-Clause License
|
||||
|
||||
Copyright (c) 2019, Yulong Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
**Component.** https://github.com/marcotcr/lime
|
||||
|
||||
**Open Source License/Copyright Notice.**
|
||||
Copyright (c) 2016, Marco Tulio Correia Ribeiro
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
20
README.md
20
README.md
|
@ -1,14 +1,18 @@
|
|||
This package contains Python library for [tensorwatch](https://github.com/sytelus/tensorwatch).
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.microsoft.com.
|
||||
This project welcomes contributions and suggestions. Most contributions require you to
|
||||
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
|
||||
and actually do, grant us the rights to use your contribution. For details, visit
|
||||
https://cla.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need
|
||||
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
|
||||
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Support
|
||||
|
||||
We highly recommend to take a look at source code and contribute to the project. Also please consider [contributing](CONTRIBUTING.md) new features and fixes :).
|
||||
|
||||
* [Join TensorWatch Facebook Group](https://www.facebook.com/groups/378075159472803/)
|
||||
* [File GitHub Issue](https://github.com/Microsoft/tensorwatch/issues)
|
|
@ -0,0 +1,23 @@
|
|||
* Fix cell size issue
|
||||
* Refactor _plot* interface to accept all values, for ImagePlot only use last value
|
||||
* Refactor ImagePlot for arbitrary number of images with alpha, cmap
|
||||
* Change tw.open -> tw.create_viz
|
||||
* Make sure streams have names as key, each data point has index
|
||||
* Add tw.open_viz(stream_name, from_index)_
|
||||
* Add persist=device_name option for streams
|
||||
* Ability to use streams in standalone mode
|
||||
* tw.create_viz on server side
|
||||
* tw.log for server side
|
||||
* experiment with IPC channel
|
||||
* confusion matrix as in https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
|
||||
* Speed up import
|
||||
* Do linting
|
||||
* live perf data
|
||||
* NaN tracing
|
||||
* PCA
|
||||
* Remove error if MNIST notebook is on and we run fruits
|
||||
* Remove 2nd image from fruits
|
||||
* clear exisitng streams when starting client
|
||||
* ImagePlotItem should accept numpy array or pillow or torch tensor
|
||||
* image plot getting refreshed at 12hz instead of 2 hz in MNIST
|
||||
* image plot doesn't title
|
|
@ -0,0 +1,11 @@
|
|||
conda install -c conda-forge jupyterlab nodejs
|
||||
conda install ipywidgets
|
||||
conda install -c plotly plotly-orca psutil
|
||||
|
||||
set NODE_OPTIONS=--max-old-space-size=4096
|
||||
jupyter labextension install @jupyter-widgets/jupyterlab-manager --no-build
|
||||
jupyter labextension install plotlywidget --no-build
|
||||
jupyter labextension install @jupyterlab/plotly-extension --no-build
|
||||
jupyter labextension install jupyterlab-chart-editor --no-build
|
||||
jupyter lab build
|
||||
set NODE_OPTIONS=
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorwatch as tw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from regim import *\n",
|
||||
"ds = DataUtils.mnist_datasets(linearize=True, train_test=False)\n",
|
||||
"ds = DataUtils.sample_by_class(ds, k=50, shuffle=True, as_np=True, no_test=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"components = tw.get_tsne_components(ds)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4a8099b3e2094f2e89924e61b5ad4701",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(FigureWidget({\n",
|
||||
" 'data': [{'hoverinfo': 'text',\n",
|
||||
" 'line': {'color': 'rgb(31, 119,…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"comp_stream = tw.ArrayStream(components)\n",
|
||||
"vis = tw.Visualizer(comp_stream, vis_type='tsne', \n",
|
||||
" hover_images=ds[0], hover_image_reshape=(28,28))\n",
|
||||
"vis.show()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,722 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorwatch as tw\n",
|
||||
"import torchvision.models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"alexnet_model = torchvision.models.alexnet()\n",
|
||||
"vgg16_model = torchvision.models.vgg16()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/svg+xml": [
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
|
||||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
|
||||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
|
||||
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
|
||||
" -->\r\n",
|
||||
"<!-- Title: %3 Pages: 1 -->\r\n",
|
||||
"<svg width=\"387pt\" height=\"1403pt\"\r\n",
|
||||
" viewBox=\"0.00 0.00 387.00 1403.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
|
||||
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 1331)\">\r\n",
|
||||
"<title>%3</title>\r\n",
|
||||
"<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,72 -72,-1331 315,-1331 315,72 -72,72\"/>\r\n",
|
||||
"<!-- AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19 -->\r\n",
|
||||
"<g id=\"node1\" class=\"node\"><title>AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19</title>\r\n",
|
||||
"<g id=\"a_node1\"><a xlink:title=\"{'kernel_shape': [3, 3], 'pads': [0, 0, 0, 0], 'strides': [2, 2]}\">\r\n",
|
||||
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"188,-1176 114,-1176 114,-1140 188,-1140 188,-1176\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"122\" y=\"-1155\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">MaxPool3x3</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 5377442926079053455 -->\r\n",
|
||||
"<g id=\"node16\" class=\"node\"><title>5377442926079053455</title>\r\n",
|
||||
"<g id=\"a_node16\"><a xlink:title=\"{'dilations': [1, 1], 'group': 1, 'kernel_shape': [5, 5], 'pads': [2, 2, 2, 2], 'strides': [1, 1]}\">\r\n",
|
||||
"<polygon fill=\"#a1c9f4\" stroke=\"#7c96bc\" points=\"197.5,-1093 104.5,-1093 104.5,-1057 197.5,-1057 197.5,-1093\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"113\" y=\"-1072\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Conv5x5 > Relu</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19->5377442926079053455 -->\r\n",
|
||||
"<g id=\"edge12\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19->5377442926079053455</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-1139.82C151,-1129.19 151,-1115.31 151,-1103.2\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-1103.15 151,-1093.15 147.5,-1103.15 154.5,-1103.15\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"182\" y=\"-1114\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x64x27x27</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22 -->\r\n",
|
||||
"<g id=\"node2\" class=\"node\"><title>AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22</title>\r\n",
|
||||
"<g id=\"a_node2\"><a xlink:title=\"{'kernel_shape': [3, 3], 'pads': [0, 0, 0, 0], 'strides': [2, 2]}\">\r\n",
|
||||
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"188,-1010 114,-1010 114,-974 188,-974 188,-1010\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"122\" y=\"-989\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">MaxPool3x3</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 7417167147267928641 -->\r\n",
|
||||
"<g id=\"node19\" class=\"node\"><title>7417167147267928641</title>\r\n",
|
||||
"<g id=\"a_node19\"><a xlink:title=\"{'dilations': [1, 1], 'group': 1, 'kernel_shape': [3, 3], 'pads': [1, 1, 1, 1], 'strides': [1, 1]}\">\r\n",
|
||||
"<polygon fill=\"#a1c9f4\" stroke=\"#7c96bc\" points=\"197.5,-927 104.5,-927 104.5,-883 197.5,-883 197.5,-927\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"113\" y=\"-911\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Conv3x3 > Relu</text>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"182\" y=\"-890\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">x3</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22->7417167147267928641 -->\r\n",
|
||||
"<g id=\"edge18\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22->7417167147267928641</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-973.799C151,-963.369 151,-949.742 151,-937.443\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-937.09 151,-927.09 147.5,-937.09 154.5,-937.09\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-948\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x192x13x13</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29 -->\r\n",
|
||||
"<g id=\"node3\" class=\"node\"><title>AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29</title>\r\n",
|
||||
"<g id=\"a_node3\"><a xlink:title=\"{'kernel_shape': [3, 3], 'pads': [0, 0, 0, 0], 'strides': [2, 2]}\">\r\n",
|
||||
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"188,-836 114,-836 114,-800 188,-800 188,-836\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"122\" y=\"-815\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">MaxPool3x3</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/31 -->\r\n",
|
||||
"<g id=\"node5\" class=\"node\"><title>AlexNet/outputs/31</title>\r\n",
|
||||
"<polygon fill=\"#d0bbff\" stroke=\"#7c96bc\" points=\"147,-753 93,-753 93,-717 147,-717 147,-753\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"105\" y=\"-732\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Shape</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29->AlexNet/outputs/31 -->\r\n",
|
||||
"<g id=\"edge1\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29->AlexNet/outputs/31</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M129.203,-799.762C124.256,-794.629 119.675,-788.594 117,-782 114.648,-776.204 114.012,-769.66 114.263,-763.364\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"117.769,-763.501 115.377,-753.179 110.811,-762.74 117.769,-763.501\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"145\" y=\"-774\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x256x6x6</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/37 -->\r\n",
|
||||
"<g id=\"node11\" class=\"node\"><title>AlexNet/outputs/37</title>\r\n",
|
||||
"<polygon fill=\"#fffea3\" stroke=\"#7c96bc\" points=\"194,-451 136,-451 136,-415 194,-415 194,-451\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"144\" y=\"-430\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Reshape</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29->AlexNet/outputs/37 -->\r\n",
|
||||
"<g id=\"edge2\" class=\"edge\"><title>AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29->AlexNet/outputs/37</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M163.591,-799.957C173.972,-784.197 187,-759.699 187,-736 187,-736 187,-736 187,-505 187,-489.781 182.331,-473.515 177.294,-460.398\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"180.489,-458.962 173.443,-451.05 174.016,-461.629 180.489,-458.962\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"215\" y=\"-618\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x256x6x6</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/30 -->\r\n",
|
||||
"<g id=\"node4\" class=\"node\"><title>AlexNet/outputs/30</title>\r\n",
|
||||
"<g id=\"a_node4\"><a xlink:title=\"{'value': tensor(0)}\">\r\n",
|
||||
"<polygon fill=\"#ff9f9b\" stroke=\"#7c96bc\" points=\"75,-753 13,-753 13,-717 75,-717 75,-753\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"21\" y=\"-732\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Constant</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/32 -->\r\n",
|
||||
"<g id=\"node6\" class=\"node\"><title>AlexNet/outputs/32</title>\r\n",
|
||||
"<g id=\"a_node6\"><a xlink:title=\"{'axis': 0}\">\r\n",
|
||||
"<polygon fill=\"#debb9b\" stroke=\"#7c96bc\" points=\"147,-680 93,-680 93,-644 147,-644 147,-680\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"103\" y=\"-659\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Gather</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/30->AlexNet/outputs/32 -->\r\n",
|
||||
"<g id=\"edge3\" class=\"edge\"><title>AlexNet/outputs/30->AlexNet/outputs/32</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M62.3975,-716.813C72.0128,-707.83 83.9339,-696.693 94.4323,-686.886\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"96.8539,-689.413 101.772,-680.029 92.0752,-684.298 96.8539,-689.413\"/>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/31->AlexNet/outputs/32 -->\r\n",
|
||||
"<g id=\"edge4\" class=\"edge\"><title>AlexNet/outputs/31->AlexNet/outputs/32</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M120,-716.813C120,-708.789 120,-699.047 120,-690.069\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"123.5,-690.029 120,-680.029 116.5,-690.029 123.5,-690.029\"/>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- /outputs/34 -->\r\n",
|
||||
"<g id=\"node8\" class=\"node\"><title>/outputs/34</title>\r\n",
|
||||
"<g id=\"a_node8\"><a xlink:title=\"{'axes': [0]}\">\r\n",
|
||||
"<polygon fill=\"#bcd6fc\" stroke=\"#7c96bc\" points=\"158,-597 88,-597 88,-561 158,-561 158,-597\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"96\" y=\"-576\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Unsqueeze</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/32->/outputs/34 -->\r\n",
|
||||
"<g id=\"edge5\" class=\"edge\"><title>AlexNet/outputs/32->/outputs/34</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M120.636,-643.822C121.03,-633.19 121.544,-619.306 121.992,-607.204\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"125.492,-607.276 122.365,-597.153 118.497,-607.017 125.492,-607.276\"/>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/33 -->\r\n",
|
||||
"<g id=\"node7\" class=\"node\"><title>AlexNet/outputs/33</title>\r\n",
|
||||
"<g id=\"a_node7\"><a xlink:title=\"{'value': tensor(9216)}\">\r\n",
|
||||
"<polygon fill=\"#ff9f9b\" stroke=\"#7c96bc\" points=\"66,-680 4,-680 4,-644 66,-644 66,-680\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"12\" y=\"-659\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Constant</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- /outputs/35 -->\r\n",
|
||||
"<g id=\"node9\" class=\"node\"><title>/outputs/35</title>\r\n",
|
||||
"<g id=\"a_node9\"><a xlink:title=\"{'axes': [0]}\">\r\n",
|
||||
"<polygon fill=\"#bcd6fc\" stroke=\"#7c96bc\" points=\"70,-597 0,-597 0,-561 70,-561 70,-597\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"8\" y=\"-576\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Unsqueeze</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/33->/outputs/35 -->\r\n",
|
||||
"<g id=\"edge6\" class=\"edge\"><title>AlexNet/outputs/33->/outputs/35</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M35,-643.822C35,-633.19 35,-619.306 35,-607.204\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"38.5001,-607.153 35,-597.153 31.5001,-607.153 38.5001,-607.153\"/>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- /outputs/36 -->\r\n",
|
||||
"<g id=\"node10\" class=\"node\"><title>/outputs/36</title>\r\n",
|
||||
"<g id=\"a_node10\"><a xlink:title=\"{'axis': 0}\">\r\n",
|
||||
"<polygon fill=\"#d0bbff\" stroke=\"#7c96bc\" points=\"150,-524 96,-524 96,-488 150,-488 150,-524\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"105\" y=\"-503\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Concat</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- /outputs/34->/outputs/36 -->\r\n",
|
||||
"<g id=\"edge7\" class=\"edge\"><title>/outputs/34->/outputs/36</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M123,-560.813C123,-552.789 123,-543.047 123,-534.069\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"126.5,-534.029 123,-524.029 119.5,-534.029 126.5,-534.029\"/>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- /outputs/35->/outputs/36 -->\r\n",
|
||||
"<g id=\"edge8\" class=\"edge\"><title>/outputs/35->/outputs/36</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M56.3023,-560.813C67.652,-551.656 81.7763,-540.26 94.1015,-530.316\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"96.3088,-533.032 101.894,-524.029 91.9133,-527.584 96.3088,-533.032\"/>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- /outputs/36->AlexNet/outputs/37 -->\r\n",
|
||||
"<g id=\"edge9\" class=\"edge\"><title>/outputs/36->AlexNet/outputs/37</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M133.167,-487.813C138.12,-479.441 144.179,-469.197 149.677,-459.903\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"152.848,-461.418 154.927,-451.029 146.823,-457.854 152.848,-461.418\"/>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39 -->\r\n",
|
||||
"<g id=\"node12\" class=\"node\"><title>AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39</title>\r\n",
|
||||
"<g id=\"a_node12\"><a xlink:title=\"{'ratio': 0.5}\">\r\n",
|
||||
"<polygon fill=\"#b9f2f0\" stroke=\"#7c96bc\" points=\"202.5,-368 127.5,-368 127.5,-332 202.5,-332 202.5,-368\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"136\" y=\"-347\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Dropout 0.5</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/outputs/37->AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39 -->\r\n",
|
||||
"<g id=\"edge10\" class=\"edge\"><title>AlexNet/outputs/37->AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-414.822C165,-404.19 165,-390.306 165,-378.204\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-378.153 165,-368.153 161.5,-378.153 168.5,-378.153\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-389\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x9216</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 10523010716743172207 -->\r\n",
|
||||
"<g id=\"node17\" class=\"node\"><title>10523010716743172207</title>\r\n",
|
||||
"<g id=\"a_node17\"><a xlink:title=\"{'alpha': 1.0, 'beta': 1.0, 'transB': 1}\">\r\n",
|
||||
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"205,-285 125,-285 125,-249 205,-249 205,-285\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"133\" y=\"-264\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Linear > Relu</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39->10523010716743172207 -->\r\n",
|
||||
"<g id=\"edge14\" class=\"edge\"><title>AlexNet/Sequential[classifier]/Dropout[0]/outputs/38/39->10523010716743172207</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-331.822C165,-321.19 165,-307.306 165,-295.204\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-295.153 165,-285.153 161.5,-295.153 168.5,-295.153\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-306\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x9216</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43 -->\r\n",
|
||||
"<g id=\"node13\" class=\"node\"><title>AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43</title>\r\n",
|
||||
"<g id=\"a_node13\"><a xlink:title=\"{'ratio': 0.5}\">\r\n",
|
||||
"<polygon fill=\"#b9f2f0\" stroke=\"#7c96bc\" points=\"202.5,-202 127.5,-202 127.5,-166 202.5,-166 202.5,-202\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"136\" y=\"-181\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Dropout 0.5</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 4117491511057718684 -->\r\n",
|
||||
"<g id=\"node18\" class=\"node\"><title>4117491511057718684</title>\r\n",
|
||||
"<g id=\"a_node18\"><a xlink:title=\"{'alpha': 1.0, 'beta': 1.0, 'transB': 1}\">\r\n",
|
||||
"<polygon fill=\"#8de5a1\" stroke=\"#7c96bc\" points=\"205,-119 125,-119 125,-83 205,-83 205,-119\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"133\" y=\"-98\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Linear > Relu</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43->4117491511057718684 -->\r\n",
|
||||
"<g id=\"edge16\" class=\"edge\"><title>AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43->4117491511057718684</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-165.822C165,-155.19 165,-141.306 165,-129.204\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-129.153 165,-119.153 161.5,-129.153 168.5,-129.153\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-140\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x4096</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- AlexNet/Sequential[classifier]/ReLU[5]/outputs/46 -->\r\n",
|
||||
"<g id=\"node14\" class=\"node\"><title>AlexNet/Sequential[classifier]/ReLU[5]/outputs/46</title>\r\n",
|
||||
"<g id=\"a_node14\"><a xlink:title=\"{'alpha': 1.0, 'beta': 1.0, 'transB': 1}\">\r\n",
|
||||
"<polygon fill=\"#4878d0\" stroke=\"#7c96bc\" points=\"192,-36 138,-36 138,-0 192,-0 192,-36\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"150\" y=\"-15\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Linear</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 10377602221935690008 -->\r\n",
|
||||
"<g id=\"node15\" class=\"node\"><title>10377602221935690008</title>\r\n",
|
||||
"<g id=\"a_node15\"><a xlink:title=\"{'dilations': [1, 1], 'group': 1, 'kernel_shape': [11, 11], 'pads': [2, 2, 2, 2], 'strides': [4, 4]}\">\r\n",
|
||||
"<polygon fill=\"#a1c9f4\" stroke=\"#7c96bc\" points=\"203.5,-1259 98.5,-1259 98.5,-1223 203.5,-1223 203.5,-1259\"/>\r\n",
|
||||
"<text text-anchor=\"start\" x=\"107\" y=\"-1238\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\">Conv11x11 > Relu</text>\r\n",
|
||||
"</a>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 10377602221935690008->AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19 -->\r\n",
|
||||
"<g id=\"edge11\" class=\"edge\"><title>10377602221935690008->AlexNet/Sequential[features]/MaxPool2d[2]/outputs/19</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-1222.82C151,-1212.19 151,-1198.31 151,-1186.2\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-1186.15 151,-1176.15 147.5,-1186.15 154.5,-1186.15\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"182\" y=\"-1197\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x64x55x55</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 5377442926079053455->AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22 -->\r\n",
|
||||
"<g id=\"edge13\" class=\"edge\"><title>5377442926079053455->AlexNet/Sequential[features]/MaxPool2d[5]/outputs/22</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-1056.82C151,-1046.19 151,-1032.31 151,-1020.2\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-1020.15 151,-1010.15 147.5,-1020.15 154.5,-1020.15\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-1031\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x192x27x27</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 10523010716743172207->AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43 -->\r\n",
|
||||
"<g id=\"edge15\" class=\"edge\"><title>10523010716743172207->AlexNet/Sequential[classifier]/Dropout[3]/outputs/42/43</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-248.822C165,-238.19 165,-224.306 165,-212.204\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-212.153 165,-202.153 161.5,-212.153 168.5,-212.153\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-223\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x4096</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 4117491511057718684->AlexNet/Sequential[classifier]/ReLU[5]/outputs/46 -->\r\n",
|
||||
"<g id=\"edge17\" class=\"edge\"><title>4117491511057718684->AlexNet/Sequential[classifier]/ReLU[5]/outputs/46</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M165,-82.822C165,-72.1903 165,-58.306 165,-46.2035\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"168.5,-46.1532 165,-36.1533 161.5,-46.1533 168.5,-46.1532\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-57\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x4096</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"<!-- 7417167147267928641->AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29 -->\r\n",
|
||||
"<g id=\"edge19\" class=\"edge\"><title>7417167147267928641->AlexNet/Sequential[features]/MaxPool2d[12]/outputs/29</title>\r\n",
|
||||
"<path fill=\"none\" stroke=\"#7c96bc\" d=\"M151,-882.989C151,-871.923 151,-858.219 151,-846.336\"/>\r\n",
|
||||
"<polygon fill=\"#7c96bc\" stroke=\"#7c96bc\" points=\"154.5,-846.062 151,-836.062 147.5,-846.062 154.5,-846.062\"/>\r\n",
|
||||
"<text text-anchor=\"middle\" x=\"185\" y=\"-857\" font-family=\"Verdana\" font-size=\"10.00\" fill=\"#202020\"> 1x256x13x13</text>\r\n",
|
||||
"</g>\r\n",
|
||||
"</g>\r\n",
|
||||
"</svg>\r\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<tensorwatch.model_graph.hiddenlayer.graph.Graph at 0x1d63302d5f8>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tw.draw_model(alexnet_model, [1, 3, 224, 224])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[MAdd]: Dropout is not supported!\n",
|
||||
"[Flops]: Dropout is not supported!\n",
|
||||
"[Memory]: Dropout is not supported!\n",
|
||||
"[MAdd]: Dropout is not supported!\n",
|
||||
"[Flops]: Dropout is not supported!\n",
|
||||
"[Memory]: Dropout is not supported!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>module name</th>\n",
|
||||
" <th>input shape</th>\n",
|
||||
" <th>output shape</th>\n",
|
||||
" <th>params</th>\n",
|
||||
" <th>memory(MB)</th>\n",
|
||||
" <th>MAdd</th>\n",
|
||||
" <th>Flops</th>\n",
|
||||
" <th>MemRead(B)</th>\n",
|
||||
" <th>MemWrite(B)</th>\n",
|
||||
" <th>duration[%]</th>\n",
|
||||
" <th>MemR+W(B)</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>features.0</td>\n",
|
||||
" <td>3 224 224</td>\n",
|
||||
" <td>64 55 55</td>\n",
|
||||
" <td>23296.0</td>\n",
|
||||
" <td>0.74</td>\n",
|
||||
" <td>140,553,600.0</td>\n",
|
||||
" <td>70,470,400.0</td>\n",
|
||||
" <td>695296.0</td>\n",
|
||||
" <td>774400.0</td>\n",
|
||||
" <td>9.08%</td>\n",
|
||||
" <td>1469696.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>features.1</td>\n",
|
||||
" <td>64 55 55</td>\n",
|
||||
" <td>64 55 55</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.74</td>\n",
|
||||
" <td>193,600.0</td>\n",
|
||||
" <td>193,600.0</td>\n",
|
||||
" <td>774400.0</td>\n",
|
||||
" <td>774400.0</td>\n",
|
||||
" <td>2.59%</td>\n",
|
||||
" <td>1548800.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>features.2</td>\n",
|
||||
" <td>64 55 55</td>\n",
|
||||
" <td>64 27 27</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.18</td>\n",
|
||||
" <td>373,248.0</td>\n",
|
||||
" <td>193,600.0</td>\n",
|
||||
" <td>774400.0</td>\n",
|
||||
" <td>186624.0</td>\n",
|
||||
" <td>6.48%</td>\n",
|
||||
" <td>961024.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>features.3</td>\n",
|
||||
" <td>64 27 27</td>\n",
|
||||
" <td>192 27 27</td>\n",
|
||||
" <td>307392.0</td>\n",
|
||||
" <td>0.53</td>\n",
|
||||
" <td>447,897,600.0</td>\n",
|
||||
" <td>224,088,768.0</td>\n",
|
||||
" <td>1416192.0</td>\n",
|
||||
" <td>559872.0</td>\n",
|
||||
" <td>10.38%</td>\n",
|
||||
" <td>1976064.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>features.4</td>\n",
|
||||
" <td>192 27 27</td>\n",
|
||||
" <td>192 27 27</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.53</td>\n",
|
||||
" <td>139,968.0</td>\n",
|
||||
" <td>139,968.0</td>\n",
|
||||
" <td>559872.0</td>\n",
|
||||
" <td>559872.0</td>\n",
|
||||
" <td>1.30%</td>\n",
|
||||
" <td>1119744.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>features.5</td>\n",
|
||||
" <td>192 27 27</td>\n",
|
||||
" <td>192 13 13</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.12</td>\n",
|
||||
" <td>259,584.0</td>\n",
|
||||
" <td>139,968.0</td>\n",
|
||||
" <td>559872.0</td>\n",
|
||||
" <td>129792.0</td>\n",
|
||||
" <td>3.89%</td>\n",
|
||||
" <td>689664.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>features.6</td>\n",
|
||||
" <td>192 13 13</td>\n",
|
||||
" <td>384 13 13</td>\n",
|
||||
" <td>663936.0</td>\n",
|
||||
" <td>0.25</td>\n",
|
||||
" <td>224,280,576.0</td>\n",
|
||||
" <td>112,205,184.0</td>\n",
|
||||
" <td>2785536.0</td>\n",
|
||||
" <td>259584.0</td>\n",
|
||||
" <td>5.19%</td>\n",
|
||||
" <td>3045120.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>features.7</td>\n",
|
||||
" <td>384 13 13</td>\n",
|
||||
" <td>384 13 13</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.25</td>\n",
|
||||
" <td>64,896.0</td>\n",
|
||||
" <td>64,896.0</td>\n",
|
||||
" <td>259584.0</td>\n",
|
||||
" <td>259584.0</td>\n",
|
||||
" <td>0.00%</td>\n",
|
||||
" <td>519168.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>features.8</td>\n",
|
||||
" <td>384 13 13</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>884992.0</td>\n",
|
||||
" <td>0.17</td>\n",
|
||||
" <td>299,040,768.0</td>\n",
|
||||
" <td>149,563,648.0</td>\n",
|
||||
" <td>3799552.0</td>\n",
|
||||
" <td>173056.0</td>\n",
|
||||
" <td>10.37%</td>\n",
|
||||
" <td>3972608.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>features.9</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.17</td>\n",
|
||||
" <td>43,264.0</td>\n",
|
||||
" <td>43,264.0</td>\n",
|
||||
" <td>173056.0</td>\n",
|
||||
" <td>173056.0</td>\n",
|
||||
" <td>1.30%</td>\n",
|
||||
" <td>346112.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>features.10</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>590080.0</td>\n",
|
||||
" <td>0.17</td>\n",
|
||||
" <td>199,360,512.0</td>\n",
|
||||
" <td>99,723,520.0</td>\n",
|
||||
" <td>2533376.0</td>\n",
|
||||
" <td>173056.0</td>\n",
|
||||
" <td>11.67%</td>\n",
|
||||
" <td>2706432.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>features.11</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.17</td>\n",
|
||||
" <td>43,264.0</td>\n",
|
||||
" <td>43,264.0</td>\n",
|
||||
" <td>173056.0</td>\n",
|
||||
" <td>173056.0</td>\n",
|
||||
" <td>0.00%</td>\n",
|
||||
" <td>346112.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>features.12</td>\n",
|
||||
" <td>256 13 13</td>\n",
|
||||
" <td>256 6 6</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.04</td>\n",
|
||||
" <td>73,728.0</td>\n",
|
||||
" <td>43,264.0</td>\n",
|
||||
" <td>173056.0</td>\n",
|
||||
" <td>36864.0</td>\n",
|
||||
" <td>1.30%</td>\n",
|
||||
" <td>209920.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>13</th>\n",
|
||||
" <td>classifier.0</td>\n",
|
||||
" <td>9216</td>\n",
|
||||
" <td>9216</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.04</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.00%</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>14</th>\n",
|
||||
" <td>classifier.1</td>\n",
|
||||
" <td>9216</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>37752832.0</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td>75,493,376.0</td>\n",
|
||||
" <td>37,748,736.0</td>\n",
|
||||
" <td>151048192.0</td>\n",
|
||||
" <td>16384.0</td>\n",
|
||||
" <td>22.82%</td>\n",
|
||||
" <td>151064576.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>15</th>\n",
|
||||
" <td>classifier.2</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td>4,096.0</td>\n",
|
||||
" <td>4,096.0</td>\n",
|
||||
" <td>16384.0</td>\n",
|
||||
" <td>16384.0</td>\n",
|
||||
" <td>0.00%</td>\n",
|
||||
" <td>32768.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>16</th>\n",
|
||||
" <td>classifier.3</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.00%</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>17</th>\n",
|
||||
" <td>classifier.4</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>16781312.0</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td>33,550,336.0</td>\n",
|
||||
" <td>16,777,216.0</td>\n",
|
||||
" <td>67141632.0</td>\n",
|
||||
" <td>16384.0</td>\n",
|
||||
" <td>9.76%</td>\n",
|
||||
" <td>67158016.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>18</th>\n",
|
||||
" <td>classifier.5</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td>4,096.0</td>\n",
|
||||
" <td>4,096.0</td>\n",
|
||||
" <td>16384.0</td>\n",
|
||||
" <td>16384.0</td>\n",
|
||||
" <td>1.30%</td>\n",
|
||||
" <td>32768.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>19</th>\n",
|
||||
" <td>classifier.6</td>\n",
|
||||
" <td>4096</td>\n",
|
||||
" <td>1000</td>\n",
|
||||
" <td>4097000.0</td>\n",
|
||||
" <td>0.00</td>\n",
|
||||
" <td>8,191,000.0</td>\n",
|
||||
" <td>4,096,000.0</td>\n",
|
||||
" <td>16404384.0</td>\n",
|
||||
" <td>4000.0</td>\n",
|
||||
" <td>2.59%</td>\n",
|
||||
" <td>16408384.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>total</th>\n",
|
||||
" <td></td>\n",
|
||||
" <td></td>\n",
|
||||
" <td></td>\n",
|
||||
" <td>61100840.0</td>\n",
|
||||
" <td>4.15</td>\n",
|
||||
" <td>1,429,567,512.0</td>\n",
|
||||
" <td>715,543,488.0</td>\n",
|
||||
" <td>16404384.0</td>\n",
|
||||
" <td>4000.0</td>\n",
|
||||
" <td>100.00%</td>\n",
|
||||
" <td>253606976.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)\n",
|
||||
"0 features.0 3 224 224 64 55 55 23296.0 0.74 140,553,600.0 70,470,400.0 695296.0 774400.0 9.08% 1469696.0\n",
|
||||
"1 features.1 64 55 55 64 55 55 0.0 0.74 193,600.0 193,600.0 774400.0 774400.0 2.59% 1548800.0\n",
|
||||
"2 features.2 64 55 55 64 27 27 0.0 0.18 373,248.0 193,600.0 774400.0 186624.0 6.48% 961024.0\n",
|
||||
"3 features.3 64 27 27 192 27 27 307392.0 0.53 447,897,600.0 224,088,768.0 1416192.0 559872.0 10.38% 1976064.0\n",
|
||||
"4 features.4 192 27 27 192 27 27 0.0 0.53 139,968.0 139,968.0 559872.0 559872.0 1.30% 1119744.0\n",
|
||||
"5 features.5 192 27 27 192 13 13 0.0 0.12 259,584.0 139,968.0 559872.0 129792.0 3.89% 689664.0\n",
|
||||
"6 features.6 192 13 13 384 13 13 663936.0 0.25 224,280,576.0 112,205,184.0 2785536.0 259584.0 5.19% 3045120.0\n",
|
||||
"7 features.7 384 13 13 384 13 13 0.0 0.25 64,896.0 64,896.0 259584.0 259584.0 0.00% 519168.0\n",
|
||||
"8 features.8 384 13 13 256 13 13 884992.0 0.17 299,040,768.0 149,563,648.0 3799552.0 173056.0 10.37% 3972608.0\n",
|
||||
"9 features.9 256 13 13 256 13 13 0.0 0.17 43,264.0 43,264.0 173056.0 173056.0 1.30% 346112.0\n",
|
||||
"10 features.10 256 13 13 256 13 13 590080.0 0.17 199,360,512.0 99,723,520.0 2533376.0 173056.0 11.67% 2706432.0\n",
|
||||
"11 features.11 256 13 13 256 13 13 0.0 0.17 43,264.0 43,264.0 173056.0 173056.0 0.00% 346112.0\n",
|
||||
"12 features.12 256 13 13 256 6 6 0.0 0.04 73,728.0 43,264.0 173056.0 36864.0 1.30% 209920.0\n",
|
||||
"13 classifier.0 9216 9216 0.0 0.04 0.0 0.0 0.0 0.0 0.00% 0.0\n",
|
||||
"14 classifier.1 9216 4096 37752832.0 0.02 75,493,376.0 37,748,736.0 151048192.0 16384.0 22.82% 151064576.0\n",
|
||||
"15 classifier.2 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0\n",
|
||||
"16 classifier.3 4096 4096 0.0 0.02 0.0 0.0 0.0 0.0 0.00% 0.0\n",
|
||||
"17 classifier.4 4096 4096 16781312.0 0.02 33,550,336.0 16,777,216.0 67141632.0 16384.0 9.76% 67158016.0\n",
|
||||
"18 classifier.5 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 1.30% 32768.0\n",
|
||||
"19 classifier.6 4096 1000 4097000.0 0.00 8,191,000.0 4,096,000.0 16404384.0 4000.0 2.59% 16408384.0\n",
|
||||
"total 61100840.0 4.15 1,429,567,512.0 715,543,488.0 16404384.0 4000.0 100.00% 253606976.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tw.model_stats(alexnet_model, [1, 3, 224, 224])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import setuptools
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setuptools.setup(
|
||||
name="tensorwatch",
|
||||
version="0.6.0",
|
||||
author="Shital Shah",
|
||||
author_email="shitals@microsoft.com",
|
||||
description="Interactive Realtime Debugging and Visualization for AI",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/sytelus/tensorwatch",
|
||||
packages=setuptools.find_packages(),
|
||||
license='MIT',
|
||||
classifiers=(
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
),
|
||||
install_requires=[
|
||||
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat' # , 'receptivefield'
|
||||
]
|
||||
)
|
|
@ -0,0 +1,175 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" DefaultTargets="Build">
|
||||
<PropertyGroup>
|
||||
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
|
||||
<SchemaVersion>2.0</SchemaVersion>
|
||||
<ProjectGuid>{cc8abc7f-ede1-4e13-b6b7-0041a5ec66a7}</ProjectGuid>
|
||||
<ProjectHome />
|
||||
<StartupFile>tensorwatch\v2\zmq_watcher_server.py</StartupFile>
|
||||
<SearchPath />
|
||||
<WorkingDirectory>.</WorkingDirectory>
|
||||
<OutputPath>.</OutputPath>
|
||||
<ProjectTypeGuids>{888888a0-9f3d-457c-b088-3a5042f75d52}</ProjectTypeGuids>
|
||||
<LaunchProvider>Standard Python launcher</LaunchProvider>
|
||||
<InterpreterId>Global|ContinuumAnalytics|Anaconda36-64</InterpreterId>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)' == 'Debug'" />
|
||||
<PropertyGroup Condition="'$(Configuration)' == 'Release'" />
|
||||
<PropertyGroup>
|
||||
<VisualStudioVersion Condition=" '$(VisualStudioVersion)' == '' ">10.0</VisualStudioVersion>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<Compile Include="tensorwatch\embeddings\__init__.py" />
|
||||
<Compile Include="tensorwatch\stream_union.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\file_stream.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\filtered_stream.py" />
|
||||
<Compile Include="tensorwatch\stream.py" />
|
||||
<Compile Include="tensorwatch\stream_factory.py" />
|
||||
<Compile Include="tensorwatch\visualizer.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\vis_base.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\saliency\lime\lime_base.py" />
|
||||
<Compile Include="tensorwatch\saliency\lime\lime_image.py" />
|
||||
<Compile Include="tensorwatch\saliency\lime\wrappers\scikit_image.py" />
|
||||
<Compile Include="tensorwatch\saliency\lime\wrappers\__init__.py" />
|
||||
<Compile Include="tensorwatch\saliency\lime\__init__.py" />
|
||||
<Compile Include="tensorwatch\pytorch_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\array_stream.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\data_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\imagenet_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\image_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\model_graph\hiddenlayer\ge.py" />
|
||||
<Compile Include="tensorwatch\model_graph\hiddenlayer\graph.py" />
|
||||
<Compile Include="tensorwatch\model_graph\hiddenlayer\pytorch_builder.py" />
|
||||
<Compile Include="tensorwatch\model_graph\hiddenlayer\tf_builder.py" />
|
||||
<Compile Include="tensorwatch\model_graph\hiddenlayer\transforms.py" />
|
||||
<Compile Include="tensorwatch\model_graph\hiddenlayer\__init__.py" />
|
||||
<Compile Include="tensorwatch\model_graph\torchstat_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\model_graph\__init__.py" />
|
||||
<Compile Include="tensorwatch\mpl\base_mpl_plot.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\evaler.py" />
|
||||
<Compile Include="tensorwatch\evaler_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\mpl\image_plot.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\lv_types.py" />
|
||||
<Compile Include="tensorwatch\mpl\line_plot.py" />
|
||||
<Compile Include="tensorwatch\mpl\__init__.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\plotly\base_plotly_plot.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\plotly\line_plot.py" />
|
||||
<Compile Include="tensorwatch\plotly\embeddings_plot.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\plotly\__init__.py" />
|
||||
<Compile Include="tensorwatch\receptive_field\rf_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\repeated_timer.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\embeddings\tsne_utils.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\saliency\backprop.py" />
|
||||
<Compile Include="tensorwatch\saliency\deeplift.py" />
|
||||
<Compile Include="tensorwatch\saliency\epsilon_lrp.py" />
|
||||
<Compile Include="tensorwatch\saliency\gradcam.py" />
|
||||
<Compile Include="tensorwatch\saliency\inverter_util.py" />
|
||||
<Compile Include="tensorwatch\saliency\lime_image_explainer.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\saliency\occlusion.py" />
|
||||
<Compile Include="tensorwatch\saliency\saliency.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\saliency\__init__.py" />
|
||||
<Compile Include="tensorwatch\text_vis.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\utils.py" />
|
||||
<Compile Include="tensorwatch\watcher_base.py" />
|
||||
<Compile Include="tensorwatch\zmq_mgmt_stream.py">
|
||||
<SubType>Code</SubType>
|
||||
</Compile>
|
||||
<Compile Include="tensorwatch\zmq_wrapper.py" />
|
||||
<Compile Include="tensorwatch\zmq_stream.py" />
|
||||
<Compile Include="tensorwatch\watcher_client.py" />
|
||||
<Compile Include="tensorwatch\watcher.py" />
|
||||
<Compile Include="tensorwatch\__init__.py" />
|
||||
<Compile Include="setup.py" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Folder Include="tensorwatch" />
|
||||
<Folder Include="tensorwatch\saliency\lime\wrappers\" />
|
||||
<Folder Include="tensorwatch\model_graph\" />
|
||||
<Folder Include="tensorwatch\model_graph\hiddenlayer\" />
|
||||
<Folder Include="tensorwatch\model_graph\hiddenlayer\__pycache__\" />
|
||||
<Folder Include="tensorwatch\mpl\" />
|
||||
<Folder Include="tensorwatch\embeddings\" />
|
||||
<Folder Include="tensorwatch\saliency\lime\" />
|
||||
<Folder Include="tensorwatch\receptive_field\" />
|
||||
<Folder Include="tensorwatch\plotly\" />
|
||||
<Folder Include="tensorwatch\saliency\" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<InterpreterReference Include="Global|ContinuumAnalytics|Anaconda36-64" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Content Include=".gitignore" />
|
||||
<Content Include=".pylintrc">
|
||||
<SubType>Code</SubType>
|
||||
</Content>
|
||||
<Content Include="CONTRIBUTING.md">
|
||||
<SubType>Code</SubType>
|
||||
</Content>
|
||||
<Content Include="install_jupyterlab.bat" />
|
||||
<Content Include="LICENSE.TXT">
|
||||
<SubType>Code</SubType>
|
||||
</Content>
|
||||
<Content Include="NOTICE.md">
|
||||
<SubType>Code</SubType>
|
||||
</Content>
|
||||
<Content Include="README.md" />
|
||||
<Content Include="tensorwatch\model_graph\hiddenlayer\README.md">
|
||||
<SubType>Code</SubType>
|
||||
</Content>
|
||||
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\ge.cpython-36.pyc" />
|
||||
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\graph.cpython-36.pyc" />
|
||||
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\pytorch_builder.cpython-36.pyc" />
|
||||
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\transforms.cpython-36.pyc" />
|
||||
<Content Include="tensorwatch\model_graph\hiddenlayer\__pycache__\__init__.cpython-36.pyc" />
|
||||
<Content Include="tensorwatch\saliency\README.md">
|
||||
<SubType>Code</SubType>
|
||||
</Content>
|
||||
<Content Include="TODO.md" />
|
||||
<Content Include="update_package.bat" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)\Python Tools\Microsoft.PythonTools.targets" />
|
||||
</Project>
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||
# Visual Studio 15
|
||||
VisualStudioVersion = 15.0.27428.2043
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "tensorwatch", "tensorwatch.pyproj", "{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
Release|Any CPU = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(ProjectConfigurationPlatforms) = postSolution
|
||||
{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{CC8ABC7F-EDE1-4E13-B6B7-0041A5EC66A7}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
SolutionGuid = {99E7AEC7-2CDE-48C8-B98B-4E28E4F840B6}
|
||||
EndGlobalSection
|
||||
EndGlobal
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Iterable, Sequence, Union
|
||||
|
||||
from .watcher_client import WatcherClient
|
||||
from .watcher import Watcher
|
||||
from .watcher_base import WatcherBase
|
||||
|
||||
from .text_vis import TextVis
|
||||
from .plotly import EmbeddingsPlot
|
||||
from .mpl import LinePlot, ImagePlot
|
||||
from .visualizer import Visualizer
|
||||
|
||||
from .stream import Stream
|
||||
from .array_stream import ArrayStream
|
||||
from .lv_types import ImagePlotItem, VisParams
|
||||
from . import utils
|
||||
|
||||
###### Import methods for tw namespace #########
|
||||
#from .receptive_field.rf_utils import plot_receptive_field, plot_grads_at
|
||||
from .embeddings.tsne_utils import get_tsne_components
|
||||
from .model_graph.torchstat_utils import model_stats
|
||||
from .image_utils import show_image, open_image, img2pyt, linear_to_2d, plt_loop
|
||||
from .data_utils import pyt_ds2list, sample_by_class, col2array, search_similar
|
||||
|
||||
|
||||
|
||||
def draw_model(model, input_shape=None, orientation='TB'): #orientation = 'LR' for landscpe
|
||||
from .model_graph.hiddenlayer import graph
|
||||
g = graph.build_graph(model, input_shape, orientation=orientation)
|
||||
return g
|
||||
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .stream import Stream
|
||||
from .lv_types import StreamItem
|
||||
import uuid
|
||||
|
||||
class ArrayStream(Stream):
|
||||
def __init__(self, array, stream_name:str=None, console_debug:bool=False):
|
||||
super(ArrayStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
|
||||
|
||||
self.stream_name = stream_name
|
||||
self.array = array
|
||||
self.creator_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
def load(self, from_stream:'Stream'=None):
|
||||
if self.array is not None:
|
||||
stream_item = StreamItem(item_index=0, value=self.array,
|
||||
stream_name=self.stream_name, creator_id=self.creator_id, stream_index=0)
|
||||
self.write(stream_item)
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import random
|
||||
import scipy.spatial.distance
|
||||
import heapq
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def pyt_tensor2np(pyt_tensor):
|
||||
if pyt_tensor is None:
|
||||
return None
|
||||
if isinstance(pyt_tensor, torch.Tensor):
|
||||
n = pyt_tensor.data.cpu().numpy()
|
||||
if len(n.shape) == 1:
|
||||
return n[0]
|
||||
else:
|
||||
return n
|
||||
elif isinstance(pyt_tensor, np.ndarray):
|
||||
return pyt_tensor
|
||||
else:
|
||||
return np.array(pyt_tensor)
|
||||
|
||||
def pyt_tuple2np(pyt_tuple):
|
||||
return tuple((pyt_tensor2np(t) for t in pyt_tuple))
|
||||
|
||||
def pyt_ds2list(pyt_ds, count=None):
|
||||
count = count or len(pyt_ds)
|
||||
return [pyt_tuple2np(t) for t, c in zip(pyt_ds, range(count))]
|
||||
|
||||
def sample_by_class(data, n_samples, class_col=1, shuffle=True):
|
||||
if shuffle:
|
||||
random.shuffle(data)
|
||||
samples = {}
|
||||
for i, t in enumerate(data):
|
||||
cls = t[class_col]
|
||||
if cls not in samples:
|
||||
samples[cls] = []
|
||||
if len(samples[cls]) < n_samples:
|
||||
samples[cls].append(data[i])
|
||||
samples = sum(samples.values(), [])
|
||||
return samples
|
||||
|
||||
def col2array(dataset, col):
|
||||
return [row[col] for row in dataset]
|
||||
|
||||
def search_similar(inputs, compare_to, algorithm='euclidean', topk=5, invert_score=True):
|
||||
all_scores = scipy.spatial.distance.cdist(inputs, compare_to, algorithm)
|
||||
all_results = []
|
||||
for input_val, scores in zip(inputs, all_scores):
|
||||
result = []
|
||||
for i, (score, data) in enumerate(zip(scores, compare_to)):
|
||||
if invert_score:
|
||||
score = 1/(score + 1.0E-6)
|
||||
if len(result) < topk:
|
||||
heapq.heappush(result, (score, (i, input_val, data)))
|
||||
else:
|
||||
heapq.heappushpop(result, (score, (i, input_val, data)))
|
||||
all_results.append(result)
|
||||
return all_results
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
def _standardize_data(data, col, whitten, flatten):
|
||||
if col is not None:
|
||||
data = data[col]
|
||||
|
||||
#TODO: enable auto flattening
|
||||
#if data is tensor then flatten it first
|
||||
#if flatten and len(data) > 0 and hasattr(data[0], 'shape') and \
|
||||
# utils.has_method(data[0], 'reshape'):
|
||||
|
||||
# data = [d.reshape((-1,)) for d in data]
|
||||
|
||||
if whitten:
|
||||
data = StandardScaler().fit_transform(data)
|
||||
return data
|
||||
|
||||
def get_tsne_components(data, features_col=0, labels_col=1, whitten=True, n_components=3, perplexity=20, flatten=True, for_plot=True):
|
||||
features = _standardize_data(data, features_col, whitten, flatten)
|
||||
tsne = TSNE(n_components=n_components, perplexity=perplexity)
|
||||
tsne_results = tsne.fit_transform(features)
|
||||
|
||||
if for_plot:
|
||||
comps = tsne_results.tolist()
|
||||
labels = data[labels_col]
|
||||
for i, item in enumerate(comps):
|
||||
item.append(None) # annotation
|
||||
item.append(str(labels[i])) # text
|
||||
item.append(labels[i]) # color
|
||||
return comps
|
||||
return tsne_results
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import threading
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from .lv_types import EventVars
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
# pylint: disable=unused-import
|
||||
from functools import *
|
||||
from itertools import *
|
||||
from statistics import *
|
||||
import numpy as np
|
||||
from .evaler_utils import *
|
||||
|
||||
class Evaler:
|
||||
class EvalReturn:
|
||||
def __init__(self, result=None, is_valid=False, exception=None):
|
||||
self.result, self.exception, self.is_valid = \
|
||||
result, exception, is_valid
|
||||
def reset(self):
|
||||
self.result, self.exception, self.is_valid = \
|
||||
None, None, False
|
||||
|
||||
class PostableIterator:
|
||||
def __init__(self, eval_wait):
|
||||
self.eval_wait = eval_wait
|
||||
self.post_wait = threading.Event()
|
||||
self.event_vars, self.ended = None, None # define attributes in init
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.event_vars, self.ended = None, False
|
||||
self.post_wait.clear()
|
||||
|
||||
def abort(self):
|
||||
self.ended = True
|
||||
self.post_wait.set()
|
||||
|
||||
def post(self, event_vars:EventVars=None, ended=False):
|
||||
self.event_vars, self.ended = event_vars, ended
|
||||
self.post_wait.set()
|
||||
|
||||
def get_vals(self):
|
||||
while True:
|
||||
self.post_wait.wait()
|
||||
self.post_wait.clear()
|
||||
if self.ended:
|
||||
break
|
||||
else:
|
||||
yield self.event_vars
|
||||
# below will cause result=None, is_valid=False when
|
||||
# expression has reduce
|
||||
self.eval_wait.set()
|
||||
|
||||
def __init__(self, expr):
|
||||
self.eval_wait = threading.Event()
|
||||
self.reset_wait = threading.Event()
|
||||
self.g = Evaler.PostableIterator(self.eval_wait)
|
||||
self.expr = expr
|
||||
self.eval_return, self.continue_thread = None, None # define in __init__
|
||||
self.reset()
|
||||
|
||||
self.th = threading.Thread(target=self._runner, daemon=True, name='evaler')
|
||||
self.th.start()
|
||||
self.running = True
|
||||
|
||||
def reset(self):
|
||||
self.g.reset()
|
||||
self.eval_wait.clear()
|
||||
self.reset_wait.clear()
|
||||
self.eval_return = Evaler.EvalReturn()
|
||||
self.continue_thread = True
|
||||
|
||||
def _runner(self):
|
||||
while True:
|
||||
# this var will be used by eval
|
||||
l = self.g.get_vals() # pylint: disable=unused-variable
|
||||
try:
|
||||
result = eval(self.expr) # pylint: disable=eval-used
|
||||
if isinstance(result, Iterator):
|
||||
for item in result:
|
||||
self.eval_return = Evaler.EvalReturn(item, True)
|
||||
else:
|
||||
self.eval_return = Evaler.EvalReturn(result, True)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
print(ex, file=sys.stderr)
|
||||
self.eval_return = Evaler.EvalReturn(None, True, ex)
|
||||
self.eval_wait.set()
|
||||
self.reset_wait.wait()
|
||||
if not self.continue_thread:
|
||||
break
|
||||
self.reset()
|
||||
self.running = False
|
||||
utils.debug_log('eval runner ended!')
|
||||
|
||||
def abort(self):
|
||||
utils.debug_log('Evaler Aborted')
|
||||
self.continue_thread = False
|
||||
self.g.abort()
|
||||
self.eval_wait.set()
|
||||
self.reset_wait.set()
|
||||
|
||||
def post(self, event_vars:EventVars=None, ended=False, continue_thread=True):
|
||||
if not self.running:
|
||||
utils.debug_log('post was called when Evaler is not running')
|
||||
return None, False
|
||||
self.eval_return.reset()
|
||||
self.g.post(event_vars, ended)
|
||||
self.eval_wait.wait()
|
||||
self.eval_wait.clear()
|
||||
# save result before it would get reset
|
||||
eval_return = self.eval_return
|
||||
self.reset_wait.set()
|
||||
self.continue_thread = continue_thread
|
||||
if isinstance(eval_return.result, Iterator):
|
||||
eval_return.result = list(eval_return.result)
|
||||
return eval_return
|
||||
|
||||
def join(self):
|
||||
self.th.join()
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from . import utils
|
||||
from .lv_types import ImagePlotItem
|
||||
from collections import OrderedDict
|
||||
from itertools import groupby, islice
|
||||
|
||||
def skip_mod(mod, g):
|
||||
for index, item in enumerate(g):
|
||||
if index % mod == 0:
|
||||
yield item
|
||||
|
||||
# sort keys, group by key, apply val function to each value in group, aggregate values
|
||||
def groupby2(l, key=lambda x:x, val=lambda x:x, agg=lambda x:x, sort=True):
|
||||
if sort:
|
||||
l = sorted(l, key=key)
|
||||
grp = ((k,v) for k,v in groupby(l, key=key))
|
||||
valx = ((k, (val(x) for x in v)) for k,v in grp)
|
||||
aggx = ((k, agg(v)) for k,v in valx)
|
||||
return aggx
|
||||
|
||||
# aggregate weights or biases, use p2v to transform tensor to scaler
|
||||
def agg_params(model, p2v, weight_or_bias=True):
|
||||
for i, (n, p) in enumerate(model.named_parameters()):
|
||||
if p.requires_grad:
|
||||
is_bias = 'bias' in n
|
||||
if (weight_or_bias and not is_bias) or (not weight_or_bias and is_bias):
|
||||
yield i, p2v(p), n
|
||||
|
||||
# use this for image to class problems
|
||||
def pyt_img_class_out_xform(item): # (net_input, target, in_weight, out_weight, net_output, loss)
|
||||
net_input = item[0].data.cpu().numpy()
|
||||
# turn log-probabilities in to (max log-probability, class ID)
|
||||
net_output = torch.max(item[4],0)
|
||||
# return image, text
|
||||
return ImagePlotItem((net_input,), title="T:{},Pb:{:.2f},pd:{:.2f},L:{:.2f}".\
|
||||
format(item[1], math.exp(net_output[0]), net_output[1], item[5]))
|
||||
|
||||
# use this for image to image translation problems
|
||||
def pyt_img_img_out_xform(item): # (net_input, target, in_weight, out_weight, net_output, loss)
|
||||
net_input = item[0].data.cpu().numpy()
|
||||
net_output = item[4].data.cpu().numpy()
|
||||
target = item[1].data.cpu().numpy()
|
||||
tar_weight = item[3].data.cpu().numpy() if item[3] is not None else None
|
||||
|
||||
# return in-image, text, out-image, target-image
|
||||
return ImagePlotItem((net_input, target, net_output, tar_weight),
|
||||
title="L:{:.2f}, S:{:.2f}, {:.2f}-{:.2f}, {:.2f}-{:.2f}".\
|
||||
format(item[5], net_input.std(), net_input.min(), net_input.max(), net_output.min(), net_output.max()))
|
||||
|
||||
def cols2rows(batch):
|
||||
in_weight = utils.fill_like(batch.in_weight, batch.input)
|
||||
tar_weight = utils.fill_like(batch.tar_weight, batch.input)
|
||||
losses = [l.mean() for l in batch.loss_all]
|
||||
targets = [t.item() if len(t.shape)==0 else t for t in batch.target]
|
||||
|
||||
return list(zip(batch.input, targets, in_weight, tar_weight,
|
||||
batch.output, losses))
|
||||
|
||||
def top(l, topk=1, order='dsc', group_key=None, out_xform=lambda x:x):
|
||||
min_result = OrderedDict()
|
||||
for event_vars in l:
|
||||
batch = cols2rows(event_vars.batch)
|
||||
# by default group items in batch by target value
|
||||
group_key = group_key or (lambda b: b[1]) #target
|
||||
by_class = groupby2(batch, group_key)
|
||||
|
||||
# pick the first values for each class after sorting by loss
|
||||
reverse, sf, ls_cmp = True, lambda b: b[5], False
|
||||
if order=='asc':
|
||||
reverse = False
|
||||
elif order=='rnd':
|
||||
ls_cmp, sf = True, lambda t: random.random()
|
||||
elif order=='dsc':
|
||||
pass
|
||||
else:
|
||||
raise ValueError('order parameter must be dsc, asc or rnd')
|
||||
|
||||
# sort grouped objects by sort function then
|
||||
# take first k values in each group
|
||||
# create (key, topk-sized list) tuples for each group
|
||||
s = ((k, list(islice(sorted(v, key=sf, reverse=reverse), topk))) \
|
||||
for k,v in by_class)
|
||||
|
||||
# for each group, maintain global k values for each keys
|
||||
changed = False
|
||||
for k,va in s:
|
||||
# get global k values for this key, if it doesn't exist
|
||||
# then put current in global min
|
||||
cur_min = min_result.get(k, None)
|
||||
if cur_min is None:
|
||||
min_result[k] = va
|
||||
changed = True
|
||||
else:
|
||||
# for each k value in this group, we will compare
|
||||
for i, (va_k, cur_k) in enumerate(zip(va, cur_min)):
|
||||
if ls_cmp or (reverse and cur_k[5] < va_k[5]) \
|
||||
or (not reverse and cur_k[5] > va_k[5]):
|
||||
cur_min[i] = va[i]
|
||||
changed = True
|
||||
if changed:
|
||||
# flatten each list in dictionary value
|
||||
yield (out_xform(t) for va in min_result.values() for t in va)
|
||||
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .stream import Stream
|
||||
import pickle
|
||||
from typing import Any
|
||||
from . import utils
|
||||
|
||||
class FileStream(Stream):
|
||||
def __init__(self, for_write:bool, file_name:str, stream_name:str=None, console_debug:bool=False):
|
||||
super(FileStream, self).__init__(stream_name=stream_name or file_name, console_debug=console_debug)
|
||||
|
||||
self._file = open(file_name, 'wb' if for_write else 'rb')
|
||||
self.file_name = file_name
|
||||
self.for_write = for_write
|
||||
utils.debug_log('FileStream started', self.file_name, verbosity=1)
|
||||
|
||||
def close(self):
|
||||
if not self._file.closed:
|
||||
self._file.close()
|
||||
self._file = None
|
||||
utils.debug_log('FileStream is closed', self.file_name, verbosity=1)
|
||||
super(FileStream, self).close()
|
||||
|
||||
def write(self, val:Any, from_stream:'Stream'=None):
|
||||
if self.for_write:
|
||||
pickle.dump(val, self._file)
|
||||
super(FileStream, self).write(val)
|
||||
|
||||
def load(self, from_stream:'Stream'=None):
|
||||
if self.for_write:
|
||||
raise IOError('Cannot use load() call because FileSteam is opened with for_write=True')
|
||||
if self._file is not None:
|
||||
while not utils.is_eof(self._file):
|
||||
stream_item = pickle.load(self._file)
|
||||
self.write(stream_item)
|
||||
super(FileStream, self).load()
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .stream import Stream
|
||||
from typing import Callable, Any
|
||||
|
||||
class FilteredStream(Stream):
|
||||
def __init__(self, source_stream:Stream, filter_expr:Callable, stream_name:str=None,
|
||||
console_debug:bool=False)->None:
|
||||
|
||||
stream_name = stream_name or '{}|{}'.format(source_stream.stream_name, str(filter_expr))
|
||||
super(FilteredStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
|
||||
self.subscribe(source_stream)
|
||||
self.filter_expr = filter_expr
|
||||
|
||||
def write(self, val:Any, from_stream:'Stream'=None):
|
||||
result, is_valid = self.filter_expr(val) \
|
||||
if self.filter_expr is not None \
|
||||
else (val, True)
|
||||
|
||||
if is_valid:
|
||||
return super(FilteredStream, self).write(result)
|
||||
# else ignore this call
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
||||
def guess_image_dims(img):
|
||||
if len(img.shape) == 1:
|
||||
# assume 2D monochrome (MNIST)
|
||||
width = height = round(math.sqrt(img.shape[0]))
|
||||
if width*height != img.shape[0]:
|
||||
# assume 3 channels (CFAR, ImageNet)
|
||||
width = height = round(math.sqrt(img.shape[0] / 3))
|
||||
if width*height*3 != img.shape[0]:
|
||||
raise ValueError("Cannot guess image dimensions for linearized pixels")
|
||||
return (3, height, width)
|
||||
return (1, height, width)
|
||||
return img.shape
|
||||
|
||||
def to_imshow_array(img, width=None, height=None):
|
||||
# array from Pytorch has shape: [[channels,] height, width]
|
||||
# image needed for imshow needs: [height, width, channels]
|
||||
|
||||
if img is not None:
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
if len(img.shape) >= 2:
|
||||
return img # img is already compatible to imshow
|
||||
|
||||
# force max 3 dimensions
|
||||
if len(img.shape) > 3:
|
||||
# TODO allow config
|
||||
# select first one in batch
|
||||
img = img[0:1,:,:]
|
||||
|
||||
if len(img.shape) == 1: # linearized pixels typically used for MLPs
|
||||
if not(width and height):
|
||||
# pylint: disable=unused-variable
|
||||
channels, height, width = guess_image_dims(img)
|
||||
img = img.reshape((-1, height, width))
|
||||
|
||||
if len(img.shape) == 3:
|
||||
if img.shape[0] == 1: # single channel images
|
||||
img = img.squeeze(0)
|
||||
else:
|
||||
img = np.swapaxes(img, 0, 2) # transpose H,W for imshow
|
||||
img = np.swapaxes(img, 0, 1)
|
||||
elif len(img.shape) == 2:
|
||||
img = np.swapaxes(img, 0, 1) # transpose H,W for imshow
|
||||
else: #zero dimensions
|
||||
img = None
|
||||
|
||||
return img
|
||||
|
||||
#width_dim=1 for imshow, 2 for pytorch arrays
|
||||
def stitch_horizontal(images, width_dim=1):
|
||||
return np.concatenate(images, axis=width_dim)
|
||||
|
||||
def _resize_image(img, size=None):
|
||||
if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
|
||||
if size is None:
|
||||
# make guess for 1-dim tensors
|
||||
h = int(math.sqrt(img.shape[0]))
|
||||
w = int(img.shape[0] / h)
|
||||
size = h,w
|
||||
img = np.reshape(img, size)
|
||||
return img
|
||||
|
||||
def show_image(img, size=None, alpha=None, cmap=None,
|
||||
img2=None, size2=None, alpha2=None, cmap2=None, ax=None):
|
||||
img =_resize_image(img, size)
|
||||
img2 =_resize_image(img2, size2)
|
||||
|
||||
(ax or plt).imshow(img, alpha=alpha, cmap=cmap)
|
||||
|
||||
if img2 is not None:
|
||||
(ax or plt).imshow(img2, alpha=alpha2, cmap=cmap2)
|
||||
|
||||
return ax or plt.show()
|
||||
|
||||
# convert_mode param is mode: https://pillow.readthedocs.io/en/5.1.x/handbook/concepts.html#modes
|
||||
# use convert_mode='RGB' to force 3 channels
|
||||
def open_image(path, resize=None, resample=Image.ANTIALIAS, convert_mode=None):
|
||||
img = Image.open(path)
|
||||
if resize is not None:
|
||||
img = img.resize(resize, resample)
|
||||
if convert_mode is not None:
|
||||
img = img.convert(convert_mode)
|
||||
return img
|
||||
|
||||
def img2pyt(img, add_batch_dim=True, resize=None):
|
||||
ts = []
|
||||
if resize is not None:
|
||||
ts.append(transforms.RandomResizedCrop(resize))
|
||||
ts.append(transforms.ToTensor())
|
||||
img_pyt = transforms.Compose(ts)(img)
|
||||
if add_batch_dim:
|
||||
img_pyt.unsqueeze_(0)
|
||||
return img_pyt
|
||||
|
||||
def linear_to_2d(img, size=None):
|
||||
if size is not None or (hasattr(img, 'shape') and len(img.shape) == 1):
|
||||
if size is None:
|
||||
# make guess for 1-dim tensors
|
||||
h = int(math.sqrt(img.shape[0]))
|
||||
w = int(img.shape[0] / h)
|
||||
size = h,w
|
||||
img = np.reshape(img, size)
|
||||
return img
|
||||
|
||||
def stack_images(imgs):
|
||||
return np.hstack(imgs)
|
||||
|
||||
def plt_loop(sleep_time=1, plt_pause=0.01):
|
||||
plt.ion()
|
||||
plt.show(block=False)
|
||||
while(not plt.waitforbuttonpress(plt_pause)):
|
||||
#plt.draw()
|
||||
plt.pause(plt_pause)
|
||||
time.sleep(sleep_time)
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torchvision import transforms
|
||||
from . import pytorch_utils
|
||||
import json
|
||||
|
||||
def get_image_transform():
|
||||
transf = transforms.Compose([ #TODO: cache these transforms?
|
||||
get_resize_transform(),
|
||||
transforms.ToTensor(),
|
||||
get_normalize_transform()
|
||||
])
|
||||
|
||||
return transf
|
||||
|
||||
def get_resize_transform():
|
||||
return transforms.Resize((224, 224))
|
||||
|
||||
def get_normalize_transform():
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
def predict(model, images, image_transform=None, device=None):
|
||||
logits = pytorch_utils.batch_predict(model, images,
|
||||
input_transform=image_transform or get_image_transform(), device=device)
|
||||
probs = pytorch_utils.logits2probabilities(logits) #2-dim array, one column per class, one row per input
|
||||
return probs
|
||||
|
||||
_imagenet_labels = None
|
||||
def get_imagenet_labels():
|
||||
# pylint: disable=global-statement
|
||||
global _imagenet_labels
|
||||
_imagenet_labels = _imagenet_labels or ImagenetLabels()
|
||||
return _imagenet_labels
|
||||
|
||||
def probabilities2classes(probs, topk=5):
|
||||
labels = get_imagenet_labels()
|
||||
top_probs = probs.topk(topk)
|
||||
# return (probability, class_id, class_label, class_code)
|
||||
return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \
|
||||
for p, c in zip(top_probs[0][0].detach().numpy(), top_probs[1][0].detach().numpy()))
|
||||
|
||||
class ImagenetLabels:
|
||||
def __init__(self, json_path='../../data/imagenet_class_index.json'):
|
||||
self._idx2label = []
|
||||
self._idx2cls = []
|
||||
self._cls2label = {}
|
||||
self._cls2idx = {}
|
||||
with open(json_path, "r") as read_file:
|
||||
class_json = json.load(read_file)
|
||||
self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))]
|
||||
self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))]
|
||||
self._cls2label = {class_json[str(k)][0]: class_json[str(k)][1] for k in range(len(class_json))}
|
||||
self._cls2idx = {class_json[str(k)][0]: k for k in range(len(class_json))}
|
||||
|
||||
def index2label_text(self, index):
|
||||
return self._idx2label[index]
|
||||
def index2label_code(self, index):
|
||||
return self._idx2cls[index]
|
||||
def label_code2label_text(self, label_code):
|
||||
return self._cls2label[label_code]
|
||||
def label_code2index(self, label_code):
|
||||
return self._cls2idx[label_code]
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import List, Callable, Any, Sequence
|
||||
from . import utils
|
||||
import uuid
|
||||
|
||||
|
||||
class EventVars:
|
||||
def __init__(self, globals_val, **vars_val):
|
||||
if globals_val is not None:
|
||||
for key in globals_val:
|
||||
setattr(self, key, globals_val[key])
|
||||
for key in vars_val:
|
||||
setattr(self, key, vars_val[key])
|
||||
|
||||
def __str__(self):
|
||||
sb = []
|
||||
for key in self.__dict__:
|
||||
val = self.__dict__[key]
|
||||
if utils.is_scalar(val):
|
||||
sb.append('{key}={value}'.format(key=key, value=val))
|
||||
else:
|
||||
sb.append('{key}="{value}"'.format(key=key, value=val))
|
||||
|
||||
return ', '.join(sb)
|
||||
|
||||
EventsVars = List[EventVars]
|
||||
|
||||
class StreamItem:
|
||||
def __init__(self, item_index:int, value:Any,
|
||||
stream_name:str, creator_id:str, stream_index:int,
|
||||
ended:bool=False, exception:Exception=None, stream_reset:bool=False):
|
||||
self.value = value
|
||||
self.exception = exception
|
||||
self.stream_name = stream_name
|
||||
self.item_index = item_index
|
||||
self.ended = ended
|
||||
self.creator_id = creator_id
|
||||
self.stream_index = stream_index
|
||||
self.stream_reset = stream_reset
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.__dict__)
|
||||
|
||||
EventEvalFunc = Callable[[EventsVars], StreamItem]
|
||||
|
||||
|
||||
class VisParams:
|
||||
def __init__(self, vis_type=None, host_vis=None,
|
||||
cell=None, title=None,
|
||||
clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None,
|
||||
images=None, images_reshape=None, width=None, height=None, vis_args=None, stream_vis_args=None)->None:
|
||||
self.vis_type=vis_type
|
||||
self.host_vis=host_vis,
|
||||
self.cell=cell
|
||||
self.title=title
|
||||
self.clear_after_end=clear_after_end
|
||||
self.clear_after_each=clear_after_each
|
||||
self.history_len=history_len
|
||||
self.dim_history=dim_history
|
||||
self.opacity=opacity
|
||||
self.images=images
|
||||
self.images_reshape=images_reshape
|
||||
self.width=width
|
||||
self.height=height
|
||||
self.vis_args=vis_args or {}
|
||||
self.stream_vis_args=stream_vis_args or {}
|
||||
|
||||
class StreamOpenRequest:
|
||||
def __init__(self, stream_name:str, devices:Sequence[str]=None,
|
||||
event_name:str='')->None:
|
||||
self.stream_name = stream_name or str(uuid.uuid4())
|
||||
self.devices = devices
|
||||
self.event_name = event_name
|
||||
|
||||
|
||||
class StreamCreateRequest:
|
||||
def __init__(self, stream_name:str, devices:Sequence[str]=None, event_name:str='',
|
||||
expr:str=None, throttle:float=None, vis_params:VisParams=None):
|
||||
self.event_name = event_name
|
||||
self.expr = expr
|
||||
self.stream_name = stream_name or str(uuid.uuid4())
|
||||
self.devices = devices
|
||||
self.vis_params = vis_params
|
||||
|
||||
# max throughput n Lenovo P50 laptop for MNIST
|
||||
# text console -> 0.1s
|
||||
# matplotlib line graph -> 0.5s
|
||||
self.throttle = throttle
|
||||
|
||||
class ClientServerRequest:
|
||||
def __init__(self, req_type:str, req_data:Any):
|
||||
self.req_type = req_type
|
||||
self.req_data = req_data
|
||||
|
||||
class CliSrvReqTypes:
|
||||
create_stream = 'CreateStream'
|
||||
del_stream = 'DeleteStream'
|
||||
|
||||
class StreamPlot:
|
||||
def __init__(self, stream, title, clear_after_end,
|
||||
clear_after_each, history_len, dim_history, opacity,
|
||||
index, stream_vis_args, last_update):
|
||||
self.stream = stream
|
||||
self.title, self.opacity = title, opacity
|
||||
self.clear_after_end, self.clear_after_each = clear_after_end, clear_after_each
|
||||
self.history_len, self.dim_history = history_len, dim_history
|
||||
self.index, self.stream_vis_args, self.last_update = index, stream_vis_args, last_update
|
||||
|
||||
class ImagePlotItem:
|
||||
# images are numpy array of shape [[channels,] height, width]
|
||||
def __init__(self, images=None, title=None, alpha=None, cmap=None):
|
||||
if not isinstance(images, tuple):
|
||||
images = (images,)
|
||||
self.images, self.alpha, self.cmap, self.title = images, alpha, cmap, title
|
||||
|
||||
class DefaultPorts:
|
||||
PubSub = 40859
|
||||
CliSrv = 41459
|
||||
|
||||
class PublisherTopics:
|
||||
StreamItem = 'StreamItem'
|
||||
ServerMgmt = 'ServerMgmt'
|
||||
|
||||
class ServerMgmtMsg:
|
||||
EventServerStart = 'ServerStart'
|
||||
def __init__(self, event_name:str, event_args:Any=None):
|
||||
self.event_name = event_name
|
||||
self.event_args = event_args
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Credits
|
||||
|
||||
Code in this folder is adopted from,
|
||||
|
||||
* https://github.com/waleedka/hiddenlayer
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
HiddenLayer
|
||||
|
||||
Implementation graph expressions to find nodes in a graph based on a pattern.
|
||||
|
||||
Written by Waleed Abdulla
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
|
||||
class GEParser():
|
||||
def __init__(self, text):
|
||||
self.index = 0
|
||||
self.text = text
|
||||
|
||||
def parse(self):
|
||||
return self.serial() or self.parallel() or self.expression()
|
||||
|
||||
def parallel(self):
|
||||
index = self.index
|
||||
expressions = []
|
||||
while len(expressions) == 0 or self.token("|"):
|
||||
e = self.expression()
|
||||
if not e:
|
||||
break
|
||||
expressions.append(e)
|
||||
if len(expressions) >= 2:
|
||||
return ParallelPattern(expressions)
|
||||
# No match. Reset index
|
||||
self.index = index
|
||||
|
||||
def serial(self):
|
||||
index = self.index
|
||||
expressions = []
|
||||
while len(expressions) == 0 or self.token(">"):
|
||||
e = self.expression()
|
||||
if not e:
|
||||
break
|
||||
expressions.append(e)
|
||||
|
||||
if len(expressions) >= 2:
|
||||
return SerialPattern(expressions)
|
||||
self.index = index
|
||||
|
||||
def expression(self):
|
||||
index = self.index
|
||||
|
||||
if self.token("("):
|
||||
e = self.serial() or self.parallel() or self.op()
|
||||
if e and self.token(")"):
|
||||
return e
|
||||
self.index = index
|
||||
e = self.op()
|
||||
return e
|
||||
|
||||
def op(self):
|
||||
t = self.re(r"\w+")
|
||||
if t:
|
||||
c = self.condition()
|
||||
return NodePattern(t, c)
|
||||
|
||||
def condition(self):
|
||||
# TODO: not implemented yet. This function is a placeholder
|
||||
index = self.index
|
||||
if self.token("["):
|
||||
c = self.token("1x1") or self.token("3x3")
|
||||
if c:
|
||||
if self.token("]"):
|
||||
return c
|
||||
self.index = index
|
||||
|
||||
def token(self, s):
|
||||
return self.re(r"\s*(" + re.escape(s) + r")\s*", 1)
|
||||
|
||||
def string(self, s):
|
||||
if s == self.text[self.index:self.index+len(s)]:
|
||||
self.index += len(s)
|
||||
return s
|
||||
|
||||
def re(self, regex, group=0):
|
||||
m = re.match(regex, self.text[self.index:])
|
||||
if m:
|
||||
self.index += len(m.group(0))
|
||||
return m.group(group)
|
||||
|
||||
|
||||
class NodePattern():
|
||||
def __init__(self, op, condition=None):
|
||||
self.op = op
|
||||
self.condition = condition # TODO: not implemented yet
|
||||
|
||||
def match(self, graph, node):
|
||||
if isinstance(node, list):
|
||||
return [], None
|
||||
if self.op == node.op:
|
||||
following = graph.outgoing(node)
|
||||
if len(following) == 1:
|
||||
following = following[0]
|
||||
return [node], following
|
||||
else:
|
||||
return [], None
|
||||
|
||||
|
||||
class SerialPattern():
|
||||
def __init__(self, patterns):
|
||||
self.patterns = patterns
|
||||
|
||||
def match(self, graph, node):
|
||||
all_matches = []
|
||||
for i, p in enumerate(self.patterns):
|
||||
matches, following = p.match(graph, node)
|
||||
if not matches:
|
||||
return [], None
|
||||
all_matches.extend(matches)
|
||||
if i < len(self.patterns) - 1:
|
||||
node = following # Might be more than one node
|
||||
return all_matches, following
|
||||
|
||||
|
||||
class ParallelPattern():
|
||||
def __init__(self, patterns):
|
||||
self.patterns = patterns
|
||||
|
||||
def match(self, graph, nodes):
|
||||
if not nodes:
|
||||
return [], None
|
||||
nodes = nodes if isinstance(nodes, list) else [nodes]
|
||||
# If a single node, assume we need to match with its siblings
|
||||
if len(nodes) == 1:
|
||||
nodes = graph.siblings(nodes[0])
|
||||
else:
|
||||
# Verify all nodes have the same parent or all have no parent
|
||||
parents = [graph.incoming(n) for n in nodes]
|
||||
matches = [set(p) == set(parents[0]) for p in parents[1:]]
|
||||
if not all(matches):
|
||||
return [], None
|
||||
|
||||
# TODO: If more nodes than patterns, we should consider
|
||||
# all permutations of the nodes
|
||||
if len(self.patterns) != len(nodes):
|
||||
return [], None
|
||||
|
||||
patterns = self.patterns.copy()
|
||||
nodes = nodes.copy()
|
||||
all_matches = []
|
||||
end_node = None
|
||||
for p in patterns:
|
||||
found = False
|
||||
for n in nodes:
|
||||
matches, following = p.match(graph, n)
|
||||
if matches:
|
||||
found = True
|
||||
nodes.remove(n)
|
||||
all_matches.extend(matches)
|
||||
# Verify all branches end in the same node
|
||||
if end_node:
|
||||
if end_node != following:
|
||||
return [], None
|
||||
else:
|
||||
end_node = following
|
||||
break
|
||||
if not found:
|
||||
return [], None
|
||||
return all_matches, end_node
|
||||
|
||||
|
|
@ -0,0 +1,402 @@
|
|||
"""
|
||||
HiddenLayer
|
||||
|
||||
Implementation of the Graph class. A framework independent directed graph to
|
||||
represent a neural network.
|
||||
|
||||
Written by Waleed Abdulla. Additions by Phil Ferriere.
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import re
|
||||
from random import getrandbits
|
||||
import inspect
|
||||
import numpy as np
|
||||
import html
|
||||
|
||||
THEMES = {
|
||||
"basic": {
|
||||
"background_color": "#FFFFFF",
|
||||
"fill_color": "#E8E8E8",
|
||||
"outline_color": "#000000",
|
||||
"font_color": "#000000",
|
||||
"font_name": "Times",
|
||||
"font_size": "10",
|
||||
"margin": "0,0",
|
||||
"padding": "1.0,0.5",
|
||||
},
|
||||
"blue": {
|
||||
"shape": "box",
|
||||
"background_color": "#FFFFFF",
|
||||
"fill_color": "#BCD6FC",
|
||||
"outline_color": "#7C96BC",
|
||||
"font_color": "#202020",
|
||||
"font_name": "Verdana",
|
||||
"font_size": "10",
|
||||
"margin": "0,0",
|
||||
"padding": "1.0",
|
||||
"layer_overrides": { #TODO: change names of these keys to same as dot params
|
||||
"ConvRelu": {"shape":"box", "fillcolor":"#A1C9F4"},
|
||||
"Conv": {"shape":"box", "fillcolor":"#FAB0E4"},
|
||||
"MaxPool": {"shape":"box", "fillcolor":"#8DE5A1"},
|
||||
"Constant": {"shape":"box", "fillcolor":"#FF9F9B"},
|
||||
"Shape": {"shape":"box", "fillcolor":"#D0BBFF"},
|
||||
"Gather": {"shape":"box", "fillcolor":"#DEBB9B"},
|
||||
"Unsqeeze": {"shape":"box", "fillcolor":"#CFCFCF"},
|
||||
"Sqeeze": {"shape":"box", "fillcolor":"#FFFEA3"},
|
||||
"Dropout": {"shape":"box", "fillcolor":"#B9F2F0"},
|
||||
"LinearRelu": {"shape":"box", "fillcolor":"#8DE5A1"},
|
||||
"Linear": {"shape":"box", "fillcolor":"#4878D0"},
|
||||
"Concat": {"shape":"box", "fillcolor":"#D0BBFF"},
|
||||
"Reshape": {"shape":"box", "fillcolor":"#FFFEA3"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
###########################################################################
|
||||
# Utility Functions
|
||||
###########################################################################
|
||||
|
||||
def detect_framework(value):
|
||||
# Get all base classes
|
||||
classes = inspect.getmro(value.__class__)
|
||||
for c in classes:
|
||||
if c.__module__.startswith("torch"):
|
||||
return "torch"
|
||||
elif c.__module__.startswith("tensorflow"):
|
||||
return "tensorflow"
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Node
|
||||
###########################################################################
|
||||
|
||||
class Node():
|
||||
"""Represents a framework-agnostic neural network layer in a directed graph."""
|
||||
|
||||
def __init__(self, uid, name, op, output_shape=None, params=None, combo_params=None):
|
||||
"""
|
||||
uid: unique ID for the layer that doesn't repeat in the computation graph.
|
||||
name: Name to display
|
||||
op: Framework-agnostic operation name.
|
||||
"""
|
||||
self.id = uid
|
||||
self.name = name # TODO: clarify the use of op vs name vs title
|
||||
self.op = op
|
||||
self.repeat = 1
|
||||
if output_shape:
|
||||
assert isinstance(output_shape, (tuple, list)),\
|
||||
"output_shape must be a tuple or list but received {}".format(type(output_shape))
|
||||
self.output_shape = output_shape
|
||||
self.params = params if params else {}
|
||||
self._caption = ""
|
||||
self.combo_params = combo_params
|
||||
|
||||
@property
|
||||
def title(self):
|
||||
# Default
|
||||
title = self.name or self.op
|
||||
|
||||
if self.op == 'Dropout':
|
||||
if 'ratio' in self.params:
|
||||
title += ' ' + str(self.params['ratio'])
|
||||
|
||||
if "kernel_shape" in self.params:
|
||||
# Kernel
|
||||
kernel = self.params["kernel_shape"]
|
||||
title += "x".join(map(str, kernel))
|
||||
if "stride" in self.params:
|
||||
stride = self.params["stride"]
|
||||
if np.unique(stride).size == 1:
|
||||
stride = stride[0]
|
||||
if stride != 1:
|
||||
title += "/s{}".format(str(stride))
|
||||
# # Transposed
|
||||
# if node.transposed:
|
||||
# name = "Transposed" + name
|
||||
return title
|
||||
|
||||
@property
|
||||
def caption(self):
|
||||
if self._caption:
|
||||
return self._caption
|
||||
|
||||
caption = ""
|
||||
|
||||
# Stride
|
||||
# if "stride" in self.params:
|
||||
# stride = self.params["stride"]
|
||||
# if np.unique(stride).size == 1:
|
||||
# stride = stride[0]
|
||||
# if stride != 1:
|
||||
# caption += "/{}".format(str(stride))
|
||||
return caption
|
||||
|
||||
def __repr__(self):
|
||||
args = (self.op, self.name, self.id, self.title, self.repeat)
|
||||
f = "<Node: op: {}, name: {}, id: {}, title: {}, repeat: {}"
|
||||
if self.output_shape:
|
||||
args += (str(self.output_shape),)
|
||||
f += ", shape: {:}"
|
||||
if self.params:
|
||||
args += (str(self.params),)
|
||||
f += ", params: {:}"
|
||||
f += ">"
|
||||
return f.format(*args)
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Graph
|
||||
###########################################################################
|
||||
|
||||
def build_graph(model=None, args=None, input_names=None,
|
||||
transforms="default", framework_transforms="default", orientation='TB'):
|
||||
# Initialize an empty graph
|
||||
g = Graph(orientation=orientation)
|
||||
|
||||
# Detect framwork
|
||||
framework = detect_framework(model)
|
||||
if framework == "torch":
|
||||
from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS
|
||||
import_graph(g, model, args)
|
||||
elif framework == "tensorflow":
|
||||
from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS
|
||||
import_graph(g, model)
|
||||
else:
|
||||
raise ValueError("`model` input param must be a PyTorch, TensorFlow, or Keras-with-TensorFlow-backend model.")
|
||||
|
||||
# Apply Transforms
|
||||
if framework_transforms:
|
||||
if framework_transforms == "default":
|
||||
framework_transforms = FRAMEWORK_TRANSFORMS
|
||||
for t in framework_transforms:
|
||||
g = t.apply(g)
|
||||
if transforms:
|
||||
if transforms == "default":
|
||||
from .transforms import SIMPLICITY_TRANSFORMS
|
||||
transforms = SIMPLICITY_TRANSFORMS
|
||||
for t in transforms:
|
||||
g = t.apply(g)
|
||||
return g
|
||||
|
||||
|
||||
class Graph():
|
||||
"""Tracks nodes and edges of a directed graph and supports basic operations on them."""
|
||||
|
||||
def __init__(self, model=None, args=None, input_names=None,
|
||||
transforms="default", framework_transforms="default",
|
||||
meaningful_ids=False, orientation='TB'):
|
||||
self.nodes = {}
|
||||
self.edges = []
|
||||
self.meaningful_ids = meaningful_ids # TODO
|
||||
self.theme = THEMES["blue"].copy()
|
||||
self.orientation = orientation
|
||||
|
||||
if model:
|
||||
# Detect framwork
|
||||
framework = detect_framework(model)
|
||||
if framework == "torch":
|
||||
from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS
|
||||
import_graph(self, model, args)
|
||||
elif framework == "tensorflow":
|
||||
from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS
|
||||
import_graph(self, model)
|
||||
|
||||
# Apply Transforms
|
||||
if framework_transforms:
|
||||
if framework_transforms == "default":
|
||||
framework_transforms = FRAMEWORK_TRANSFORMS
|
||||
for t in framework_transforms:
|
||||
t.apply(self)
|
||||
if transforms:
|
||||
if transforms == "default":
|
||||
from .transforms import SIMPLICITY_TRANSFORMS
|
||||
transforms = SIMPLICITY_TRANSFORMS
|
||||
for t in transforms:
|
||||
t.apply(self)
|
||||
|
||||
|
||||
def id(self, node):
|
||||
"""Returns a unique node identifier. If the node has an id
|
||||
attribute (preferred), it's used. Otherwise, the hash() is returned."""
|
||||
return node.id if hasattr(node, "id") else hash(node)
|
||||
|
||||
def add_node(self, node):
|
||||
id = self.id(node)
|
||||
# assert(id not in self.nodes)
|
||||
self.nodes[id] = node
|
||||
|
||||
def add_edge(self, node1, node2, label=None):
|
||||
# If the edge is already present, don't add it again.
|
||||
# TODO: If an edge exists with a different label, still don't add it again.
|
||||
edge = (self.id(node1), self.id(node2), label)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def add_edge_by_id(self, vid1, vid2, label=None):
|
||||
self.edges.append((vid1, vid2, label))
|
||||
|
||||
def outgoing(self, node):
|
||||
"""Returns nodes connecting out of the given node (or list of nodes)."""
|
||||
nodes = node if isinstance(node, list) else [node]
|
||||
node_ids = [self.id(n) for n in nodes]
|
||||
# Find edges outgoing from this group but not incoming to it
|
||||
outgoing = [self[e[1]] for e in self.edges
|
||||
if e[0] in node_ids and e[1] not in node_ids]
|
||||
return outgoing
|
||||
|
||||
def incoming(self, node):
|
||||
"""Returns nodes connecting to the given node (or list of nodes)."""
|
||||
nodes = node if isinstance(node, list) else [node]
|
||||
node_ids = [self.id(n) for n in nodes]
|
||||
# Find edges incoming to this group but not outgoing from it
|
||||
incoming = [self[e[0]] for e in self.edges
|
||||
if e[1] in node_ids and e[0] not in node_ids]
|
||||
return incoming
|
||||
|
||||
def siblings(self, node):
|
||||
"""Returns all nodes that share the same parent (incoming node) with
|
||||
the given node, including the node itself.
|
||||
"""
|
||||
incoming = self.incoming(node)
|
||||
# TODO: Not handling the case of multiple incoming nodes yet
|
||||
if len(incoming) == 1:
|
||||
incoming = incoming[0]
|
||||
siblings = self.outgoing(incoming)
|
||||
return siblings
|
||||
else:
|
||||
return [node]
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, list):
|
||||
return [self.nodes.get(k) for k in key]
|
||||
else:
|
||||
return self.nodes.get(key)
|
||||
|
||||
def remove(self, nodes):
|
||||
"""Remove a node and its edges."""
|
||||
nodes = nodes if isinstance(nodes, list) else [nodes]
|
||||
for node in nodes:
|
||||
k = self.id(node)
|
||||
self.edges = list(filter(lambda e: e[0] != k and e[1] != k, self.edges))
|
||||
del self.nodes[k]
|
||||
|
||||
def replace(self, nodes, node):
|
||||
"""Replace nodes with node. Edges incoming to nodes[0] are connected to
|
||||
the new node, and nodes outgoing from nodes[-1] become outgoing from
|
||||
the new node."""
|
||||
nodes = nodes if isinstance(nodes, list) else [nodes]
|
||||
# Is the new node part of the replace nodes (i.e. want to collapse
|
||||
# a group of nodes into one of them)?
|
||||
collapse = self.id(node) in self.nodes
|
||||
# Add new node and edges
|
||||
if not collapse:
|
||||
self.add_node(node)
|
||||
for in_node in self.incoming(nodes):
|
||||
# TODO: check specifically for output_shape is not generic. Consider refactoring.
|
||||
self.add_edge(in_node, node, in_node.output_shape if hasattr(in_node, "output_shape") else None)
|
||||
for out_node in self.outgoing(nodes):
|
||||
self.add_edge(node, out_node, node.output_shape if hasattr(node, "output_shape") else None)
|
||||
# Remove the old nodes
|
||||
for n in nodes:
|
||||
if collapse and n == node:
|
||||
continue
|
||||
self.remove(n)
|
||||
|
||||
def search(self, pattern):
|
||||
"""Searches the graph for a sub-graph that matches the given pattern
|
||||
and returns the first match it finds.
|
||||
"""
|
||||
for node in self.nodes.values():
|
||||
match, following = pattern.match(self, node)
|
||||
if match:
|
||||
return match, following
|
||||
return [], None
|
||||
|
||||
|
||||
def sequence_id(self, sequence):
|
||||
"""Make up an ID for a sequence (list) of nodes.
|
||||
Note: `getrandbits()` is very uninformative as a "readable" ID. Here, we build a name
|
||||
such that when the mouse hovers over the drawn node in Jupyter, one can figure out
|
||||
which original nodes make up the sequence. This is actually quite useful.
|
||||
"""
|
||||
if self.meaningful_ids:
|
||||
# TODO: This might fail if the ID becomes too long
|
||||
return "><".join([node.id for node in sequence])
|
||||
else:
|
||||
return getrandbits(64)
|
||||
|
||||
def build_dot(self, orientation):
|
||||
"""Generate a GraphViz Dot graph.
|
||||
|
||||
Returns a GraphViz Digraph object.
|
||||
"""
|
||||
from graphviz import Digraph
|
||||
|
||||
# Build GraphViz Digraph
|
||||
dot = Digraph()
|
||||
dot.attr("graph",
|
||||
bgcolor=self.theme["background_color"],
|
||||
color=self.theme["outline_color"],
|
||||
fontsize=self.theme["font_size"],
|
||||
fontcolor=self.theme["font_color"],
|
||||
fontname=self.theme["font_name"],
|
||||
margin=self.theme["margin"],
|
||||
rankdir=orientation,
|
||||
pad=self.theme["padding"])
|
||||
dot.attr("node", shape="box",
|
||||
style="filled", margin="0,0",
|
||||
fillcolor=self.theme["fill_color"],
|
||||
color=self.theme["outline_color"],
|
||||
fontsize=self.theme["font_size"],
|
||||
fontcolor=self.theme["font_color"],
|
||||
fontname=self.theme["font_name"])
|
||||
dot.attr("edge", style="solid",
|
||||
color=self.theme["outline_color"],
|
||||
fontsize=self.theme["font_size"],
|
||||
fontcolor=self.theme["font_color"],
|
||||
fontname=self.theme["font_name"])
|
||||
|
||||
for k, n in self.nodes.items():
|
||||
label = "<tr><td cellpadding='6'>{}</td></tr>".format(n.title)
|
||||
if n.caption:
|
||||
label += "<tr><td>{}</td></tr>".format(n.caption)
|
||||
if n.repeat > 1:
|
||||
label += "<tr><td align='right' cellpadding='2'>x{}</td></tr>".format(n.repeat)
|
||||
label = "<<table border='0' cellborder='0' cellpadding='0'>{}</table>>".\
|
||||
format(label)
|
||||
|
||||
# figure out tooltip
|
||||
tooltips = set()
|
||||
if len(n.params):
|
||||
tooltips.update([str(n.params)])
|
||||
if n.combo_params:
|
||||
for params in n.combo_params:
|
||||
if len(params):
|
||||
tooltips.update([str(params)])
|
||||
|
||||
# figure out shape and color
|
||||
layer_overrides = self.theme.get('layer_overrides', {})
|
||||
op_overrides = layer_overrides.get(n.op or n.name, {})
|
||||
|
||||
dot.node(str(k), label, tooltip=html.escape(' '.join(tooltips)), **op_overrides)
|
||||
for a, b, label in self.edges:
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = ' ' + "x".join([str(l or "?") for l in label])
|
||||
|
||||
dot.edge(str(a), str(b), label)
|
||||
return dot
|
||||
|
||||
def _repr_svg_(self):
|
||||
"""Allows Jupyter notebook to render the graph automatically."""
|
||||
return self.build_dot(self.orientation)._repr_svg_()
|
||||
|
||||
def save(self, path, format="pdf"):
|
||||
# TODO: assert on acceptable format values
|
||||
dot = self.build_dot(self.orientation)
|
||||
dot.format = format
|
||||
directory, file_name = os.path.split(path)
|
||||
# Remove extension from file name. dot.render() adds it.
|
||||
file_name = file_name.replace("." + format, "")
|
||||
dot.render(file_name, directory=directory, cleanup=True)
|
|
@ -0,0 +1,137 @@
|
|||
"""
|
||||
HiddenLayer
|
||||
|
||||
PyTorch graph importer.
|
||||
|
||||
Written by Waleed Abdulla
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import re
|
||||
from .graph import Graph, Node
|
||||
from . import transforms as ht
|
||||
import torch
|
||||
from collections import abc
|
||||
import numpy as np
|
||||
|
||||
# PyTorch Graph Transforms
|
||||
FRAMEWORK_TRANSFORMS = [
|
||||
# Hide onnx: prefix
|
||||
ht.Rename(op=r"onnx::(.*)", to=r"\1"),
|
||||
# ONNX uses Gemm for linear layers (stands for General Matrix Multiplication).
|
||||
# It's an odd name that noone recognizes. Rename it.
|
||||
ht.Rename(op=r"Gemm", to=r"Linear"),
|
||||
# PyTorch layers that don't have an ONNX counterpart
|
||||
ht.Rename(op=r"aten::max\_pool2d\_with\_indices", to="MaxPool"),
|
||||
# Shorten op name
|
||||
ht.Rename(op=r"BatchNormalization", to="BatchNorm"),
|
||||
]
|
||||
|
||||
|
||||
def dump_pytorch_graph(graph):
|
||||
"""List all the nodes in a PyTorch graph."""
|
||||
f = "{:25} {:40} {} -> {}"
|
||||
print(f.format("kind", "scopeName", "inputs", "outputs"))
|
||||
for node in graph.nodes():
|
||||
print(f.format(node.kind(), node.scopeName(),
|
||||
[i.unique() for i in node.inputs()],
|
||||
[i.unique() for i in node.outputs()]
|
||||
))
|
||||
|
||||
|
||||
def pytorch_id(node):
|
||||
"""Returns a unique ID for a node."""
|
||||
# After ONNX simplification, the scopeName is not unique anymore
|
||||
# so append node outputs to guarantee uniqueness
|
||||
return node.scopeName() + "/outputs/" + "/".join([o.uniqueName() for o in node.outputs()])
|
||||
|
||||
|
||||
def get_shape(torch_node):
|
||||
"""Return the output shape of the given Pytorch node."""
|
||||
# Extract node output shape from the node string representation
|
||||
# This is a hack because there doesn't seem to be an official way to do it.
|
||||
# See my quesiton in the PyTorch forum:
|
||||
# https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
|
||||
# TODO: find a better way to extract output shape
|
||||
# TODO: Assuming the node has one output. Update if we encounter a multi-output node.
|
||||
m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
|
||||
if m:
|
||||
shape = m.group(1)
|
||||
shape = shape.split(",")
|
||||
shape = tuple(map(int, shape))
|
||||
else:
|
||||
shape = None
|
||||
return shape
|
||||
|
||||
def calc_rf(model, input_shape):
|
||||
for n, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue;
|
||||
if 'bias' in n:
|
||||
p.data.fill_(0)
|
||||
elif 'weight' in n:
|
||||
p.data.fill_(1)
|
||||
|
||||
input = torch.ones(input_shape, requires_grad=True)
|
||||
output = model(input)
|
||||
out_shape = output.size()
|
||||
ndims = len(out_shape)
|
||||
grad = torch.zeros(out_shape)
|
||||
l_tmp=[]
|
||||
for i in xrange(ndims):
|
||||
if i==0 or i==1:#batch or channel
|
||||
l_tmp.append(0)
|
||||
else:
|
||||
l_tmp.append(out_shape[i]/2)
|
||||
|
||||
grad[tuple(l_tmp)] = 1
|
||||
output.backward(gradient=grad)
|
||||
grad_np = img_.grad[0,0].data.numpy()
|
||||
idx_nonzeros = np.where(grad_np!=0)
|
||||
RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros]
|
||||
|
||||
return RF
|
||||
|
||||
def import_graph(hl_graph, model, args, input_names=None, verbose=False):
|
||||
# TODO: add input names to graph
|
||||
|
||||
if args is None:
|
||||
args = [1, 3, 224, 224] # assume ImageNet default
|
||||
|
||||
# if args is not Tensor but is array like then convert it to torch tensor
|
||||
if not isinstance(args, torch.Tensor) and \
|
||||
hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
|
||||
not isinstance(args, (str, abc.ByteString)):
|
||||
args = torch.ones(args)
|
||||
|
||||
# Run the Pytorch graph to get a trace and generate a graph from it
|
||||
trace, out = torch.jit.get_trace_graph(model, args)
|
||||
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
|
||||
torch_graph = trace.graph()
|
||||
|
||||
# Dump list of nodes (DEBUG only)
|
||||
if verbose:
|
||||
dump_pytorch_graph(torch_graph)
|
||||
|
||||
# Loop through nodes and build HL graph
|
||||
for torch_node in torch_graph.nodes():
|
||||
# Op
|
||||
op = torch_node.kind()
|
||||
# Parameters
|
||||
params = {k: torch_node[k] for k in torch_node.attributeNames()}
|
||||
# Inputs/outputs
|
||||
# TODO: inputs = [i.unique() for i in node.inputs()]
|
||||
outputs = [o.unique() for o in torch_node.outputs()]
|
||||
# Get output shape
|
||||
shape = get_shape(torch_node)
|
||||
# Add HL node
|
||||
hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
|
||||
output_shape=shape, params=params)
|
||||
hl_graph.add_node(hl_node)
|
||||
# Add edges
|
||||
for target_torch_node in torch_graph.nodes():
|
||||
target_inputs = [i.unique() for i in target_torch_node.inputs()]
|
||||
if set(outputs) & set(target_inputs):
|
||||
hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
|
||||
return hl_graph
|
|
@ -0,0 +1,142 @@
|
|||
"""
|
||||
HiddenLayer
|
||||
|
||||
TensorFlow graph importer.
|
||||
|
||||
Written by Phil Ferriere. Edits by Waleed Abdulla.
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
from .graph import Graph, Node
|
||||
from . import transforms as ht
|
||||
|
||||
|
||||
FRAMEWORK_TRANSFORMS = [
|
||||
# Rename VariableV2 op to Variable. Same for anything V2, V3, ...etc.
|
||||
ht.Rename(op=r"(\w+)V\d", to=r"\1"),
|
||||
ht.Prune("Const"),
|
||||
ht.Prune("PlaceholderWithDefault"),
|
||||
ht.Prune("Variable"),
|
||||
ht.Prune("VarIsInitializedOp"),
|
||||
ht.Prune("VarHandleOp"),
|
||||
ht.Prune("ReadVariableOp"),
|
||||
ht.PruneBranch("Assign"),
|
||||
ht.PruneBranch("AssignSub"),
|
||||
ht.PruneBranch("AssignAdd"),
|
||||
ht.PruneBranch("AssignVariableOp"),
|
||||
ht.Prune("ApplyMomentum"),
|
||||
ht.Prune("ApplyAdam"),
|
||||
ht.FoldId(r"^(gradients)/.*", "NoOp"), # Fold to NoOp then delete in the next step
|
||||
ht.Prune("NoOp"),
|
||||
ht.Rename(op=r"DepthwiseConv2dNative", to="SeparableConv"),
|
||||
ht.Rename(op=r"Conv2D", to="Conv"),
|
||||
ht.Rename(op=r"FusedBatchNorm", to="BatchNorm"),
|
||||
ht.Rename(op=r"MatMul", to="Linear"),
|
||||
ht.Fold("Conv > BiasAdd", "__first__"),
|
||||
ht.Fold("Linear > BiasAdd", "__first__"),
|
||||
ht.Fold("Shape > StridedSlice > Pack > Reshape", "__last__"),
|
||||
ht.FoldId(r"(.+)/dropout/.*", "Dropout"),
|
||||
ht.FoldId(r"(softmax_cross\_entropy)\_with\_logits.*", "SoftmaxCrossEntropy"),
|
||||
]
|
||||
|
||||
|
||||
def dump_tf_graph(tfgraph, tfgraphdef):
|
||||
"""List all the nodes in a TF graph.
|
||||
tfgraph: A TF Graph object.
|
||||
tfgraphdef: A TF GraphDef object.
|
||||
"""
|
||||
print("Nodes ({})".format(len(tfgraphdef.node)))
|
||||
f = "{:15} {:59} {:20} {}"
|
||||
print(f.format("kind", "scopeName", "shape", "inputs"))
|
||||
for node in tfgraphdef.node:
|
||||
scopename = node.name
|
||||
kind = node.op
|
||||
inputs = node.input
|
||||
shape = tf.graph_util.tensor_shape_from_node_def_name(tfgraph, scopename)
|
||||
print(f.format(kind, scopename, str(shape), inputs))
|
||||
|
||||
|
||||
def import_graph(hl_graph, tf_graph, output=None, verbose=False):
|
||||
"""Convert TF graph to directed graph
|
||||
tfgraph: A TF Graph object.
|
||||
output: Name of the output node (string).
|
||||
verbose: Set to True for debug print output
|
||||
"""
|
||||
# Get clean(er) list of nodes
|
||||
graph_def = tf_graph.as_graph_def(add_shapes=True)
|
||||
graph_def = tf.graph_util.remove_training_nodes(graph_def)
|
||||
|
||||
# Dump list of TF nodes (DEBUG only)
|
||||
if verbose:
|
||||
dump_tf_graph(tf_graph, graph_def)
|
||||
|
||||
# Loop through nodes and build the matching directed graph
|
||||
for tf_node in graph_def.node:
|
||||
# Read node details
|
||||
try:
|
||||
op, uid, name, shape, params = import_node(tf_node, tf_graph, verbose)
|
||||
except:
|
||||
if verbose:
|
||||
logging.exception("Failed to read node {}".format(tf_node))
|
||||
continue
|
||||
|
||||
# Add node
|
||||
hl_node = Node(uid=uid, name=name, op=op, output_shape=shape, params=params)
|
||||
hl_graph.add_node(hl_node)
|
||||
|
||||
# Add edges
|
||||
for target_node in graph_def.node:
|
||||
target_inputs = target_node.input
|
||||
if uid in target_node.input:
|
||||
hl_graph.add_edge_by_id(uid, target_node.name, shape)
|
||||
return hl_graph
|
||||
|
||||
|
||||
def import_node(tf_node, tf_graph, verbose=False):
|
||||
# Operation type and name
|
||||
op = tf_node.op
|
||||
uid = tf_node.name
|
||||
name = None
|
||||
|
||||
# Shape
|
||||
shape = None
|
||||
if tf_node.op != "NoOp":
|
||||
try:
|
||||
shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.name)
|
||||
# Is the shape is known, convert to a list
|
||||
if shape.ndims is not None:
|
||||
shape = shape.as_list()
|
||||
except:
|
||||
if verbose:
|
||||
logging.exception("Error reading shape of {}".format(tf_node.name))
|
||||
|
||||
# Parameters
|
||||
# At this stage, we really only care about two parameters:
|
||||
# 1/ the kernel size used by convolution layers
|
||||
# 2/ the stride used by convolutional and pooling layers (TODO: not fully working yet)
|
||||
|
||||
# 1/ The kernel size is actually not stored in the convolution tensor but in its weight input.
|
||||
# The weights input has the shape [shape=[kernel, kernel, in_channels, filters]]
|
||||
# So we must fish for it
|
||||
params = {}
|
||||
if op == "Conv2D" or op == "DepthwiseConv2dNative":
|
||||
kernel_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.input[1])
|
||||
kernel_shape = [int(a) for a in kernel_shape]
|
||||
params["kernel_shape"] = kernel_shape[0:2]
|
||||
if 'strides' in tf_node.attr.keys():
|
||||
strides = [int(a) for a in tf_node.attr['strides'].list.i]
|
||||
params["stride"] = strides[1:3]
|
||||
elif op == "MaxPool" or op == "AvgPool":
|
||||
# 2/ the stride used by pooling layers
|
||||
# See https://stackoverflow.com/questions/44124942/how-to-access-values-in-protos-in-tensorflow
|
||||
if 'ksize' in tf_node.attr.keys():
|
||||
kernel_shape = [int(a) for a in tf_node.attr['ksize'].list.i]
|
||||
params["kernel_shape"] = kernel_shape[1:3]
|
||||
if 'strides' in tf_node.attr.keys():
|
||||
strides = [int(a) for a in tf_node.attr['strides'].list.i]
|
||||
params["stride"] = strides[1:3]
|
||||
|
||||
return op, uid, name, shape, params
|
|
@ -0,0 +1,212 @@
|
|||
"""
|
||||
HiddenLayer
|
||||
|
||||
Transforms that apply to and modify graph nodes.
|
||||
|
||||
Written by Waleed Abdulla
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
import re
|
||||
import copy
|
||||
from .graph import Node
|
||||
from . import ge
|
||||
|
||||
|
||||
|
||||
###########################################################################
|
||||
# Transforms
|
||||
###########################################################################
|
||||
|
||||
def _concate_params(matches):
|
||||
combo_params = [match.params for match in matches]
|
||||
combo_params += [cb for match in matches if match.combo_params for cb in match.combo_params]
|
||||
return combo_params
|
||||
|
||||
class Fold():
|
||||
def __init__(self, pattern, op, name=None):
|
||||
# TODO: validate that op and name are valid
|
||||
self.pattern = ge.GEParser(pattern).parse()
|
||||
self.op = op
|
||||
self.name = name
|
||||
|
||||
def apply(self, graph):
|
||||
# Copy the graph. Don't change the original.
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
while True:
|
||||
matches, _ = graph.search(self.pattern)
|
||||
if not matches:
|
||||
break
|
||||
|
||||
# Replace pattern with new node
|
||||
if self.op == "__first__":
|
||||
combo = matches[0]
|
||||
elif self.op == "__last__":
|
||||
combo = matches[-1]
|
||||
else:
|
||||
combo = Node(uid=graph.sequence_id(matches),
|
||||
name=self.name or " > ".join([l.title for l in matches]),
|
||||
op=self.op or self.pattern,
|
||||
output_shape=matches[-1].output_shape,
|
||||
combo_params=_concate_params(matches))
|
||||
combo._caption = "/".join(filter(None, [l.caption for l in matches]))
|
||||
graph.replace(matches, combo)
|
||||
return graph
|
||||
|
||||
|
||||
class FoldId():
|
||||
def __init__(self, id_regex, op, name=None):
|
||||
# TODO: validate op and name are valid
|
||||
self.id_regex = re.compile(id_regex)
|
||||
self.op = op
|
||||
self.name = name
|
||||
|
||||
def apply(self, graph):
|
||||
# Copy the graph. Don't change the original.
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# Group nodes by the first matching group of the regex
|
||||
groups = {}
|
||||
for node in graph.nodes.values():
|
||||
m = self.id_regex.match(node.id)
|
||||
if not m:
|
||||
continue
|
||||
|
||||
assert m.groups(), "Regular expression must have a matching group to avoid folding unrelated nodes."
|
||||
key = m.group(1)
|
||||
if key not in groups:
|
||||
groups[key] = []
|
||||
groups[key].append(node)
|
||||
|
||||
# Fold each group of nodes together
|
||||
for key, nodes in groups.items():
|
||||
# Replace with a new node
|
||||
# TODO: Find last node in the sub-graph and get the output shape from it
|
||||
combo = Node(uid=key,
|
||||
name=self.name,
|
||||
op=self.op,
|
||||
combo_params=_concate_params(nodes))
|
||||
graph.replace(nodes, combo)
|
||||
return graph
|
||||
|
||||
|
||||
class Prune():
|
||||
def __init__(self, pattern):
|
||||
self.pattern = ge.GEParser(pattern).parse()
|
||||
|
||||
def apply(self, graph):
|
||||
# Copy the graph. Don't change the original.
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
while True:
|
||||
matches, _ = graph.search(self.pattern)
|
||||
if not matches:
|
||||
break
|
||||
# Remove found nodes
|
||||
graph.remove(matches)
|
||||
return graph
|
||||
|
||||
|
||||
class PruneBranch():
|
||||
def __init__(self, pattern):
|
||||
self.pattern = ge.GEParser(pattern).parse()
|
||||
|
||||
def tag(self, node, tag, graph, conditional=False):
|
||||
# Return if the node is already tagged
|
||||
if hasattr(node, "__tag__") and node.__tag__ == "tag":
|
||||
return
|
||||
# If conditional, then tag the node if and only if all its
|
||||
# outgoing nodes already have the same tag.
|
||||
if conditional:
|
||||
# Are all outgoing nodes already tagged?
|
||||
outgoing = graph.outgoing(node)
|
||||
tagged = filter(lambda n: hasattr(n, "__tag__") and n.__tag__ == tag,
|
||||
outgoing)
|
||||
if len(list(tagged)) != len(outgoing):
|
||||
# Not all outgoing are tagged
|
||||
return
|
||||
# Tag the node
|
||||
node.__tag__ = tag
|
||||
# Tag incoming nodes
|
||||
for n in graph.incoming(node):
|
||||
self.tag(n, tag, graph, conditional=True)
|
||||
|
||||
def apply(self, graph):
|
||||
# Copy the graph. Don't change the original.
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
while True:
|
||||
matches, _ = graph.search(self.pattern)
|
||||
if not matches:
|
||||
break
|
||||
# Tag found nodes and their incoming branches
|
||||
for n in matches:
|
||||
self.tag(n, "delete", graph)
|
||||
# Find all tagged nodes and delete them
|
||||
tagged = [n for n in graph.nodes.values()
|
||||
if hasattr(n, "__tag__") and n.__tag__ == "delete"]
|
||||
graph.remove(tagged)
|
||||
return graph
|
||||
|
||||
|
||||
class FoldDuplicates():
|
||||
def apply(self, graph):
|
||||
# Copy the graph. Don't change the original.
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
matches = True
|
||||
while matches:
|
||||
for node in graph.nodes.values():
|
||||
pattern = ge.SerialPattern([ge.NodePattern(node.op), ge.NodePattern(node.op)])
|
||||
matches, _ = pattern.match(graph, node)
|
||||
if matches:
|
||||
# Use op and name from the first node, and output_shape from the last
|
||||
combo = Node(uid=graph.sequence_id(matches),
|
||||
name=node.name,
|
||||
op=node.op,
|
||||
output_shape=matches[-1].output_shape,
|
||||
combo_params=_concate_params(matches))
|
||||
combo._caption = node.caption
|
||||
combo.repeat = sum([n.repeat for n in matches])
|
||||
graph.replace(matches, combo)
|
||||
break
|
||||
return graph
|
||||
|
||||
|
||||
class Rename():
|
||||
def __init__(self, op=None, name=None, to=None):
|
||||
assert op or name, "Either op or name must be provided"
|
||||
assert not(op and name), "Either op or name should be provided, but not both"
|
||||
assert bool(to), "The to parameter is required"
|
||||
self.to = to
|
||||
self.op = re.compile(op) if op else None
|
||||
self.name = re.compile(name) if name else None
|
||||
|
||||
def apply(self, graph):
|
||||
# Copy the graph. Don't change the original.
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
for node in graph.nodes.values():
|
||||
if self.op:
|
||||
node.op = self.op.sub(self.to, node.op)
|
||||
# TODO: name is not tested yet
|
||||
if self.name:
|
||||
node.name = self.name.sub(self.to, node.name)
|
||||
return graph
|
||||
|
||||
|
||||
# Transforms to simplify graphs by folding layers that tend to be
|
||||
# used together often, such as Conv/BN/Relu.
|
||||
# These transforms are used AFTER the framework specific transforms
|
||||
# that map TF and PyTorch graphs to a common representation.
|
||||
SIMPLICITY_TRANSFORMS = [
|
||||
Fold("Conv > Conv > BatchNorm > Relu", "ConvConvBnRelu"),
|
||||
Fold("Conv > BatchNorm > Relu", "ConvBnRelu"),
|
||||
Fold("Conv > BatchNorm", "ConvBn"),
|
||||
Fold("Conv > Relu", "ConvRelu"),
|
||||
Fold("Linear > Relu", "LinearRelu"),
|
||||
# Fold("ConvBnRelu > MaxPool", "ConvBnReluMaxpool"),
|
||||
# Fold("ConvRelu > MaxPool", "ConvReluMaxpool"),
|
||||
FoldDuplicates(),
|
||||
]
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torchstat
|
||||
import pandas as pd
|
||||
|
||||
def model_stats(model, input_shape):
|
||||
if len(input_shape) > 3:
|
||||
input_shape = input_shape[1:4]
|
||||
ms = torchstat.statistics.ModelStat(model, input_shape, 1)
|
||||
collected_nodes = ms._analyze_model()
|
||||
return _report_format(collected_nodes)
|
||||
|
||||
def _round_value(value, binary=False):
|
||||
divisor = 1024. if binary else 1000.
|
||||
|
||||
if value // divisor**4 > 0:
|
||||
return str(round(value / divisor**4, 2)) + 'T'
|
||||
elif value // divisor**3 > 0:
|
||||
return str(round(value / divisor**3, 2)) + 'G'
|
||||
elif value // divisor**2 > 0:
|
||||
return str(round(value / divisor**2, 2)) + 'M'
|
||||
elif value // divisor > 0:
|
||||
return str(round(value / divisor, 2)) + 'K'
|
||||
return str(value)
|
||||
|
||||
|
||||
def _report_format(collected_nodes):
|
||||
pd.set_option('display.width', 1000)
|
||||
pd.set_option('display.max_rows', 10000)
|
||||
pd.set_option('display.max_columns', 10000)
|
||||
|
||||
data = list()
|
||||
for node in collected_nodes:
|
||||
name = node.name
|
||||
input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format(
|
||||
*[e for e in node.input_shape])
|
||||
output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format(
|
||||
*[e for e in node.output_shape])
|
||||
parameter_quantity = node.parameter_quantity
|
||||
inference_memory = node.inference_memory
|
||||
MAdd = node.MAdd
|
||||
Flops = node.Flops
|
||||
mread, mwrite = [i for i in node.Memory]
|
||||
duration = node.duration
|
||||
data.append([name, input_shape, output_shape, parameter_quantity,
|
||||
inference_memory, MAdd, duration, Flops, mread,
|
||||
mwrite])
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = ['module name', 'input shape', 'output shape',
|
||||
'params', 'memory(MB)',
|
||||
'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)']
|
||||
df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7)
|
||||
df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)']
|
||||
total_parameters_quantity = df['params'].sum()
|
||||
total_memory = df['memory(MB)'].sum()
|
||||
total_operation_quantity = df['MAdd'].sum()
|
||||
total_flops = df['Flops'].sum()
|
||||
total_duration = df['duration[%]'].sum()
|
||||
total_mread = df['MemRead(B)'].sum()
|
||||
total_mwrite = df['MemWrite(B)'].sum()
|
||||
total_memrw = df['MemR+W(B)'].sum()
|
||||
del df['duration']
|
||||
|
||||
# Add Total row
|
||||
total_df = pd.Series([total_parameters_quantity, total_memory,
|
||||
total_operation_quantity, total_flops,
|
||||
total_duration, mread, mwrite, total_memrw],
|
||||
index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]',
|
||||
'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'],
|
||||
name='total')
|
||||
df = df.append(total_df)
|
||||
|
||||
df = df.fillna(' ')
|
||||
df['memory(MB)'] = df['memory(MB)'].apply(
|
||||
lambda x: '{:.2f}'.format(x))
|
||||
df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
|
||||
df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x))
|
||||
df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x))
|
||||
|
||||
#summary = "Total params: {:,}\n".format(total_parameters_quantity)
|
||||
|
||||
#summary += "-" * len(str(df).split('\n')[0])
|
||||
#summary += '\n'
|
||||
#summary += "Total memory: {:.2f}MB\n".format(total_memory)
|
||||
#summary += "Total MAdd: {}MAdd\n".format(_round_value(total_operation_quantity))
|
||||
#summary += "Total Flops: {}Flops\n".format(_round_value(total_flops))
|
||||
#summary += "Total MemR+W: {}B\n".format(_round_value(total_memrw, True))
|
||||
return df
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# import matplotlib before anything else
|
||||
# because of VS debugger issue for multiprocessing
|
||||
# https://github.com/Microsoft/ptvsd/issues/1041
|
||||
from .line_plot import LinePlot
|
||||
from .image_plot import ImagePlot
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#from IPython import get_ipython, display
|
||||
#if get_ipython():
|
||||
# get_ipython().magic('matplotlib notebook')
|
||||
|
||||
#import matplotlib
|
||||
#if os.name == 'posix' and "DISPLAY" not in os.environ:
|
||||
# matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
#from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
|
||||
#from ipykernel.pylab.backend_inline import flush_figures
|
||||
|
||||
from ..vis_base import VisBase
|
||||
|
||||
import sys, traceback, logging
|
||||
from abc import abstractmethod
|
||||
from .. import utils
|
||||
from IPython import get_ipython #, display
|
||||
import ipywidgets as widgets
|
||||
|
||||
class BaseMplPlot(VisBase):
|
||||
def __init__(self, cell:widgets.Box=None, title:str=None, show_legend:bool=None, stream_name:str=None, console_debug:bool=False, **vis_args):
|
||||
super(BaseMplPlot, self).__init__(widgets.Output(), cell, title, show_legend, stream_name=stream_name, console_debug=console_debug, **vis_args)
|
||||
|
||||
self._fig_init_done = False
|
||||
self.show_legend = show_legend
|
||||
# graph objects
|
||||
self.figure = None
|
||||
self._ax_main = None
|
||||
# matplotlib animation
|
||||
self.animation = None
|
||||
self.anim_interval = None
|
||||
#print(matplotlib.get_backend())
|
||||
#display.display(self.cell)
|
||||
|
||||
# anim_interval in seconds
|
||||
def init_fig(self, anim_interval:float=1.0):
|
||||
"""(for derived class) Initializes matplotlib figure"""
|
||||
if self._fig_init_done:
|
||||
return False
|
||||
|
||||
# create figure and animation
|
||||
self.figure = plt.figure(figsize=(8, 3))
|
||||
self.anim_interval = anim_interval
|
||||
|
||||
plt.set_cmap('Dark2')
|
||||
plt.rcParams['image.cmap']='Dark2'
|
||||
|
||||
self._fig_init_done = True
|
||||
return True
|
||||
|
||||
def get_main_axis(self):
|
||||
# if we don't yet have main axis, create one
|
||||
if not self._ax_main:
|
||||
# by default assign one subplot to whole graph
|
||||
self._ax_main = self.figure.add_subplot(111)
|
||||
self._ax_main.grid(True)
|
||||
# change the color of the top and right spines to opaque gray
|
||||
self._ax_main.spines['right'].set_color((.8,.8,.8))
|
||||
self._ax_main.spines['top'].set_color((.8,.8,.8))
|
||||
if self.title is not None:
|
||||
title = self._ax_main.set_title(self.title)
|
||||
title.set_weight('bold')
|
||||
return self._ax_main
|
||||
|
||||
def _on_update(self, frame): # pylint: disable=unused-argument
|
||||
try:
|
||||
self._update_stream_plots()
|
||||
except Exception as ex:
|
||||
# when exception occurs here, animation will stop and there
|
||||
# will be no further plot updates
|
||||
# TODO: may be we don't need all of below but none of them
|
||||
# are popping up exception in Jupyter Notebook because these
|
||||
# exceptions occur in background?
|
||||
self.last_ex = ex
|
||||
print(ex)
|
||||
logging.fatal(ex, exc_info=True)
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
def show(self, blocking=False):
|
||||
if not self.is_shown and self.anim_interval:
|
||||
self.animation = FuncAnimation(self.figure, self._on_update, interval=self.anim_interval*1000.0)
|
||||
super(BaseMplPlot, self).show(blocking)
|
||||
|
||||
def _post_update_stream_plot(self, stream_vis):
|
||||
utils.debug_log("Plot updated", stream_vis.stream.stream_name, verbosity=5)
|
||||
|
||||
if self.layout_dirty:
|
||||
# do not do tight_layout() call on every update
|
||||
# that would jumble up the graphs! it should only called
|
||||
# once each time there is change in layout
|
||||
self.figure.tight_layout()
|
||||
self.layout_dirty = False
|
||||
|
||||
# below forces redraw and it was helpful to
|
||||
# repaint even if there was error in interval loop
|
||||
# but it does work in native UX and not in Jupyter Notebook
|
||||
#self.figure.canvas.draw()
|
||||
#self.figure.canvas.flush_events()
|
||||
|
||||
if self._use_hbox and get_ipython():
|
||||
self.widget.clear_output(wait=True)
|
||||
with self.widget:
|
||||
plt.show(self.figure)
|
||||
|
||||
# everything else that doesn't work
|
||||
#self.figure.show()
|
||||
#display.clear_output(wait=True)
|
||||
#display.display(self.figure)
|
||||
#flush_figures()
|
||||
#plt.show()
|
||||
#show_inline_matplotlib_plots()
|
||||
#elif not get_ipython():
|
||||
# self.figure.canvas.draw()
|
||||
|
||||
def _post_add_subscription(self, stream_vis, **stream_vis_args):
|
||||
# make sure figure is initialized
|
||||
self.init_fig()
|
||||
self.init_stream_plot(stream_vis, **stream_vis_args)
|
||||
|
||||
# redo the legend
|
||||
#self.figure.legend(loc='center right', bbox_to_anchor=(1.5, 0.5))
|
||||
if self.show_legend:
|
||||
self.figure.legend(loc='lower right')
|
||||
plt.subplots_adjust(hspace=0.6)
|
||||
|
||||
def _show_widget_native(self, blocking:bool):
|
||||
#plt.ion()
|
||||
#plt.show()
|
||||
return plt.show(block=blocking)
|
||||
|
||||
def _show_widget_notebook(self):
|
||||
# no need to return anything because %matplotlib notebook will
|
||||
# detect spawning of figure and paint it
|
||||
# if self.figure is returned then you will see two of them
|
||||
return None
|
||||
#plt.show()
|
||||
#return self.figure
|
||||
|
||||
def _can_update_stream_plots(self):
|
||||
return False # we run interval timer which will flush the key
|
||||
|
||||
@abstractmethod
|
||||
def init_stream_plot(self, stream_vis, **stream_vis_args):
|
||||
"""(for derived class) Create new plot info for this stream"""
|
||||
pass
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .base_mpl_plot import BaseMplPlot
|
||||
from .. import utils, image_utils
|
||||
import numpy as np
|
||||
import skimage.transform
|
||||
#from IPython import get_ipython
|
||||
|
||||
class ImagePlot(BaseMplPlot):
|
||||
def init_stream_plot(self, stream_vis,
|
||||
rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
|
||||
colormap=None, viz_img_scale=None, **stream_vis_args):
|
||||
stream_vis.rows, stream_vis.cols = rows, cols
|
||||
stream_vis.img_channels, stream_vis.colormap = img_channels, colormap
|
||||
stream_vis.img_width, stream_vis.img_height = img_width, img_height
|
||||
stream_vis.viz_img_scale = viz_img_scale
|
||||
# subplots holding each image
|
||||
stream_vis.axs = [[None for _ in range(cols)] for _ in range(rows)]
|
||||
# axis image
|
||||
stream_vis.ax_imgs = [[None for _ in range(cols)] for _ in range(rows)]
|
||||
|
||||
def clear_plot(self, stream_vis, clear_history):
|
||||
for row in range(stream_vis.rows):
|
||||
for col in range(stream_vis.cols):
|
||||
img = stream_vis.ax_imgs[row][col]
|
||||
if img:
|
||||
x, y = img.get_size()
|
||||
img.set_data(np.zeros((x, y)))
|
||||
|
||||
def _show_stream_items(self, stream_vis, stream_items):
|
||||
# as we repaint each image plot, select last if multiple events were pending
|
||||
stream_item = None
|
||||
for er in reversed(stream_items):
|
||||
if not(er.ended or er.value is None):
|
||||
stream_item = er
|
||||
break
|
||||
if stream_item is None:
|
||||
return False
|
||||
|
||||
row, col, i = 0, 0, 0
|
||||
dirty = False
|
||||
# stream_item.value is expected to be ImagePlotItems
|
||||
for image_list in stream_item.value:
|
||||
# convert to imshow compatible, stitch images
|
||||
images = [image_utils.to_imshow_array(img, stream_vis.img_width, stream_vis.img_height) \
|
||||
for img in image_list.images if img is not None]
|
||||
img_viz = image_utils.stitch_horizontal(images, width_dim=1)
|
||||
|
||||
# resize if requested
|
||||
if stream_vis.viz_img_scale is not None:
|
||||
img_viz = skimage.transform.rescale(img_viz,
|
||||
(stream_vis.viz_img_scale, stream_vis.viz_img_scale), mode='reflect', preserve_range=False)
|
||||
|
||||
# create subplot if it doesn't exist
|
||||
ax = stream_vis.axs[row][col]
|
||||
if ax is None:
|
||||
ax = stream_vis.axs[row][col] = \
|
||||
self.figure.add_subplot(stream_vis.rows, stream_vis.cols, i+1)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
|
||||
cmap = image_list.cmap or ('Greys' if stream_vis.colormap is None and \
|
||||
len(img_viz.shape) == 2 else stream_vis.colormap)
|
||||
|
||||
stream_vis.ax_imgs[row][col] = ax.imshow(img_viz, interpolation="none", cmap=cmap, alpha=image_list.alpha)
|
||||
dirty = True
|
||||
|
||||
# set title
|
||||
title = image_list.title
|
||||
if len(title) > 12: #wordwrap if too long
|
||||
title = utils.wrap_string(title) if len(title) > 24 else title
|
||||
fontsize = 8
|
||||
else:
|
||||
fontsize = 12
|
||||
ax.set_title(title, fontsize=fontsize) #'fontweight': 'light'
|
||||
|
||||
#ax.autoscale_view() # not needed
|
||||
col = col + 1
|
||||
if col >= stream_vis.cols:
|
||||
col = 0
|
||||
row = row + 1
|
||||
if row >= stream_vis.rows:
|
||||
break
|
||||
i += 1
|
||||
|
||||
return dirty
|
||||
|
||||
|
||||
def has_legend(self):
|
||||
return self.show_legend or False
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .base_mpl_plot import BaseMplPlot
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
from .. import utils
|
||||
import numpy as np
|
||||
from ..lv_types import EventVars
|
||||
import ipywidgets as widgets
|
||||
|
||||
class LinePlot(BaseMplPlot):
|
||||
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=True, stream_name:str=None, console_debug:bool=False, is_3d:bool=False, **vis_args):
|
||||
super(LinePlot, self).__init__(cell, title, show_legend, stream_name=stream_name, console_debug=console_debug, **vis_args)
|
||||
self.is_3d = is_3d #TODO: not implemented for mpl
|
||||
|
||||
def init_stream_plot(self, stream_vis,
|
||||
xtitle='', ytitle='', color=None, xrange=None, yrange=None, **stream_vis_args):
|
||||
stream_vis.xylabel_refs = [] # annotation references
|
||||
|
||||
# add main subplot
|
||||
if len(self._stream_vises) == 0:
|
||||
stream_vis.ax = self.get_main_axis()
|
||||
else:
|
||||
stream_vis.ax = self.get_main_axis().twinx()
|
||||
|
||||
#TODO: improve color selection
|
||||
color = color or plt.cm.Dark2((len(self._stream_vises)%8)/8) # pylint: disable=no-member
|
||||
|
||||
# add default line in subplot
|
||||
stream_vis.line = matplotlib.lines.Line2D([], [],
|
||||
label=stream_vis.title or ytitle or str(stream_vis.index), color=color) #, linewidth=3
|
||||
if stream_vis.opacity is not None:
|
||||
stream_vis.line.set_alpha(stream_vis.opacity)
|
||||
stream_vis.ax.add_line(stream_vis.line)
|
||||
|
||||
# if more than 2 y-axis then place additional outside
|
||||
if len(self._stream_vises) > 1:
|
||||
pos = (len(self._stream_vises)) * 30
|
||||
stream_vis.ax.spines['right'].set_position(('outward', pos))
|
||||
|
||||
stream_vis.ax.set_xlabel(xtitle)
|
||||
stream_vis.ax.set_ylabel(ytitle)
|
||||
stream_vis.ax.yaxis.label.set_color(color)
|
||||
stream_vis.ax.yaxis.label.set_style('italic')
|
||||
stream_vis.ax.xaxis.label.set_style('italic')
|
||||
if xrange is not None:
|
||||
stream_vis.ax.set_xlim(*xrange)
|
||||
if yrange is not None:
|
||||
stream_vis.ax.set_ylim(*yrange)
|
||||
|
||||
def clear_plot(self, stream_vis, clear_history):
|
||||
lines = stream_vis.ax.get_lines()
|
||||
# if we need to keep history
|
||||
if stream_vis.history_len > 1:
|
||||
# make sure we have history len - 1 lines
|
||||
lines_keep = 0 if clear_history else stream_vis.history_len-1
|
||||
while len(lines) > lines_keep:
|
||||
lines.pop(0).remove()
|
||||
# dim old lines
|
||||
if stream_vis.dim_history and len(lines) > 0:
|
||||
alphas = np.linspace(0.05, 1, len(lines))
|
||||
for line, opacity in zip(lines, alphas):
|
||||
line.set_alpha(opacity)
|
||||
line.set_linewidth(1)
|
||||
# add new line
|
||||
line = matplotlib.lines.Line2D([], [], linewidth=3)
|
||||
stream_vis.ax.add_line(line)
|
||||
else: #clear current line
|
||||
lines[-1].set_data([], [])
|
||||
|
||||
# remove annotations
|
||||
for label_info in stream_vis.xylabel_refs:
|
||||
label_info.set_visible(False)
|
||||
label_info.remove()
|
||||
stream_vis.xylabel_refs.clear()
|
||||
|
||||
def _show_stream_items(self, stream_vis, stream_items):
|
||||
vals = self._extract_vals(stream_items)
|
||||
if not len(vals):
|
||||
return False
|
||||
|
||||
line = stream_vis.ax.get_lines()[-1]
|
||||
xdata, ydata = line.get_data()
|
||||
zdata, anndata, txtdata, clrdata = [], [], [], []
|
||||
|
||||
unpacker = lambda a0=None,a1=None,a2=None,a3=None,a4=None,a5=None, *_:(a0,a1,a2,a3,a4,a5)
|
||||
|
||||
# add each value in trace data
|
||||
# each value is of the form:
|
||||
# 2D graphs:
|
||||
# y
|
||||
# x [, y [, annotation [, text [, color]]]]
|
||||
# y
|
||||
# x [, y [, z, [annotation [, text [, color]]]]]
|
||||
for val in vals:
|
||||
# set defaults
|
||||
x, y, z = len(xdata), None, None
|
||||
ann, txt, clr = None, None, None
|
||||
|
||||
# if val turns out to be array-like, extract x,y
|
||||
val_l = utils.is_scaler_array(val)
|
||||
if val_l >= 0:
|
||||
if self.is_3d:
|
||||
x, y, z, ann, txt, clr = unpacker(*val)
|
||||
else:
|
||||
x, y, ann, txt, clr, _ = unpacker(*val)
|
||||
elif isinstance(val, EventVars):
|
||||
x = val.x if hasattr(val, 'x') else x
|
||||
y = val.y if hasattr(val, 'y') else y
|
||||
z = val.z if hasattr(val, 'z') else z
|
||||
ann = val.ann if hasattr(val, 'ann') else ann
|
||||
txt = val.txt if hasattr(val, 'txt') else txt
|
||||
clr = val.clr if hasattr(val, 'clr') else clr
|
||||
|
||||
if y is None:
|
||||
y = next(iter(val.__dict__.values()))
|
||||
else:
|
||||
y = val
|
||||
|
||||
if ann is not None:
|
||||
ann = str(ann)
|
||||
if txt is not None:
|
||||
txt = str(txt)
|
||||
|
||||
xdata.append(x)
|
||||
ydata.append(y)
|
||||
zdata.append(z)
|
||||
if (txt):
|
||||
txtdata.append(txt)
|
||||
if clr:
|
||||
clrdata.append(clr)
|
||||
if ann: #TODO: yref should be y2 for different y axis
|
||||
anndata.append(dict(x=x, y=y, xref='x', yref='y', text=ann, showarrow=False))
|
||||
|
||||
line.set_data(xdata, ydata)
|
||||
for ann in anndata:
|
||||
stream_vis.xylabel_refs.append(stream_vis.ax.text( \
|
||||
ann['x'], ann['y'], ann['text']))
|
||||
|
||||
stream_vis.ax.relim()
|
||||
stream_vis.ax.autoscale_view()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .line_plot import LinePlot
|
||||
from .embeddings_plot import EmbeddingsPlot
|
||||
#from .vis_base import *
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import plotly
|
||||
import plotly.graph_objs as go
|
||||
import ipywidgets as widgets
|
||||
|
||||
from ..vis_base import VisBase
|
||||
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from .. import utils
|
||||
|
||||
|
||||
class BasePlotlyPlot(VisBase):
|
||||
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=None, stream_name:str=None, console_debug:bool=False, **vis_args):
|
||||
super(BasePlotlyPlot, self).__init__(go.FigureWidget(), cell, title, show_legend,
|
||||
stream_name=stream_name, console_debug=console_debug, **vis_args)
|
||||
|
||||
self.widget.layout.title = title
|
||||
self.widget.layout.showlegend = show_legend if show_legend is not None else True
|
||||
|
||||
def _add_trace(self, stream_vis):
|
||||
stream_vis.trace_index = len(self.widget.data)
|
||||
trace = self._create_trace(stream_vis)
|
||||
if stream_vis.opacity is not None:
|
||||
trace.opacity = stream_vis.opacity
|
||||
self.widget.add_trace(trace)
|
||||
|
||||
def _add_trace_with_history(self, stream_vis):
|
||||
# if history buffer isn't full
|
||||
if stream_vis.history_len > len(stream_vis.trace_history):
|
||||
self._add_trace(stream_vis)
|
||||
stream_vis.trace_history.append(len(self.widget.data)-1)
|
||||
stream_vis.cur_history_index = len(stream_vis.trace_history)-1
|
||||
#if stream_vis.cur_history_index:
|
||||
# self.widget.data[trace_index].showlegend = False
|
||||
else:
|
||||
# rotate trace
|
||||
stream_vis.cur_history_index = (stream_vis.cur_history_index + 1) % stream_vis.history_len
|
||||
stream_vis.trace_index = stream_vis.trace_history[stream_vis.cur_history_index]
|
||||
self.clear_plot(stream_vis, False)
|
||||
self.widget.data[stream_vis.trace_index].opacity = stream_vis.opacity or 1
|
||||
|
||||
cur_history_len = len(stream_vis.trace_history)
|
||||
if stream_vis.dim_history and cur_history_len > 1:
|
||||
max_opacity = stream_vis.opacity or 1
|
||||
min_alpha, max_alpha, dimmed_len = max_opacity*0.05, max_opacity*0.8, cur_history_len-1
|
||||
alphas = list(utils.frange(max_alpha, min_alpha, steps=dimmed_len))
|
||||
for i, thi in enumerate(range(stream_vis.cur_history_index+1,
|
||||
stream_vis.cur_history_index+cur_history_len)):
|
||||
trace_index = stream_vis.trace_history[thi % cur_history_len]
|
||||
self.widget.data[trace_index].opacity = alphas[i]
|
||||
|
||||
@staticmethod
|
||||
def get_pallet_color(i:int):
|
||||
return plotly.colors.DEFAULT_PLOTLY_COLORS[i % len(plotly.colors.DEFAULT_PLOTLY_COLORS)]
|
||||
|
||||
@staticmethod
|
||||
def _get_axis_common_props(title:str, axis_range:tuple):
|
||||
props = {'showline':True, 'showgrid': True,
|
||||
'showticklabels': True, 'ticks':'inside'}
|
||||
if title:
|
||||
props['title'] = title
|
||||
if axis_range:
|
||||
props['range'] = list(axis_range)
|
||||
return props
|
||||
|
||||
def _can_update_stream_plots(self):
|
||||
return time.time() - self.q_last_processed > 0.5 # make configurable
|
||||
|
||||
def _post_add_subscription(self, stream_vis, **stream_vis_args):
|
||||
stream_vis.trace_history, stream_vis.cur_history_index = [], None
|
||||
self._add_trace_with_history(stream_vis)
|
||||
self._setup_layout(stream_vis)
|
||||
|
||||
if not self.widget.layout.title:
|
||||
self.widget.layout.title = stream_vis.title
|
||||
# TODO: better way for below?
|
||||
if stream_vis.history_len > 1:
|
||||
self.widget.layout.showlegend = False
|
||||
|
||||
def _show_widget_native(self, blocking:bool):
|
||||
pass
|
||||
#TODO: save image, spawn browser?
|
||||
|
||||
def _show_widget_notebook(self):
|
||||
#plotly.offline.iplot(self.widget)
|
||||
return None
|
||||
|
||||
def _post_update_stream_plot(self, stream_vis):
|
||||
# not needed for plotly as FigureWidget stays upto date
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _setup_layout(self, stream_vis):
|
||||
pass
|
||||
@abstractmethod
|
||||
def _create_trace(self, stream_vis):
|
||||
pass
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from ipywidgets import Output #, Layout
|
||||
from IPython.display import display, clear_output
|
||||
import numpy as np
|
||||
from .line_plot import LinePlot
|
||||
import time
|
||||
from .. import utils
|
||||
import ipywidgets as widgets
|
||||
|
||||
class EmbeddingsPlot(LinePlot):
|
||||
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=False, stream_name:str=None, console_debug:bool=False,
|
||||
is_3d:bool=True, hover_images=None, hover_image_reshape=None, **vis_args):
|
||||
utils.set_default(vis_args, 'height', '8in')
|
||||
super(EmbeddingsPlot, self).__init__(cell, title, show_legend,
|
||||
stream_name=stream_name, console_debug=console_debug, is_3d=is_3d, **vis_args)
|
||||
if hover_images is not None:
|
||||
plt.ioff()
|
||||
self.image_output = Output()
|
||||
self.image_figure = plt.figure(figsize=(2,2))
|
||||
self.image_ax = self.image_figure.add_subplot(111)
|
||||
self.cell.children += (self.image_output,)
|
||||
plt.ion()
|
||||
self.hover_images, self.hover_image_reshape = hover_images, hover_image_reshape
|
||||
self.last_ind, self.last_ind_time = -1, 0
|
||||
|
||||
def hover_fn(self, trace, points, state): # pylint: disable=unused-argument
|
||||
if not points:
|
||||
return
|
||||
ind = points.point_inds[0]
|
||||
if ind == self.last_ind or ind > len(self.hover_images) or ind < 0:
|
||||
return
|
||||
|
||||
if self.last_ind == -1:
|
||||
self.last_ind, self.last_ind_time = ind, time.time()
|
||||
else:
|
||||
elapsed = time.time() - self.last_ind_time
|
||||
if elapsed < 0.3:
|
||||
self.last_ind, self.last_ind_time = ind, time.time()
|
||||
if elapsed < 1:
|
||||
return
|
||||
# else too much time since update
|
||||
# else we have stable ind
|
||||
|
||||
with self.image_output:
|
||||
plt.ioff()
|
||||
|
||||
if self.hover_image_reshape:
|
||||
img = np.reshape(self.hover_images[ind], self.hover_image_reshape)
|
||||
else:
|
||||
img = self.hover_images[ind]
|
||||
if img is not None:
|
||||
clear_output(wait=True)
|
||||
self.image_ax.imshow(img)
|
||||
display(self.image_figure)
|
||||
plt.ion()
|
||||
|
||||
return None
|
||||
|
||||
def _create_trace(self, stream_vis):
|
||||
stream_vis.stream_vis_args.clear() #TODO remove this
|
||||
utils.set_default(stream_vis.stream_vis_args, 'draw_line', False)
|
||||
utils.set_default(stream_vis.stream_vis_args, 'draw_marker', True)
|
||||
utils.set_default(stream_vis.stream_vis_args, 'draw_marker_text', True)
|
||||
utils.set_default(stream_vis.stream_vis_args, 'hoverinfo', 'text')
|
||||
utils.set_default(stream_vis.stream_vis_args, 'marker', {})
|
||||
|
||||
marker = stream_vis.stream_vis_args['marker']
|
||||
utils.set_default(marker, 'size', 6)
|
||||
utils.set_default(marker, 'colorscale', 'Jet')
|
||||
utils.set_default(marker, 'showscale', False)
|
||||
utils.set_default(marker, 'opacity', 0.8)
|
||||
|
||||
return super(EmbeddingsPlot, self)._create_trace(stream_vis)
|
||||
|
||||
def subscribe(self, stream, **stream_vis_args):
|
||||
super(EmbeddingsPlot, self).subscribe(stream)
|
||||
stream_vis = self._stream_vises[stream.stream_name]
|
||||
if stream_vis.index == 0 and self.hover_images is not None:
|
||||
self.widget.data[stream_vis.trace_index].on_hover(self.hover_fn)
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import plotly.graph_objs as go
|
||||
from .base_plotly_plot import BasePlotlyPlot
|
||||
from ..lv_types import EventVars
|
||||
from .. import utils
|
||||
import ipywidgets as widgets
|
||||
|
||||
class LinePlot(BasePlotlyPlot):
|
||||
def __init__(self, cell:widgets.Box=None, title=None, show_legend:bool=True, stream_name:str=None, console_debug:bool=False,
|
||||
is_3d:bool=False, **vis_args):
|
||||
super(LinePlot, self).__init__(cell, title, show_legend, stream_name=stream_name, console_debug=console_debug, **vis_args)
|
||||
self.is_3d = is_3d
|
||||
|
||||
def _setup_layout(self, stream_vis):
|
||||
# handle multiple y axis
|
||||
yaxis = 'yaxis' + (str(stream_vis.index + 1) if stream_vis.separate_yaxis else '')
|
||||
|
||||
xaxis = 'xaxis' + str(stream_vis.index+1)
|
||||
axis_props = BasePlotlyPlot._get_axis_common_props(stream_vis.xtitle, stream_vis.xrange)
|
||||
#axis_props['rangeslider'] = dict(visible = True)
|
||||
self.widget.layout[xaxis] = axis_props
|
||||
|
||||
# handle multiple Y-Axis plots
|
||||
color = self.widget.data[stream_vis.trace_index].line.color
|
||||
yaxis = 'yaxis' + (str(stream_vis.index + 1) if stream_vis.separate_yaxis else '')
|
||||
axis_props = BasePlotlyPlot._get_axis_common_props(stream_vis.ytitle, stream_vis.yrange)
|
||||
axis_props['linecolor'] = color
|
||||
axis_props['tickfont']=axis_props['titlefont'] = dict(color=color)
|
||||
if stream_vis.index > 0 and stream_vis.separate_yaxis:
|
||||
axis_props['overlaying'] = 'y'
|
||||
axis_props['side'] = 'right'
|
||||
if stream_vis.index > 1:
|
||||
self.widget.layout.xaxis = dict(domain=[0, 1 - 0.085*(stream_vis.index-1)])
|
||||
axis_props['anchor'] = 'free'
|
||||
axis_props['position'] = 1 - 0.085*(stream_vis.index-2)
|
||||
self.widget.layout[yaxis] = axis_props
|
||||
|
||||
if self.is_3d:
|
||||
zaxis = 'zaxis' #+ str(stream_vis.index+1)
|
||||
axis_props = BasePlotlyPlot._get_axis_common_props(stream_vis.ztitle, stream_vis.zrange)
|
||||
self.widget.layout.scene[zaxis] = axis_props
|
||||
self.widget.layout.margin = dict(l=0, r=0, b=0, t=0)
|
||||
self.widget.layout.hoverdistance = 1
|
||||
|
||||
def _create_2d_trace(self, stream_vis, mode, hoverinfo, marker, line):
|
||||
yaxis = 'y' + (str(stream_vis.index + 1) if stream_vis.separate_yaxis else '')
|
||||
|
||||
trace = go.Scatter(x=[], y=[], mode=mode, name=stream_vis.title or stream_vis.ytitle, yaxis=yaxis, hoverinfo=hoverinfo,
|
||||
line=line, marker=marker)
|
||||
return trace
|
||||
|
||||
def _create_3d_trace(self, stream_vis, mode, hoverinfo, marker, line):
|
||||
trace = go.Scatter3d(x=[], y=[], z=[], mode=mode, name=stream_vis.title or stream_vis.ytitle, hoverinfo=hoverinfo,
|
||||
line=line, marker=marker)
|
||||
return trace
|
||||
|
||||
|
||||
def _create_trace(self, stream_vis):
|
||||
stream_vis.separate_yaxis = stream_vis.stream_vis_args.get('separate_yaxis', True)
|
||||
stream_vis.xtitle = stream_vis.stream_vis_args.get('xtitle',None)
|
||||
stream_vis.ytitle = stream_vis.stream_vis_args.get('ytitle',None)
|
||||
stream_vis.ztitle = stream_vis.stream_vis_args.get('ztitle',None)
|
||||
stream_vis.color = stream_vis.stream_vis_args.get('color',None)
|
||||
stream_vis.xrange = stream_vis.stream_vis_args.get('xrange',None)
|
||||
stream_vis.yrange = stream_vis.stream_vis_args.get('yrange',None)
|
||||
stream_vis.zrange = stream_vis.stream_vis_args.get('zrange',None)
|
||||
draw_line = stream_vis.stream_vis_args.get('draw_line',True)
|
||||
draw_marker = stream_vis.stream_vis_args.get('draw_marker',True)
|
||||
draw_marker_text = stream_vis.stream_vis_args.get('draw_marker_text',False)
|
||||
hoverinfo = stream_vis.stream_vis_args.get('hoverinfo',None)
|
||||
marker = stream_vis.stream_vis_args.get('marker',{})
|
||||
line = stream_vis.stream_vis_args.get('line',{})
|
||||
utils.set_default(line, 'color', stream_vis.color or BasePlotlyPlot.get_pallet_color(stream_vis.index))
|
||||
|
||||
mode = 'lines' if draw_line else ''
|
||||
if draw_marker:
|
||||
mode = ('' if mode=='' else mode+'+') + 'markers'
|
||||
if draw_marker_text:
|
||||
mode = ('' if mode=='' else mode+'+') + 'text'
|
||||
|
||||
if self.is_3d:
|
||||
return self._create_3d_trace(stream_vis, mode, hoverinfo, marker, line)
|
||||
else:
|
||||
return self._create_2d_trace(stream_vis, mode, hoverinfo, marker, line)
|
||||
|
||||
def _show_stream_items(self, stream_vis, stream_items):
|
||||
vals = self._extract_vals(stream_items)
|
||||
if not len(vals):
|
||||
return False
|
||||
|
||||
# get trace data
|
||||
trace = self.widget.data[stream_vis.trace_index]
|
||||
xdata, ydata, zdata, anndata, txtdata, clrdata = list(trace.x), list(trace.y), [], [], [], []
|
||||
if self.is_3d:
|
||||
zdata = list(trace.z)
|
||||
|
||||
unpacker = lambda a0=None,a1=None,a2=None,a3=None,a4=None,a5=None, *_:(a0,a1,a2,a3,a4,a5)
|
||||
|
||||
# add each value in trace data
|
||||
# each value is of the form:
|
||||
# 2D graphs:
|
||||
# y
|
||||
# x [, y [, annotation [, text [, color]]]]
|
||||
# y
|
||||
# x [, y [, z, [annotation [, text [, color]]]]]
|
||||
for val in vals:
|
||||
# set defaults
|
||||
x, y, z = len(xdata), None, None
|
||||
ann, txt, clr = None, None, None
|
||||
|
||||
# if val turns out to be array-like, extract x,y
|
||||
val_l = utils.is_scaler_array(val)
|
||||
if val_l >= 0:
|
||||
if self.is_3d:
|
||||
x, y, z, ann, txt, clr = unpacker(*val)
|
||||
else:
|
||||
x, y, ann, txt, clr, _ = unpacker(*val)
|
||||
elif isinstance(val, EventVars):
|
||||
x = val.x if hasattr(val, 'x') else x
|
||||
y = val.y if hasattr(val, 'y') else y
|
||||
z = val.z if hasattr(val, 'z') else z
|
||||
ann = val.ann if hasattr(val, 'ann') else ann
|
||||
txt = val.txt if hasattr(val, 'txt') else txt
|
||||
clr = val.clr if hasattr(val, 'clr') else clr
|
||||
|
||||
if y is None:
|
||||
y = next(iter(val.__dict__.values()))
|
||||
else:
|
||||
y = val
|
||||
|
||||
if ann is not None:
|
||||
ann = str(ann)
|
||||
if txt is not None:
|
||||
txt = str(txt)
|
||||
|
||||
xdata.append(x)
|
||||
ydata.append(y)
|
||||
zdata.append(z)
|
||||
if txt is not None:
|
||||
txtdata.append(txt)
|
||||
if clr is not None:
|
||||
clrdata.append(clr)
|
||||
if ann: #TODO: yref should be y2 for different y axis
|
||||
anndata.append(dict(x=x, y=y, xref='x', yref='y', text=ann, showarrow=False))
|
||||
|
||||
self.widget.data[stream_vis.trace_index].x = xdata
|
||||
self.widget.data[stream_vis.trace_index].y = ydata
|
||||
if self.is_3d:
|
||||
self.widget.data[stream_vis.trace_index].z = zdata
|
||||
|
||||
# add text
|
||||
if len(txtdata):
|
||||
exisitng = self.widget.data[stream_vis.trace_index].text
|
||||
exisitng = list(exisitng) if utils.is_array_like(exisitng) else []
|
||||
exisitng += txtdata
|
||||
self.widget.data[stream_vis.trace_index].text = exisitng
|
||||
|
||||
# add annotation
|
||||
if len(anndata):
|
||||
existing = list(self.widget.layout.annotations)
|
||||
existing += anndata
|
||||
self.widget.layout.annotations = existing
|
||||
|
||||
# add color
|
||||
if len(clrdata):
|
||||
exisitng = self.widget.data[stream_vis.trace_index].marker.color
|
||||
exisitng = list(exisitng) if utils.is_array_like(exisitng) else []
|
||||
exisitng += clrdata
|
||||
self.widget.data[stream_vis.trace_index].marker.color = exisitng
|
||||
|
||||
return True
|
||||
|
||||
def clear_plot(self, stream_vis, clear_history):
|
||||
traces = range(len(stream_vis.trace_history)) if clear_history else (stream_vis.trace_index,)
|
||||
for i in traces:
|
||||
stream_vis.trace_index = i
|
||||
|
||||
self.widget.data[stream_vis.trace_index].x = []
|
||||
self.widget.data[stream_vis.trace_index].y = []
|
||||
if self.is_3d:
|
||||
self.widget.data[stream_vis.trace_index].z = []
|
||||
self.widget.data[stream_vis.trace_index].text = ""
|
||||
# TODO: avoid removing annotations for other streams
|
||||
self.widget.layout.annotations = []
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torchvision import models, transforms
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from . import utils, image_utils
|
||||
import os
|
||||
|
||||
def get_model(model_name):
|
||||
model = models.__dict__[model_name](pretrained=True)
|
||||
return model
|
||||
|
||||
def tensors2batch(tensors, preprocess_transform=None):
|
||||
if preprocess_transform:
|
||||
tensors = tuple(preprocess_transform(i) for i in tensors)
|
||||
if not utils.is_array_like(tensors):
|
||||
tensors = tuple(tensors)
|
||||
return torch.stack(tensors, dim=0)
|
||||
|
||||
def int2tensor(val):
|
||||
return torch.LongTensor([val])
|
||||
|
||||
def image_class2tensor(image_path, class_index=None, image_convert_mode=None,
|
||||
image_transform=None):
|
||||
|
||||
raw_input = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
|
||||
if image_transform:
|
||||
input_x = image_transform(raw_input)
|
||||
else:
|
||||
input_x = transforms.ToTensor()(raw_input)
|
||||
input_x = input_x.unsqueeze(0) #convert to batch of 1
|
||||
target_class = int2tensor(class_index) if class_index is not None else None
|
||||
return raw_input, input_x, target_class
|
||||
|
||||
def batch_predict(model, inputs, input_transform=None, device=None):
|
||||
if input_transform:
|
||||
batch = torch.stack(tuple(input_transform(i) for i in inputs), dim=0)
|
||||
else:
|
||||
batch = torch.stack(inputs, dim=0)
|
||||
|
||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
batch = batch.to(device)
|
||||
|
||||
outputs = model(batch)
|
||||
|
||||
return outputs
|
||||
|
||||
def logits2probabilities(logits):
|
||||
return F.softmax(logits, dim=1)
|
||||
|
||||
def tensor2numpy(t):
|
||||
return t.detach().cpu().numpy()
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#from receptivefield.pytorch import PytorchReceptiveField
|
||||
##from receptivefield.image import get_default_image
|
||||
#import numpy as np
|
||||
|
||||
#def _get_rf(model, sample_pil_img):
|
||||
# # define model functions
|
||||
# def model_fn():
|
||||
# model.eval()
|
||||
# return model
|
||||
|
||||
# input_shape = np.array(sample_pil_img).shape
|
||||
|
||||
# rf = PytorchReceptiveField(model_fn)
|
||||
# rf_params = rf.compute(input_shape=input_shape)
|
||||
# return rf, rf_params
|
||||
|
||||
#def plot_receptive_field(model, sample_pil_img, layout=(2, 2), figsize=(6, 6)):
|
||||
# rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
|
||||
# return rf.plot_rf_grids(
|
||||
# custom_image=sample_pil_img,
|
||||
# figsize=figsize,
|
||||
# layout=layout)
|
||||
|
||||
#def plot_grads_at(model, sample_pil_img, feature_map_index=0, point=(8,8), figsize=(6, 6)):
|
||||
# rf, rf_params = _get_rf(model, sample_pil_img) # pylint: disable=unused-variable
|
||||
# return rf.plot_gradient_at(fm_id=feature_map_index, point=point, image=None, figsize=figsize)
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
|
||||
class RepeatedTimer:
|
||||
class State:
|
||||
Stopped=0
|
||||
Paused=1
|
||||
Running=2
|
||||
|
||||
def __init__(self, secs, callback, count=None):
|
||||
self.secs = secs
|
||||
self.callback = weakref.WeakMethod(callback) if callback else None
|
||||
self._thread = None
|
||||
self._state = RepeatedTimer.State.Stopped
|
||||
self.pause_wait = threading.Event()
|
||||
self.pause_wait.set()
|
||||
self._continue_thread = False
|
||||
self.count = count
|
||||
|
||||
def start(self):
|
||||
self._continue_thread = True
|
||||
self.pause_wait.set()
|
||||
if self._thread is None or not self._thread.isAlive():
|
||||
self._thread = threading.Thread(target=self._runner, name='RepeatedTimer', daemon=True)
|
||||
self._thread.start()
|
||||
self._state = RepeatedTimer.State.Running
|
||||
|
||||
def stop(self, block=False):
|
||||
self.pause_wait.set()
|
||||
self._continue_thread = False
|
||||
if block and not (self._thread is None or not self._thread.isAlive()):
|
||||
self._thread.join()
|
||||
self._state = RepeatedTimer.State.Stopped
|
||||
|
||||
def get_state(self):
|
||||
return self._state
|
||||
|
||||
|
||||
def pause(self):
|
||||
if self._state == RepeatedTimer.State.Running:
|
||||
self.pause_wait.clear()
|
||||
self._state = RepeatedTimer.State.Paused
|
||||
# else nothing to do
|
||||
def unpause(self):
|
||||
if self._state == RepeatedTimer.State.Paused:
|
||||
self.pause_wait.set()
|
||||
if self._state == RepeatedTimer.State.Paused:
|
||||
self._state = RepeatedTimer.State.Running
|
||||
# else nothing to do
|
||||
|
||||
def _runner(self):
|
||||
while (self._continue_thread):
|
||||
if self.count:
|
||||
self.count -= 0
|
||||
if not self.count:
|
||||
self._continue_thread = False
|
||||
|
||||
if self._continue_thread:
|
||||
self.pause_wait.wait()
|
||||
if self.callback and self.callback():
|
||||
self.callback()()
|
||||
|
||||
if self._continue_thread:
|
||||
time.sleep(self.secs)
|
||||
|
||||
self._thread = None
|
||||
self._state = RepeatedTimer.State.Stopped
|
|
@ -0,0 +1,5 @@
|
|||
# Credits
|
||||
Code in this folder is adopted from
|
||||
|
||||
* https://github.com/yulongwang12/visual-attribution
|
||||
* https://github.com/marcotcr/lime
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
import numpy as np
|
||||
from torch.autograd import Variable, Function
|
||||
import torch
|
||||
import types
|
||||
|
||||
|
||||
class VanillaGradExplainer(object):
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def _backprop(self, inp, ind):
|
||||
inp.requires_grad = True
|
||||
if inp.grad is not None:
|
||||
inp.grad.zero_()
|
||||
if ind.grad is not None:
|
||||
ind.grad.zero_()
|
||||
self.model.eval()
|
||||
self.model.zero_grad()
|
||||
|
||||
output = self.model(inp)
|
||||
if ind is None:
|
||||
ind = output.max(1)[1]
|
||||
grad_out = output.clone()
|
||||
grad_out.fill_(0.0)
|
||||
grad_out.scatter_(1, ind.unsqueeze(0).t(), 1.0)
|
||||
output.backward(grad_out)
|
||||
return inp.grad
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
return self._backprop(inp, ind)
|
||||
|
||||
|
||||
class GradxInputExplainer(VanillaGradExplainer):
|
||||
def __init__(self, model):
|
||||
super(GradxInputExplainer, self).__init__(model)
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
grad = self._backprop(inp, ind)
|
||||
return inp * grad
|
||||
|
||||
|
||||
class SaliencyExplainer(VanillaGradExplainer):
|
||||
def __init__(self, model):
|
||||
super(SaliencyExplainer, self).__init__(model)
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
grad = self._backprop(inp, ind)
|
||||
return grad.abs()
|
||||
|
||||
|
||||
class IntegrateGradExplainer(VanillaGradExplainer):
|
||||
def __init__(self, model, steps=100):
|
||||
super(IntegrateGradExplainer, self).__init__(model)
|
||||
self.steps = steps
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
grad = 0
|
||||
inp_data = inp.clone()
|
||||
|
||||
for alpha in np.arange(1 / self.steps, 1.0, 1 / self.steps):
|
||||
new_inp = Variable(inp_data * alpha, requires_grad=True)
|
||||
g = self._backprop(new_inp, ind)
|
||||
grad += g
|
||||
|
||||
return grad * inp_data / self.steps
|
||||
|
||||
|
||||
class DeconvExplainer(VanillaGradExplainer):
|
||||
def __init__(self, model):
|
||||
super(DeconvExplainer, self).__init__(model)
|
||||
self._override_backward()
|
||||
|
||||
def _override_backward(self):
|
||||
class _ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = torch.clamp(input, min=0)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_inp = torch.clamp(grad_output, min=0)
|
||||
return grad_inp
|
||||
|
||||
def new_forward(self, x):
|
||||
return _ReLU.apply(x)
|
||||
|
||||
def replace(m):
|
||||
if m.__class__.__name__ == 'ReLU':
|
||||
m.forward = types.MethodType(new_forward, m)
|
||||
|
||||
self.model.apply(replace)
|
||||
|
||||
|
||||
class GuidedBackpropExplainer(VanillaGradExplainer):
|
||||
def __init__(self, model):
|
||||
super(GuidedBackpropExplainer, self).__init__(model)
|
||||
self._override_backward()
|
||||
|
||||
def _override_backward(self):
|
||||
class _ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = torch.clamp(input, min=0)
|
||||
ctx.save_for_backward(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
output, = ctx.saved_tensors
|
||||
mask1 = (output > 0).float()
|
||||
mask2 = (grad_output > 0).float()
|
||||
grad_inp = mask1 * mask2 * grad_output
|
||||
grad_output.copy_(grad_inp)
|
||||
return grad_output
|
||||
|
||||
def new_forward(self, x):
|
||||
return _ReLU.apply(x)
|
||||
|
||||
def replace(m):
|
||||
if m.__class__.__name__ == 'ReLU':
|
||||
m.forward = types.MethodType(new_forward, m)
|
||||
|
||||
self.model.apply(replace)
|
||||
|
||||
|
||||
# modified from https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L80
|
||||
class SmoothGradExplainer(object):
|
||||
def __init__(self, model, base_explainer=None, stdev_spread=0.15,
|
||||
nsamples=25, magnitude=True):
|
||||
self.base_explainer = base_explainer or VanillaGradExplainer(model)
|
||||
self.stdev_spread = stdev_spread
|
||||
self.nsamples = nsamples
|
||||
self.magnitude = magnitude
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
stdev = self.stdev_spread * (inp.max() - inp.min())
|
||||
|
||||
total_gradients = 0
|
||||
|
||||
for i in range(self.nsamples):
|
||||
noise = torch.randn_like(inp) * stdev
|
||||
|
||||
noisy_inp = inp + noise
|
||||
noisy_inp.retain_grad()
|
||||
grad = self.base_explainer.explain(noisy_inp, ind)
|
||||
|
||||
if self.magnitude:
|
||||
total_gradients += grad ** 2
|
||||
else:
|
||||
total_gradients += grad
|
||||
|
||||
return total_gradients / self.nsamples
|
|
@ -0,0 +1,80 @@
|
|||
from .backprop import GradxInputExplainer
|
||||
import types
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
# Based on formulation in DeepExplain, https://arxiv.org/abs/1711.06104
|
||||
# https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L221-L272
|
||||
class DeepLIFTRescaleExplainer(GradxInputExplainer):
|
||||
def __init__(self, model):
|
||||
super(DeepLIFTRescaleExplainer, self).__init__(model)
|
||||
self._prepare_reference()
|
||||
self.baseline_inp = None
|
||||
self._override_backward()
|
||||
|
||||
def _prepare_reference(self):
|
||||
def init_refs(m):
|
||||
name = m.__class__.__name__
|
||||
if name.find('ReLU') != -1:
|
||||
m.ref_inp_list = []
|
||||
m.ref_out_list = []
|
||||
|
||||
def ref_forward(self, x):
|
||||
self.ref_inp_list.append(x.data.clone())
|
||||
out = F.relu(x)
|
||||
self.ref_out_list.append(out.data.clone())
|
||||
return out
|
||||
|
||||
def ref_replace(m):
|
||||
name = m.__class__.__name__
|
||||
if name.find('ReLU') != -1:
|
||||
m.forward = types.MethodType(ref_forward, m)
|
||||
|
||||
self.model.apply(init_refs)
|
||||
self.model.apply(ref_replace)
|
||||
|
||||
def _reset_preference(self):
|
||||
def reset_refs(m):
|
||||
name = m.__class__.__name__
|
||||
if name.find('ReLU') != -1:
|
||||
m.ref_inp_list = []
|
||||
m.ref_out_list = []
|
||||
|
||||
self.model.apply(reset_refs)
|
||||
|
||||
def _baseline_forward(self, inp):
|
||||
if self.baseline_inp is None:
|
||||
self.baseline_inp = inp.data.clone()
|
||||
self.baseline_inp.fill_(0.0)
|
||||
self.baseline_inp = Variable(self.baseline_inp)
|
||||
else:
|
||||
self.baseline_inp.fill_(0.0)
|
||||
# get ref
|
||||
_ = self.model(self.baseline_inp)
|
||||
|
||||
def _override_backward(self):
|
||||
def new_backward(self, grad_out):
|
||||
ref_inp, inp = self.ref_inp_list
|
||||
ref_out, out = self.ref_out_list
|
||||
delta_out = out - ref_out
|
||||
delta_in = inp - ref_inp
|
||||
g1 = (delta_in.abs() > 1e-5).float() * grad_out * \
|
||||
delta_out / delta_in
|
||||
mask = ((ref_inp + inp) > 0).float()
|
||||
g2 = (delta_in.abs() <= 1e-5).float() * 0.5 * mask * grad_out
|
||||
|
||||
return g1 + g2
|
||||
|
||||
def backward_replace(m):
|
||||
name = m.__class__.__name__
|
||||
if name.find('ReLU') != -1:
|
||||
m.backward = types.MethodType(new_backward, m)
|
||||
|
||||
self.model.apply(backward_replace)
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
self._reset_preference()
|
||||
self._baseline_forward(inp)
|
||||
g = super(DeepLIFTRescaleExplainer, self).explain(inp, ind)
|
||||
|
||||
return g
|
|
@ -0,0 +1,242 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .inverter_util import RelevancePropagator
|
||||
|
||||
|
||||
class EpsilonLrp(object):
|
||||
def __init__(self, model):
|
||||
self.model = InnvestigateModel(model)
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
predicitions, saliency = self.model.innvestigate(inp, ind)
|
||||
return saliency
|
||||
|
||||
|
||||
class InnvestigateModel(torch.nn.Module):
|
||||
"""
|
||||
ATTENTION:
|
||||
Currently, innvestigating a network only works if all
|
||||
layers that have to be inverted are specified explicitly
|
||||
and registered as a module. If., for example,
|
||||
only the functional max_poolnd is used, the inversion will not work.
|
||||
"""
|
||||
|
||||
def __init__(self, the_model, lrp_exponent=1, beta=.5, epsilon=1e-6,
|
||||
method="e-rule"):
|
||||
"""
|
||||
Model wrapper for pytorch models to 'innvestigate' them
|
||||
with layer-wise relevance propagation (LRP) as introduced by Bach et. al
|
||||
(https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140).
|
||||
Given a class level probability produced by the model under consideration,
|
||||
the LRP algorithm attributes this probability to the nodes in each layer.
|
||||
This allows for visualizing the relevance of input pixels on the resulting
|
||||
class probability.
|
||||
|
||||
Args:
|
||||
the_model: Pytorch model, e.g. a pytorch.nn.Sequential consisting of
|
||||
different layers. Not all layers are supported yet.
|
||||
lrp_exponent: Exponent for rescaling the importance values per node
|
||||
in a layer when using the e-rule method.
|
||||
beta: Beta value allows for placing more (large beta) emphasis on
|
||||
nodes that positively contribute to the activation of a given node
|
||||
in the subsequent layer. Low beta value allows for placing more emphasis
|
||||
on inhibitory neurons in a layer. Only relevant for method 'b-rule'.
|
||||
epsilon: Stabilizing term to avoid numerical instabilities if the norm (denominator
|
||||
for distributing the relevance) is close to zero.
|
||||
method: Different rules for the LRP algorithm, b-rule allows for placing
|
||||
more or less focus on positive / negative contributions, whereas
|
||||
the e-rule treats them equally. For more information,
|
||||
see the paper linked above.
|
||||
"""
|
||||
super(InnvestigateModel, self).__init__()
|
||||
self.model = the_model
|
||||
self.device = torch.device("cpu", 0)
|
||||
self.prediction = None
|
||||
self.r_values_per_layer = None
|
||||
self.only_max_score = None
|
||||
# Initialize the 'Relevance Propagator' with the chosen rule.
|
||||
# This will be used to back-propagate the relevance values
|
||||
# through the layers in the innvestigate method.
|
||||
self.inverter = RelevancePropagator(lrp_exponent=lrp_exponent,
|
||||
beta=beta, method=method, epsilon=epsilon,
|
||||
device=self.device)
|
||||
|
||||
# Parsing the individual model layers
|
||||
self.register_hooks(self.model)
|
||||
if method == "b-rule" and float(beta) in (-1., 0):
|
||||
which = "positive" if beta == -1 else "negative"
|
||||
which_opp = "negative" if beta == -1 else "positive"
|
||||
print("WARNING: With the chosen beta value, "
|
||||
"only " + which + " contributions "
|
||||
"will be taken into account.\nHence, "
|
||||
"if in any layer only " + which_opp +
|
||||
" contributions exist, the "
|
||||
"overall relevance will not be conserved.\n")
|
||||
|
||||
def cuda(self, device=None):
|
||||
self.device = torch.device("cuda", device)
|
||||
self.inverter.device = self.device
|
||||
return super(InnvestigateModel, self).cuda(device)
|
||||
|
||||
def cpu(self):
|
||||
self.device = torch.device("cpu", 0)
|
||||
self.inverter.device = self.device
|
||||
return super(InnvestigateModel, self).cpu()
|
||||
|
||||
def register_hooks(self, parent_module):
|
||||
"""
|
||||
Recursively unrolls a model and registers the required
|
||||
hooks to save all the necessary values for LRP in the forward pass.
|
||||
|
||||
Args:
|
||||
parent_module: Model to unroll and register hooks for.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
for mod in parent_module.children():
|
||||
if list(mod.children()):
|
||||
self.register_hooks(mod)
|
||||
continue
|
||||
mod.register_forward_hook(
|
||||
self.inverter.get_layer_fwd_hook(mod))
|
||||
if isinstance(mod, torch.nn.ReLU):
|
||||
mod.register_backward_hook(
|
||||
self.relu_hook_function
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def relu_hook_function(module, grad_in, grad_out):
|
||||
"""
|
||||
If there is a negative gradient, change it to zero.
|
||||
"""
|
||||
return (torch.clamp(grad_in[0], min=0.0),)
|
||||
|
||||
def __call__(self, in_tensor):
|
||||
"""
|
||||
The innvestigate wrapper returns the same prediction as the
|
||||
original model, but wraps the model call method in the evaluate
|
||||
method to save the last prediction.
|
||||
|
||||
Args:
|
||||
in_tensor: Model input to pass through the pytorch model.
|
||||
|
||||
Returns:
|
||||
Model output.
|
||||
"""
|
||||
return self.evaluate(in_tensor)
|
||||
|
||||
def evaluate(self, in_tensor):
|
||||
"""
|
||||
Evaluates the model on a new input. The registered forward hooks will
|
||||
save all the data that is necessary to compute the relevance per neuron per layer.
|
||||
|
||||
Args:
|
||||
in_tensor: New input for which to predict an output.
|
||||
|
||||
Returns:
|
||||
Model prediction
|
||||
"""
|
||||
# Reset module list. In case the structure changes dynamically,
|
||||
# the module list is tracked for every forward pass.
|
||||
self.inverter.reset_module_list()
|
||||
self.prediction = self.model(in_tensor)
|
||||
return self.prediction
|
||||
|
||||
def get_r_values_per_layer(self):
|
||||
if self.r_values_per_layer is None:
|
||||
print("No relevances have been calculated yet, returning None in"
|
||||
" get_r_values_per_layer.")
|
||||
return self.r_values_per_layer
|
||||
|
||||
def innvestigate(self, in_tensor=None, rel_for_class=None):
|
||||
"""
|
||||
Method for 'innvestigating' the model with the LRP rule chosen at
|
||||
the initialization of the InnvestigateModel.
|
||||
Args:
|
||||
in_tensor: Input for which to evaluate the LRP algorithm.
|
||||
If input is None, the last evaluation is used.
|
||||
If no evaluation has been performed since initialization,
|
||||
an error is raised.
|
||||
rel_for_class (int): Index of the class for which the relevance
|
||||
distribution is to be analyzed. If None, the 'winning' class
|
||||
is used for indexing.
|
||||
|
||||
Returns:
|
||||
Model output and relevances of nodes in the input layer.
|
||||
In order to get relevance distributions in other layers, use
|
||||
the get_r_values_per_layer method.
|
||||
"""
|
||||
if self.r_values_per_layer is not None:
|
||||
for elt in self.r_values_per_layer:
|
||||
del elt
|
||||
self.r_values_per_layer = None
|
||||
|
||||
with torch.no_grad():
|
||||
# Check if innvestigation can be performed.
|
||||
if in_tensor is None and self.prediction is None:
|
||||
raise RuntimeError("Model needs to be evaluated at least "
|
||||
"once before an innvestigation can be "
|
||||
"performed. Please evaluate model first "
|
||||
"or call innvestigate with a new input to "
|
||||
"evaluate.")
|
||||
|
||||
# Evaluate the model anew if a new input is supplied.
|
||||
if in_tensor is not None:
|
||||
self.evaluate(in_tensor)
|
||||
|
||||
# If no class index is specified, analyze for class
|
||||
# with highest prediction.
|
||||
if rel_for_class is None:
|
||||
# Default behaviour is innvestigating the output
|
||||
# on an arg-max-basis, if no class is specified.
|
||||
org_shape = self.prediction.size()
|
||||
# Make sure shape is just a 1D vector per batch example.
|
||||
self.prediction = self.prediction.view(org_shape[0], -1)
|
||||
max_v, _ = torch.max(self.prediction, dim=1, keepdim=True)
|
||||
only_max_score = torch.zeros_like(self.prediction).to(self.device)
|
||||
only_max_score[max_v == self.prediction] = self.prediction[max_v == self.prediction]
|
||||
relevance_tensor = only_max_score.view(org_shape)
|
||||
self.prediction.view(org_shape)
|
||||
|
||||
else:
|
||||
org_shape = self.prediction.size()
|
||||
self.prediction = self.prediction.view(org_shape[0], -1)
|
||||
only_max_score = torch.zeros_like(self.prediction).to(self.device)
|
||||
only_max_score[:, rel_for_class] += self.prediction[:, rel_for_class]
|
||||
relevance_tensor = only_max_score.view(org_shape)
|
||||
self.prediction.view(org_shape)
|
||||
|
||||
# We have to iterate through the model backwards.
|
||||
# The module list is computed for every forward pass
|
||||
# by the model inverter.
|
||||
rev_model = self.inverter.module_list[::-1]
|
||||
relevance = relevance_tensor.detach()
|
||||
del relevance_tensor
|
||||
# List to save relevance distributions per layer
|
||||
r_values_per_layer = [relevance]
|
||||
for layer in rev_model:
|
||||
# Compute layer specific backwards-propagation of relevance values
|
||||
relevance = self.inverter.compute_propagated_relevance(layer, relevance)
|
||||
r_values_per_layer.append(relevance.cpu())
|
||||
|
||||
self.r_values_per_layer = r_values_per_layer
|
||||
|
||||
del relevance
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
return self.prediction, r_values_per_layer[-1]
|
||||
|
||||
def forward(self, in_tensor):
|
||||
return self.model.forward(in_tensor)
|
||||
|
||||
def extra_repr(self):
|
||||
r"""Set the extra representation of the module
|
||||
|
||||
To print customized extra information, you should re-implement
|
||||
this method in your own modules. Both single-line and multi-line
|
||||
strings are acceptable.
|
||||
"""
|
||||
return self.model.extra_repr()
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
from .backprop import VanillaGradExplainer
|
||||
|
||||
|
||||
def _get_layer(model, key_list):
|
||||
a = model
|
||||
for key in key_list:
|
||||
a = a._modules[key]
|
||||
return a
|
||||
|
||||
class GradCAMExplainer(VanillaGradExplainer):
|
||||
def __init__(self, model, target_layer_name_keys=None, use_inp=False):
|
||||
super(GradCAMExplainer, self).__init__(model)
|
||||
self.target_layer = _get_layer(model, target_layer_name_keys)
|
||||
self.use_inp = use_inp
|
||||
self.intermediate_act = []
|
||||
self.intermediate_grad = []
|
||||
self._register_forward_backward_hook()
|
||||
|
||||
def _register_forward_backward_hook(self):
|
||||
def forward_hook_input(m, i, o):
|
||||
self.intermediate_act.append(i[0].data.clone())
|
||||
|
||||
def forward_hook_output(m, i, o):
|
||||
self.intermediate_act.append(o.data.clone())
|
||||
|
||||
def backward_hook(m, grad_i, grad_o):
|
||||
self.intermediate_grad.append(grad_o[0].data.clone())
|
||||
|
||||
if self.use_inp:
|
||||
self.target_layer.register_forward_hook(forward_hook_input)
|
||||
else:
|
||||
self.target_layer.register_forward_hook(forward_hook_output)
|
||||
|
||||
self.target_layer.register_backward_hook(backward_hook)
|
||||
|
||||
def _reset_intermediate_lists(self):
|
||||
self.intermediate_act = []
|
||||
self.intermediate_grad = []
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
self._reset_intermediate_lists()
|
||||
|
||||
_ = super(GradCAMExplainer, self)._backprop(inp, ind)
|
||||
|
||||
grad = self.intermediate_grad[0]
|
||||
act = self.intermediate_act[0]
|
||||
|
||||
weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
cam = weights * act
|
||||
cam = cam.sum(1).unsqueeze(1)
|
||||
|
||||
cam = torch.clamp(cam, min=0)
|
||||
|
||||
return cam
|
||||
|
|
@ -0,0 +1,526 @@
|
|||
import torch
|
||||
import torch.nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Flatten(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def forward(self, in_tensor):
|
||||
return in_tensor.view((in_tensor.size()[0], -1))
|
||||
|
||||
|
||||
def module_tracker(fwd_hook_func):
|
||||
"""
|
||||
Wrapper for tracking the layers throughout the forward pass.
|
||||
|
||||
Args:
|
||||
fwd_hook_func: Forward hook function to be wrapped.
|
||||
|
||||
Returns:
|
||||
Wrapped method.
|
||||
|
||||
"""
|
||||
def hook_wrapper(relevance_propagator_instance, layer, *args):
|
||||
relevance_propagator_instance.module_list.append(layer)
|
||||
return fwd_hook_func(relevance_propagator_instance, layer, *args)
|
||||
|
||||
return hook_wrapper
|
||||
|
||||
|
||||
class RelevancePropagator:
|
||||
"""
|
||||
Class for computing the relevance propagation and supplying
|
||||
the necessary forward hooks for all layers.
|
||||
"""
|
||||
|
||||
# All layers that do not require any specific forward hooks.
|
||||
# This is due to the fact that they are all one-to-one
|
||||
# mappings and hence no normalization is needed (each node only
|
||||
# influences exactly one other node -> relevance conservation
|
||||
# ensures that relevance is just inherited in a one-to-one manner, too).
|
||||
allowed_pass_layers = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
|
||||
torch.nn.BatchNorm3d,
|
||||
torch.nn.ReLU, torch.nn.ELU, Flatten,
|
||||
torch.nn.Dropout, torch.nn.Dropout2d,
|
||||
torch.nn.Dropout3d,
|
||||
torch.nn.Softmax,
|
||||
torch.nn.LogSoftmax)
|
||||
# Implemented rules for relevance propagation.
|
||||
available_methods = ["e-rule", "b-rule"]
|
||||
|
||||
def __init__(self, lrp_exponent, beta, method, epsilon, device):
|
||||
|
||||
self.device = device
|
||||
self.layer = None
|
||||
self.p = lrp_exponent
|
||||
self.beta = beta
|
||||
self.eps = epsilon
|
||||
self.warned_log_softmax = False
|
||||
self.module_list = []
|
||||
if method not in self.available_methods:
|
||||
raise NotImplementedError("Only methods available are: " +
|
||||
str(self.available_methods))
|
||||
self.method = method
|
||||
|
||||
def reset_module_list(self):
|
||||
"""
|
||||
The module list is reset for every evaluation, in change the order or number
|
||||
of layers changes dynamically.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
self.module_list = []
|
||||
# Try to free memory
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def compute_propagated_relevance(self, layer, relevance):
|
||||
"""
|
||||
This method computes the backward pass for the incoming relevance
|
||||
for the specified layer.
|
||||
|
||||
Args:
|
||||
layer: Layer to be reverted.
|
||||
relevance: Incoming relevance from higher up in the network.
|
||||
|
||||
Returns:
|
||||
The
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(layer,
|
||||
(torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)):
|
||||
return self.max_pool_nd_inverse(layer, relevance).detach()
|
||||
|
||||
elif isinstance(layer,
|
||||
(torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
|
||||
return self.conv_nd_inverse(layer, relevance).detach()
|
||||
|
||||
elif isinstance(layer, torch.nn.LogSoftmax):
|
||||
# Only layer that does not conserve relevance. Mainly used
|
||||
# to make probability out of the log values. Should probably
|
||||
# be changed to pure passing and the user should make sure
|
||||
# the layer outputs are sensible (0 would be 100% class probability,
|
||||
# but no relevance could be passed on).
|
||||
if relevance.sum() < 0:
|
||||
relevance[relevance == 0] = -1e6
|
||||
relevance = relevance.exp()
|
||||
if not self.warned_log_softmax:
|
||||
print("WARNING: LogSoftmax layer was "
|
||||
"turned into probabilities.")
|
||||
self.warned_log_softmax = True
|
||||
return relevance
|
||||
elif isinstance(layer, self.allowed_pass_layers):
|
||||
# The above layers are one-to-one mappings of input to
|
||||
# output nodes. All the relevance in the output will come
|
||||
# entirely from the input node. Given the conservation
|
||||
# of relevance, the input is as relevant as the output.
|
||||
return relevance
|
||||
|
||||
elif isinstance(layer, torch.nn.Linear):
|
||||
return self.linear_inverse(layer, relevance).detach()
|
||||
else:
|
||||
raise NotImplementedError("The network contains layers that"
|
||||
" are currently not supported {0:s}".format(str(layer)))
|
||||
|
||||
def get_layer_fwd_hook(self, layer):
|
||||
"""
|
||||
Each layer might need to save very specific data during the forward
|
||||
pass in order to allow for relevance propagation in the backward
|
||||
pass. For example, for max_pooling, we need to store the
|
||||
indices of the max values. In convolutional layers, we need to calculate
|
||||
the normalizations, to ensure the overall amount of relevance is conserved.
|
||||
|
||||
Args:
|
||||
layer: Layer instance for which forward hook is needed.
|
||||
|
||||
Returns:
|
||||
Layer-specific forward hook.
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(layer,
|
||||
(torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d)):
|
||||
return self.max_pool_nd_fwd_hook
|
||||
|
||||
if isinstance(layer,
|
||||
(torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)):
|
||||
return self.conv_nd_fwd_hook
|
||||
|
||||
if isinstance(layer, self.allowed_pass_layers):
|
||||
return self.silent_pass # No hook needed.
|
||||
|
||||
if isinstance(layer, torch.nn.Linear):
|
||||
return self.linear_fwd_hook
|
||||
|
||||
else:
|
||||
raise NotImplementedError("The network contains layers that"
|
||||
" are currently not supported {0:s}".format(str(layer)))
|
||||
|
||||
@staticmethod
|
||||
def get_conv_method(conv_module):
|
||||
"""
|
||||
Get dimension-specific convolution.
|
||||
The forward pass and inversion are made in a
|
||||
'dimensionality-agnostic' manner and are the same for
|
||||
all nd instances of the layer, except for the functional
|
||||
that needs to be used.
|
||||
|
||||
Args:
|
||||
conv_module: instance of convolutional layer.
|
||||
|
||||
Returns:
|
||||
The correct functional used in the convolutional layer.
|
||||
|
||||
"""
|
||||
|
||||
conv_func_mapper = {
|
||||
torch.nn.Conv1d: F.conv1d,
|
||||
torch.nn.Conv2d: F.conv2d,
|
||||
torch.nn.Conv3d: F.conv3d
|
||||
}
|
||||
return conv_func_mapper[type(conv_module)]
|
||||
|
||||
@staticmethod
|
||||
def get_inv_conv_method(conv_module):
|
||||
"""
|
||||
Get dimension-specific convolution inversion layer.
|
||||
The forward pass and inversion are made in a
|
||||
'dimensionality-agnostic' manner and are the same for
|
||||
all nd instances of the layer, except for the functional
|
||||
that needs to be used.
|
||||
|
||||
Args:
|
||||
conv_module: instance of convolutional layer.
|
||||
|
||||
Returns:
|
||||
The correct functional used for inverting the convolutional layer.
|
||||
|
||||
"""
|
||||
|
||||
conv_func_mapper = {
|
||||
torch.nn.Conv1d: F.conv_transpose1d,
|
||||
torch.nn.Conv2d: F.conv_transpose2d,
|
||||
torch.nn.Conv3d: F.conv_transpose3d
|
||||
}
|
||||
return conv_func_mapper[type(conv_module)]
|
||||
|
||||
@module_tracker
|
||||
def silent_pass(self, m, in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor):
|
||||
# Placeholder forward hook for layers that do not need
|
||||
# to store any specific data. Still useful for module tracking.
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_inv_max_pool_method(max_pool_instance):
|
||||
"""
|
||||
Get dimension-specific max_pooling layer.
|
||||
The forward pass and inversion are made in a
|
||||
'dimensionality-agnostic' manner and are the same for
|
||||
all nd instances of the layer, except for the functional
|
||||
that needs to be used.
|
||||
|
||||
Args:
|
||||
max_pool_instance: instance of max_pool layer.
|
||||
|
||||
Returns:
|
||||
The correct functional used in the max_pooling layer.
|
||||
|
||||
"""
|
||||
|
||||
conv_func_mapper = {
|
||||
torch.nn.MaxPool1d: F.max_unpool1d,
|
||||
torch.nn.MaxPool2d: F.max_unpool2d,
|
||||
torch.nn.MaxPool3d: F.max_unpool3d
|
||||
}
|
||||
return conv_func_mapper[type(max_pool_instance)]
|
||||
|
||||
def linear_inverse(self, m, relevance_in):
|
||||
|
||||
if self.method == "e-rule":
|
||||
m.in_tensor = m.in_tensor.pow(self.p)
|
||||
w = m.weight.pow(self.p)
|
||||
norm = F.linear(m.in_tensor, w, bias=None)
|
||||
|
||||
norm = norm + torch.sign(norm) * self.eps
|
||||
relevance_in[norm == 0] = 0
|
||||
norm[norm == 0] = 1
|
||||
relevance_out = F.linear(relevance_in / norm,
|
||||
w.t(), bias=None)
|
||||
relevance_out *= m.in_tensor
|
||||
del m.in_tensor, norm, w, relevance_in
|
||||
return relevance_out
|
||||
|
||||
if self.method == "b-rule":
|
||||
out_c, in_c = m.weight.size()
|
||||
w = m.weight.repeat((4, 1))
|
||||
# First and third channel repetition only contain the positive weights
|
||||
w[:out_c][w[:out_c] < 0] = 0
|
||||
w[2 * out_c:3 * out_c][w[2 * out_c:3 * out_c] < 0] = 0
|
||||
# Second and fourth channel repetition with only the negative weights
|
||||
w[1 * out_c:2 * out_c][w[1 * out_c:2 * out_c] > 0] = 0
|
||||
w[-out_c:][w[-out_c:] > 0] = 0
|
||||
|
||||
# Repeat across channel dimension (pytorch always has channels first)
|
||||
m.in_tensor = m.in_tensor.repeat((1, 4))
|
||||
m.in_tensor[:, :in_c][m.in_tensor[:, :in_c] < 0] = 0
|
||||
m.in_tensor[:, -in_c:][m.in_tensor[:, -in_c:] < 0] = 0
|
||||
m.in_tensor[:, 1 * in_c:3 * in_c][m.in_tensor[:, 1 * in_c:3 * in_c] > 0] = 0
|
||||
|
||||
# Normalize such that the sum of the individual importance values
|
||||
# of the input neurons divided by the norm
|
||||
# yields 1 for an output neuron j if divided by norm (v_ij in paper).
|
||||
# Norm layer just sums the importance values of the inputs
|
||||
# contributing to output j for each j. This will then serve as the normalization
|
||||
# such that the contributions of the neurons sum to 1 in order to
|
||||
# properly split up the relevance of j amongst its roots.
|
||||
|
||||
norm_shape = m.out_shape
|
||||
norm_shape[1] *= 4
|
||||
norm = torch.zeros(norm_shape).to(self.device)
|
||||
|
||||
for i in range(4):
|
||||
norm[:, out_c * i:(i + 1) * out_c] = F.linear(
|
||||
m.in_tensor[:, in_c * i:(i + 1) * in_c], w[out_c * i:(i + 1) * out_c], bias=None)
|
||||
|
||||
# Double number of output channels for positive and negative norm per
|
||||
# channel.
|
||||
norm_shape[1] = norm_shape[1] // 2
|
||||
new_norm = torch.zeros(norm_shape).to(self.device)
|
||||
new_norm[:, :out_c] = norm[:, :out_c] + norm[:, out_c:2 * out_c]
|
||||
new_norm[:, out_c:] = norm[:, 2 * out_c:3 * out_c] + norm[:, 3 * out_c:]
|
||||
norm = new_norm
|
||||
|
||||
# Some 'rare' neurons only receive either
|
||||
# only positive or only negative inputs.
|
||||
# Conservation of relevance does not hold, if we also
|
||||
# rescale those neurons by (1+beta) or -beta.
|
||||
# Therefore, catch those first and scale norm by
|
||||
# the according value, such that it cancels in the fraction.
|
||||
|
||||
# First, however, avoid NaNs.
|
||||
mask = norm == 0
|
||||
# Set the norm to anything non-zero, e.g. 1.
|
||||
# The actual inputs are zero at this point anyways, that
|
||||
# is why norm is zero in the first place.
|
||||
norm[mask] = 1
|
||||
# The norm in the b-rule has shape (N, 2*out_c, *spatial_dims).
|
||||
# The first out_c block corresponds to the positive norms,
|
||||
# the second out_c block corresponds to the negative norms.
|
||||
# We find the rare neurons by choosing those nodes per channel
|
||||
# in which either the positive norm ([:, :out_c]) is zero, or
|
||||
# the negative norm ([:, :out_c]) is zero.
|
||||
rare_neurons = (mask[:, :out_c] + mask[:, out_c:])
|
||||
|
||||
# Also, catch new possibilities for norm == zero to avoid NaN..
|
||||
# The actual value of norm again does not really matter, since
|
||||
# the pre-factor will be zero in this case.
|
||||
|
||||
norm[:, :out_c][rare_neurons] *= 1 if self.beta == -1 else 1 + self.beta
|
||||
norm[:, out_c:][rare_neurons] *= 1 if self.beta == 0 else -self.beta
|
||||
# Add stabilizer term to norm to avoid numerical instabilities.
|
||||
norm += self.eps * torch.sign(norm)
|
||||
input_relevance = relevance_in.squeeze(dim=-1).repeat(1, 4)
|
||||
input_relevance[:, :2*out_c] *= (1+self.beta)/norm[:, :out_c].repeat(1, 2)
|
||||
input_relevance[:, 2*out_c:] *= -self.beta/norm[:, out_c:].repeat(1, 2)
|
||||
inv_w = w.t()
|
||||
relevance_out = torch.zeros_like(m.in_tensor)
|
||||
for i in range(4):
|
||||
relevance_out[:, i*in_c:(i+1)*in_c] = F.linear(
|
||||
input_relevance[:, i*out_c:(i+1)*out_c],
|
||||
weight=inv_w[:, i*out_c:(i+1)*out_c], bias=None)
|
||||
|
||||
relevance_out *= m.in_tensor
|
||||
|
||||
sum_weights = torch.zeros([in_c, in_c * 4, 1]).to(self.device)
|
||||
for i in range(in_c):
|
||||
sum_weights[i, i::in_c] = 1
|
||||
relevance_out = F.conv1d(relevance_out[:, :, None], weight=sum_weights, bias=None)
|
||||
|
||||
del sum_weights, input_relevance, norm, rare_neurons, \
|
||||
mask, new_norm, m.in_tensor, w, inv_w
|
||||
|
||||
return relevance_out
|
||||
|
||||
@module_tracker
|
||||
def linear_fwd_hook(self, m, in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor):
|
||||
|
||||
setattr(m, "in_tensor", in_tensor[0])
|
||||
setattr(m, "out_shape", list(out_tensor.size()))
|
||||
return
|
||||
|
||||
def max_pool_nd_inverse(self, layer_instance, relevance_in):
|
||||
|
||||
# In case the output had been reshaped for a linear layer,
|
||||
# make sure the relevance is put into the same shape as before.
|
||||
relevance_in = relevance_in.view(layer_instance.out_shape)
|
||||
|
||||
invert_pool = self.get_inv_max_pool_method(layer_instance)
|
||||
inverted = invert_pool(relevance_in, layer_instance.indices,
|
||||
layer_instance.kernel_size, layer_instance.stride,
|
||||
layer_instance.padding, output_size=layer_instance.in_shape)
|
||||
del layer_instance.indices
|
||||
|
||||
return inverted
|
||||
|
||||
@module_tracker
|
||||
def max_pool_nd_fwd_hook(self, m, in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor):
|
||||
# Ignore unused for pylint
|
||||
_ = self
|
||||
|
||||
# Save the return indices value to make sure
|
||||
tmp_return_indices = bool(m.return_indices)
|
||||
m.return_indices = True
|
||||
_, indices = m.forward(in_tensor[0])
|
||||
m.return_indices = tmp_return_indices
|
||||
setattr(m, "indices", indices)
|
||||
setattr(m, 'out_shape', out_tensor.size())
|
||||
setattr(m, 'in_shape', in_tensor[0].size())
|
||||
|
||||
def conv_nd_inverse(self, m, relevance_in):
|
||||
|
||||
# In case the output had been reshaped for a linear layer,
|
||||
# make sure the relevance is put into the same shape as before.
|
||||
relevance_in = relevance_in.view(m.out_shape)
|
||||
|
||||
# Get required values from layer
|
||||
inv_conv_nd = self.get_inv_conv_method(m)
|
||||
conv_nd = self.get_conv_method(m)
|
||||
|
||||
if self.method == "e-rule":
|
||||
with torch.no_grad():
|
||||
m.in_tensor = m.in_tensor.pow(self.p).detach()
|
||||
w = m.weight.pow(self.p).detach()
|
||||
norm = conv_nd(m.in_tensor, weight=w, bias=None,
|
||||
stride=m.stride, padding=m.padding,
|
||||
groups=m.groups)
|
||||
|
||||
norm = norm + torch.sign(norm) * self.eps
|
||||
relevance_in[norm == 0] = 0
|
||||
norm[norm == 0] = 1
|
||||
relevance_out = inv_conv_nd(relevance_in/norm,
|
||||
weight=w, bias=None,
|
||||
padding=m.padding, stride=m.stride,
|
||||
groups=m.groups)
|
||||
relevance_out *= m.in_tensor
|
||||
del m.in_tensor, norm, w
|
||||
return relevance_out
|
||||
|
||||
if self.method == "b-rule":
|
||||
with torch.no_grad():
|
||||
w = m.weight
|
||||
|
||||
out_c, in_c = m.out_channels, m.in_channels
|
||||
repeats = np.array(np.ones_like(w.size()).flatten(), dtype=int)
|
||||
repeats[0] *= 4
|
||||
w = w.repeat(tuple(repeats))
|
||||
# First and third channel repetition only contain the positive weights
|
||||
w[:out_c][w[:out_c] < 0] = 0
|
||||
w[2 * out_c:3 * out_c][w[2 * out_c:3 * out_c] < 0] = 0
|
||||
# Second and fourth channel repetition with only the negative weights
|
||||
w[1 * out_c:2 * out_c][w[1 * out_c:2 * out_c] > 0] = 0
|
||||
w[-out_c:][w[-out_c:] > 0] = 0
|
||||
repeats = np.array(np.ones_like(m.in_tensor.size()).flatten(), dtype=int)
|
||||
repeats[1] *= 4
|
||||
# Repeat across channel dimension (pytorch always has channels first)
|
||||
m.in_tensor = m.in_tensor.repeat(tuple(repeats))
|
||||
m.in_tensor[:, :in_c][m.in_tensor[:, :in_c] < 0] = 0
|
||||
m.in_tensor[:, -in_c:][m.in_tensor[:, -in_c:] < 0] = 0
|
||||
m.in_tensor[:, 1 * in_c:3 * in_c][m.in_tensor[:, 1 * in_c:3 * in_c] > 0] = 0
|
||||
groups = 4
|
||||
|
||||
# Normalize such that the sum of the individual importance values
|
||||
# of the input neurons divided by the norm
|
||||
# yields 1 for an output neuron j if divided by norm (v_ij in paper).
|
||||
# Norm layer just sums the importance values of the inputs
|
||||
# contributing to output j for each j. This will then serve as the normalization
|
||||
# such that the contributions of the neurons sum to 1 in order to
|
||||
# properly split up the relevance of j amongst its roots.
|
||||
norm = conv_nd(m.in_tensor, weight=w, bias=None, stride=m.stride,
|
||||
padding=m.padding, dilation=m.dilation, groups=groups * m.groups)
|
||||
# Double number of output channels for positive and negative norm per
|
||||
# channel. Using list with out_tensor.size() allows for ND generalization
|
||||
new_shape = m.out_shape
|
||||
new_shape[1] *= 2
|
||||
new_norm = torch.zeros(new_shape).to(self.device)
|
||||
new_norm[:, :out_c] = norm[:, :out_c] + norm[:, out_c:2 * out_c]
|
||||
new_norm[:, out_c:] = norm[:, 2 * out_c:3 * out_c] + norm[:, 3 * out_c:]
|
||||
norm = new_norm
|
||||
# Some 'rare' neurons only receive either
|
||||
# only positive or only negative inputs.
|
||||
# Conservation of relevance does not hold, if we also
|
||||
# rescale those neurons by (1+beta) or -beta.
|
||||
# Therefore, catch those first and scale norm by
|
||||
# the according value, such that it cancels in the fraction.
|
||||
|
||||
# First, however, avoid NaNs.
|
||||
mask = norm == 0
|
||||
# Set the norm to anything non-zero, e.g. 1.
|
||||
# The actual inputs are zero at this point anyways, that
|
||||
# is why norm is zero in the first place.
|
||||
norm[mask] = 1
|
||||
# The norm in the b-rule has shape (N, 2*out_c, *spatial_dims).
|
||||
# The first out_c block corresponds to the positive norms,
|
||||
# the second out_c block corresponds to the negative norms.
|
||||
# We find the rare neurons by choosing those nodes per channel
|
||||
# in which either the positive norm ([:, :out_c]) is zero, or
|
||||
# the negative norm ([:, :out_c]) is zero.
|
||||
rare_neurons = (mask[:, :out_c] + mask[:, out_c:])
|
||||
|
||||
# Also, catch new possibilities for norm == zero to avoid NaN..
|
||||
# The actual value of norm again does not really matter, since
|
||||
# the pre-factor will be zero in this case.
|
||||
|
||||
norm[:, :out_c][rare_neurons] *= 1 if self.beta == -1 else 1 + self.beta
|
||||
norm[:, out_c:][rare_neurons] *= 1 if self.beta == 0 else -self.beta
|
||||
# Add stabilizer term to norm to avoid numerical instabilities.
|
||||
norm += self.eps * torch.sign(norm)
|
||||
spatial_dims = [1] * len(relevance_in.size()[2:])
|
||||
|
||||
input_relevance = relevance_in.repeat(1, 4, *spatial_dims)
|
||||
input_relevance[:, :2*out_c] *= (1+self.beta)/norm[:, :out_c].repeat(1, 2, *spatial_dims)
|
||||
input_relevance[:, 2*out_c:] *= -self.beta/norm[:, out_c:].repeat(1, 2, *spatial_dims)
|
||||
# Each of the positive / negative entries needs its own
|
||||
# convolution. TODO: Can this be done in groups, too?
|
||||
|
||||
relevance_out = torch.zeros_like(m.in_tensor)
|
||||
# Weird code to make up for loss of size due to stride
|
||||
tmp_result = result = None
|
||||
for i in range(4):
|
||||
tmp_result = inv_conv_nd(
|
||||
input_relevance[:, i*out_c:(i+1)*out_c],
|
||||
weight=w[i*out_c:(i+1)*out_c],
|
||||
bias=None, padding=m.padding, stride=m.stride,
|
||||
groups=m.groups)
|
||||
result = torch.zeros_like(relevance_out[:, i*in_c:(i+1)*in_c])
|
||||
tmp_size = tmp_result.size()
|
||||
slice_list = [slice(0, l) for l in tmp_size]
|
||||
result[slice_list] += tmp_result
|
||||
relevance_out[:, i*in_c:(i+1)*in_c] = result
|
||||
relevance_out *= m.in_tensor
|
||||
|
||||
sum_weights = torch.zeros([in_c, in_c * 4, *spatial_dims]).to(self.device)
|
||||
for i in range(m.in_channels):
|
||||
sum_weights[i, i::in_c] = 1
|
||||
relevance_out = conv_nd(relevance_out, weight=sum_weights, bias=None)
|
||||
|
||||
del sum_weights, m.in_tensor, result, mask, rare_neurons, norm, \
|
||||
new_norm, input_relevance, tmp_result, w
|
||||
|
||||
return relevance_out
|
||||
|
||||
@module_tracker
|
||||
def conv_nd_fwd_hook(self, m, in_tensor: torch.Tensor,
|
||||
out_tensor: torch.Tensor):
|
||||
|
||||
setattr(m, "in_tensor", in_tensor[0])
|
||||
setattr(m, 'out_shape', list(out_tensor.size()))
|
||||
return
|
|
@ -0,0 +1,179 @@
|
|||
"""
|
||||
Contains abstract functionality for learning locally linear sparse model.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
from sklearn.linear_model import Ridge, lars_path
|
||||
from sklearn.utils import check_random_state
|
||||
|
||||
|
||||
class LimeBase(object):
|
||||
"""Class for learning a locally linear sparse model from perturbed data"""
|
||||
def __init__(self,
|
||||
kernel_fn,
|
||||
verbose=False,
|
||||
random_state=None):
|
||||
"""Init function
|
||||
|
||||
Args:
|
||||
kernel_fn: function that transforms an array of distances into an
|
||||
array of proximity values (floats).
|
||||
verbose: if true, print local prediction values from linear model.
|
||||
random_state: an integer or numpy.RandomState that will be used to
|
||||
generate random numbers. If None, the random state will be
|
||||
initialized using the internal numpy seed.
|
||||
"""
|
||||
self.kernel_fn = kernel_fn
|
||||
self.verbose = verbose
|
||||
self.random_state = check_random_state(random_state)
|
||||
|
||||
@staticmethod
|
||||
def generate_lars_path(weighted_data, weighted_labels):
|
||||
"""Generates the lars path for weighted data.
|
||||
|
||||
Args:
|
||||
weighted_data: data that has been weighted by kernel
|
||||
weighted_label: labels, weighted by kernel
|
||||
|
||||
Returns:
|
||||
(alphas, coefs), both are arrays corresponding to the
|
||||
regularization parameter and coefficients, respectively
|
||||
"""
|
||||
x_vector = weighted_data
|
||||
alphas, _, coefs = lars_path(x_vector,
|
||||
weighted_labels,
|
||||
method='lasso',
|
||||
verbose=False)
|
||||
return alphas, coefs
|
||||
|
||||
def forward_selection(self, data, labels, weights, num_features):
|
||||
"""Iteratively adds features to the model"""
|
||||
clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
|
||||
used_features = []
|
||||
for _ in range(min(num_features, data.shape[1])):
|
||||
max_ = -100000000
|
||||
best = 0
|
||||
for feature in range(data.shape[1]):
|
||||
if feature in used_features:
|
||||
continue
|
||||
clf.fit(data[:, used_features + [feature]], labels,
|
||||
sample_weight=weights)
|
||||
score = clf.score(data[:, used_features + [feature]],
|
||||
labels,
|
||||
sample_weight=weights)
|
||||
if score > max_:
|
||||
best = feature
|
||||
max_ = score
|
||||
used_features.append(best)
|
||||
return np.array(used_features)
|
||||
|
||||
def feature_selection(self, data, labels, weights, num_features, method):
|
||||
"""Selects features for the model. see explain_instance_with_data to
|
||||
understand the parameters."""
|
||||
if method == 'none':
|
||||
return np.array(range(data.shape[1]))
|
||||
elif method == 'forward_selection':
|
||||
return self.forward_selection(data, labels, weights, num_features)
|
||||
elif method == 'highest_weights':
|
||||
clf = Ridge(alpha=0, fit_intercept=True,
|
||||
random_state=self.random_state)
|
||||
clf.fit(data, labels, sample_weight=weights)
|
||||
feature_weights = sorted(zip(range(data.shape[0]),
|
||||
clf.coef_ * data[0]),
|
||||
key=lambda x: np.abs(x[1]),
|
||||
reverse=True)
|
||||
return np.array([x[0] for x in feature_weights[:num_features]])
|
||||
elif method == 'lasso_path':
|
||||
weighted_data = ((data - np.average(data, axis=0, weights=weights))
|
||||
* np.sqrt(weights[:, np.newaxis]))
|
||||
weighted_labels = ((labels - np.average(labels, weights=weights))
|
||||
* np.sqrt(weights))
|
||||
nonzero = range(weighted_data.shape[1])
|
||||
_, coefs = self.generate_lars_path(weighted_data,
|
||||
weighted_labels)
|
||||
for i in range(len(coefs.T) - 1, 0, -1):
|
||||
nonzero = coefs.T[i].nonzero()[0]
|
||||
if len(nonzero) <= num_features:
|
||||
break
|
||||
used_features = nonzero
|
||||
return used_features
|
||||
elif method == 'auto':
|
||||
if num_features <= 6:
|
||||
n_method = 'forward_selection'
|
||||
else:
|
||||
n_method = 'highest_weights'
|
||||
return self.feature_selection(data, labels, weights,
|
||||
num_features, n_method)
|
||||
|
||||
def explain_instance_with_data(self,
|
||||
neighborhood_data,
|
||||
neighborhood_labels,
|
||||
distances,
|
||||
label,
|
||||
num_features,
|
||||
feature_selection='auto',
|
||||
model_regressor=None):
|
||||
"""Takes perturbed data, labels and distances, returns explanation.
|
||||
|
||||
Args:
|
||||
neighborhood_data: perturbed data, 2d array. first element is
|
||||
assumed to be the original data point.
|
||||
neighborhood_labels: corresponding perturbed labels. should have as
|
||||
many columns as the number of possible labels.
|
||||
distances: distances to original data point.
|
||||
label: label for which we want an explanation
|
||||
num_features: maximum number of features in explanation
|
||||
feature_selection: how to select num_features. options are:
|
||||
'forward_selection': iteratively add features to the model.
|
||||
This is costly when num_features is high
|
||||
'highest_weights': selects the features that have the highest
|
||||
product of absolute weight * original data point when
|
||||
learning with all the features
|
||||
'lasso_path': chooses features based on the lasso
|
||||
regularization path
|
||||
'none': uses all features, ignores num_features
|
||||
'auto': uses forward_selection if num_features <= 6, and
|
||||
'highest_weights' otherwise.
|
||||
model_regressor: sklearn regressor to use in explanation.
|
||||
Defaults to Ridge regression if None. Must have
|
||||
model_regressor.coef_ and 'sample_weight' as a parameter
|
||||
to model_regressor.fit()
|
||||
|
||||
Returns:
|
||||
(intercept, exp, score, local_pred):
|
||||
intercept is a float.
|
||||
exp is a sorted list of tuples, where each tuple (x,y) corresponds
|
||||
to the feature id (x) and the local weight (y). The list is sorted
|
||||
by decreasing absolute value of y.
|
||||
score is the R^2 value of the returned explanation
|
||||
local_pred is the prediction of the explanation model on the original instance
|
||||
"""
|
||||
|
||||
weights = self.kernel_fn(distances)
|
||||
labels_column = neighborhood_labels[:, label]
|
||||
used_features = self.feature_selection(neighborhood_data,
|
||||
labels_column,
|
||||
weights,
|
||||
num_features,
|
||||
feature_selection)
|
||||
|
||||
if model_regressor is None:
|
||||
model_regressor = Ridge(alpha=1, fit_intercept=True,
|
||||
random_state=self.random_state)
|
||||
easy_model = model_regressor
|
||||
easy_model.fit(neighborhood_data[:, used_features],
|
||||
labels_column, sample_weight=weights)
|
||||
prediction_score = easy_model.score(
|
||||
neighborhood_data[:, used_features],
|
||||
labels_column, sample_weight=weights)
|
||||
|
||||
local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
|
||||
|
||||
if self.verbose:
|
||||
print('Intercept', easy_model.intercept_)
|
||||
print('Prediction_local', local_pred,)
|
||||
print('Right:', neighborhood_labels[0, label])
|
||||
return (easy_model.intercept_,
|
||||
sorted(zip(used_features, easy_model.coef_),
|
||||
key=lambda x: np.abs(x[1]), reverse=True),
|
||||
prediction_score, local_pred)
|
|
@ -0,0 +1,261 @@
|
|||
"""
|
||||
Functions for explaining classifiers that use Image data.
|
||||
"""
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import sklearn
|
||||
import sklearn.preprocessing
|
||||
from sklearn.utils import check_random_state
|
||||
from skimage.color import gray2rgb
|
||||
|
||||
from . import lime_base
|
||||
from .wrappers.scikit_image import SegmentationAlgorithm
|
||||
|
||||
|
||||
class ImageExplanation(object):
|
||||
def __init__(self, image, segments):
|
||||
"""Init function.
|
||||
|
||||
Args:
|
||||
image: 3d numpy array
|
||||
segments: 2d numpy array, with the output from skimage.segmentation
|
||||
"""
|
||||
self.image = image
|
||||
self.segments = segments
|
||||
self.intercept = {}
|
||||
self.local_exp = {}
|
||||
self.local_pred = None
|
||||
|
||||
def get_image_and_mask(self, label, positive_only=True, hide_rest=False,
|
||||
num_features=5, min_weight=0.):
|
||||
"""Init function.
|
||||
|
||||
Args:
|
||||
label: label to explain
|
||||
positive_only: if True, only take superpixels that contribute to
|
||||
the prediction of the label. Otherwise, use the top
|
||||
num_features superpixels, which can be positive or negative
|
||||
towards the label
|
||||
hide_rest: if True, make the non-explanation part of the return
|
||||
image gray
|
||||
num_features: number of superpixels to include in explanation
|
||||
min_weight: TODO
|
||||
|
||||
Returns:
|
||||
(image, mask), where image is a 3d numpy array and mask is a 2d
|
||||
numpy array that can be used with
|
||||
skimage.segmentation.mark_boundaries
|
||||
"""
|
||||
if label not in self.local_exp:
|
||||
raise KeyError('Label not in explanation')
|
||||
segments = self.segments
|
||||
image = self.image
|
||||
exp = self.local_exp[label]
|
||||
mask = np.zeros(segments.shape, segments.dtype)
|
||||
if hide_rest:
|
||||
temp = np.zeros(self.image.shape)
|
||||
else:
|
||||
temp = self.image.copy()
|
||||
if positive_only:
|
||||
fs = [x[0] for x in exp
|
||||
if x[1] > 0 and x[1] > min_weight][:num_features]
|
||||
for f in fs:
|
||||
temp[segments == f] = image[segments == f].copy()
|
||||
mask[segments == f] = 1
|
||||
return temp, mask
|
||||
else:
|
||||
for f, w in exp[:num_features]:
|
||||
if np.abs(w) < min_weight:
|
||||
continue
|
||||
c = 0 if w < 0 else 1
|
||||
mask[segments == f] = 1 if w < 0 else 2
|
||||
temp[segments == f] = image[segments == f].copy()
|
||||
temp[segments == f, c] = np.max(image)
|
||||
for cp in [0, 1, 2]:
|
||||
if c == cp:
|
||||
continue
|
||||
# temp[segments == f, cp] *= 0.5
|
||||
return temp, mask
|
||||
|
||||
|
||||
class LimeImageExplainer(object):
|
||||
"""Explains predictions on Image (i.e. matrix) data.
|
||||
For numerical features, perturb them by sampling from a Normal(0,1) and
|
||||
doing the inverse operation of mean-centering and scaling, according to the
|
||||
means and stds in the training data. For categorical features, perturb by
|
||||
sampling according to the training distribution, and making a binary
|
||||
feature that is 1 when the value is the same as the instance being
|
||||
explained."""
|
||||
|
||||
def __init__(self, kernel_width=.25, kernel=None, verbose=False,
|
||||
feature_selection='auto', random_state=None):
|
||||
"""Init function.
|
||||
|
||||
Args:
|
||||
kernel_width: kernel width for the exponential kernel.
|
||||
If None, defaults to sqrt(number of columns) * 0.75.
|
||||
kernel: similarity kernel that takes euclidean distances and kernel
|
||||
width as input and outputs weights in (0,1). If None, defaults to
|
||||
an exponential kernel.
|
||||
verbose: if true, print local prediction values from linear model
|
||||
feature_selection: feature selection method. can be
|
||||
'forward_selection', 'lasso_path', 'none' or 'auto'.
|
||||
See function 'explain_instance_with_data' in lime_base.py for
|
||||
details on what each of the options does.
|
||||
random_state: an integer or numpy.RandomState that will be used to
|
||||
generate random numbers. If None, the random state will be
|
||||
initialized using the internal numpy seed.
|
||||
"""
|
||||
kernel_width = float(kernel_width)
|
||||
|
||||
if kernel is None:
|
||||
def kernel(d, kernel_width):
|
||||
return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
|
||||
|
||||
kernel_fn = partial(kernel, kernel_width=kernel_width)
|
||||
|
||||
self.random_state = check_random_state(random_state)
|
||||
self.feature_selection = feature_selection
|
||||
self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state)
|
||||
|
||||
def explain_instance(self, image, classifier_fn, labels=(1,),
|
||||
hide_color=None,
|
||||
top_labels=5, num_features=100000, num_samples=1000,
|
||||
batch_size=10,
|
||||
segmentation_fn=None,
|
||||
distance_metric='cosine',
|
||||
model_regressor=None,
|
||||
random_seed=None):
|
||||
"""Generates explanations for a prediction.
|
||||
|
||||
First, we generate neighborhood data by randomly perturbing features
|
||||
from the instance (see __data_inverse). We then learn locally weighted
|
||||
linear models on this neighborhood data to explain each of the classes
|
||||
in an interpretable way (see lime_base.py).
|
||||
|
||||
Args:
|
||||
image: 3 dimension RGB image. If this is only two dimensional,
|
||||
we will assume it's a grayscale image and call gray2rgb.
|
||||
classifier_fn: classifier prediction probability function, which
|
||||
takes a numpy array and outputs prediction probabilities. For
|
||||
ScikitClassifiers , this is classifier.predict_proba.
|
||||
labels: iterable with labels to be explained.
|
||||
hide_color: TODO
|
||||
top_labels: if not None, ignore labels and produce explanations for
|
||||
the K labels with highest prediction probabilities, where K is
|
||||
this parameter.
|
||||
num_features: maximum number of features present in explanation
|
||||
num_samples: size of the neighborhood to learn the linear model
|
||||
batch_size: TODO
|
||||
distance_metric: the distance metric to use for weights.
|
||||
model_regressor: sklearn regressor to use in explanation. Defaults
|
||||
to Ridge regression in LimeBase. Must have model_regressor.coef_
|
||||
and 'sample_weight' as a parameter to model_regressor.fit()
|
||||
segmentation_fn: SegmentationAlgorithm, wrapped skimage
|
||||
segmentation function
|
||||
random_seed: integer used as random seed for the segmentation
|
||||
algorithm. If None, a random integer, between 0 and 1000,
|
||||
will be generated using the internal random number generator.
|
||||
|
||||
Returns:
|
||||
An Explanation object (see explanation.py) with the corresponding
|
||||
explanations.
|
||||
"""
|
||||
if len(image.shape) == 2:
|
||||
image = gray2rgb(image)
|
||||
if random_seed is None:
|
||||
random_seed = self.random_state.randint(0, high=1000)
|
||||
|
||||
if segmentation_fn is None:
|
||||
segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
|
||||
max_dist=200, ratio=0.2,
|
||||
random_seed=random_seed)
|
||||
try:
|
||||
segments = segmentation_fn(image)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
|
||||
fudged_image = image.copy()
|
||||
if hide_color is None:
|
||||
for x in np.unique(segments):
|
||||
fudged_image[segments == x] = (
|
||||
np.mean(image[segments == x][:, 0]),
|
||||
np.mean(image[segments == x][:, 1]),
|
||||
np.mean(image[segments == x][:, 2]))
|
||||
else:
|
||||
fudged_image[:] = hide_color
|
||||
|
||||
top = labels
|
||||
|
||||
data, labels = self.data_labels(image, fudged_image, segments,
|
||||
classifier_fn, num_samples,
|
||||
batch_size=batch_size)
|
||||
|
||||
distances = sklearn.metrics.pairwise_distances(
|
||||
data,
|
||||
data[0].reshape(1, -1),
|
||||
metric=distance_metric
|
||||
).ravel()
|
||||
|
||||
ret_exp = ImageExplanation(image, segments)
|
||||
if top_labels:
|
||||
top = np.argsort(labels[0])[-top_labels:]
|
||||
ret_exp.top_labels = list(top)
|
||||
ret_exp.top_labels.reverse()
|
||||
for label in top:
|
||||
(ret_exp.intercept[label],
|
||||
ret_exp.local_exp[label],
|
||||
ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data(
|
||||
data, labels, distances, label, num_features,
|
||||
model_regressor=model_regressor,
|
||||
feature_selection=self.feature_selection)
|
||||
return ret_exp
|
||||
|
||||
def data_labels(self,
|
||||
image,
|
||||
fudged_image,
|
||||
segments,
|
||||
classifier_fn,
|
||||
num_samples,
|
||||
batch_size=10):
|
||||
"""Generates images and predictions in the neighborhood of this image.
|
||||
|
||||
Args:
|
||||
image: 3d numpy array, the image
|
||||
fudged_image: 3d numpy array, image to replace original image when
|
||||
superpixel is turned off
|
||||
segments: segmentation of the image
|
||||
classifier_fn: function that takes a list of images and returns a
|
||||
matrix of prediction probabilities
|
||||
num_samples: size of the neighborhood to learn the linear model
|
||||
batch_size: classifier_fn will be called on batches of this size.
|
||||
|
||||
Returns:
|
||||
A tuple (data, labels), where:
|
||||
data: dense num_samples * num_superpixels
|
||||
labels: prediction probabilities matrix
|
||||
"""
|
||||
n_features = np.unique(segments).shape[0]
|
||||
data = self.random_state.randint(0, 2, num_samples * n_features)\
|
||||
.reshape((num_samples, n_features))
|
||||
labels = []
|
||||
data[0, :] = 1
|
||||
imgs = []
|
||||
for row in data:
|
||||
temp = copy.deepcopy(image)
|
||||
zeros = np.where(row == 0)[0]
|
||||
mask = np.zeros(segments.shape).astype(bool)
|
||||
for z in zeros:
|
||||
mask[segments == z] = True
|
||||
temp[mask] = fudged_image[mask]
|
||||
imgs.append(temp)
|
||||
if len(imgs) == batch_size:
|
||||
preds = classifier_fn(np.array(imgs))
|
||||
labels.extend(preds)
|
||||
imgs = []
|
||||
if len(imgs) > 0:
|
||||
preds = classifier_fn(np.array(imgs))
|
||||
labels.extend(preds)
|
||||
return data, np.array(labels)
|
|
@ -0,0 +1,117 @@
|
|||
import types
|
||||
from lime.utils.generic_utils import has_arg
|
||||
from skimage.segmentation import felzenszwalb, slic, quickshift
|
||||
|
||||
|
||||
class BaseWrapper(object):
|
||||
"""Base class for LIME Scikit-Image wrapper
|
||||
|
||||
|
||||
Args:
|
||||
target_fn: callable function or class instance
|
||||
target_params: dict, parameters to pass to the target_fn
|
||||
|
||||
|
||||
'target_params' takes parameters required to instanciate the
|
||||
desired Scikit-Image class/model
|
||||
"""
|
||||
|
||||
def __init__(self, target_fn=None, **target_params):
|
||||
self.target_fn = target_fn
|
||||
self.target_params = target_params
|
||||
|
||||
self.target_fn = target_fn
|
||||
self.target_params = target_params
|
||||
|
||||
def _check_params(self, parameters):
|
||||
"""Checks for mistakes in 'parameters'
|
||||
|
||||
Args :
|
||||
parameters: dict, parameters to be checked
|
||||
|
||||
Raises :
|
||||
ValueError: if any parameter is not a valid argument for the target function
|
||||
or the target function is not defined
|
||||
TypeError: if argument parameters is not iterable
|
||||
"""
|
||||
a_valid_fn = []
|
||||
if self.target_fn is None:
|
||||
if callable(self):
|
||||
a_valid_fn.append(self.__call__)
|
||||
else:
|
||||
raise TypeError('invalid argument: tested object is not callable,\
|
||||
please provide a valid target_fn')
|
||||
elif isinstance(self.target_fn, types.FunctionType) \
|
||||
or isinstance(self.target_fn, types.MethodType):
|
||||
a_valid_fn.append(self.target_fn)
|
||||
else:
|
||||
a_valid_fn.append(self.target_fn.__call__)
|
||||
|
||||
if not isinstance(parameters, str):
|
||||
for p in parameters:
|
||||
for fn in a_valid_fn:
|
||||
if has_arg(fn, p):
|
||||
pass
|
||||
else:
|
||||
raise ValueError('{} is not a valid parameter'.format(p))
|
||||
else:
|
||||
raise TypeError('invalid argument: list or dictionnary expected')
|
||||
|
||||
def set_params(self, **params):
|
||||
"""Sets the parameters of this estimator.
|
||||
Args:
|
||||
**params: Dictionary of parameter names mapped to their values.
|
||||
|
||||
Raises :
|
||||
ValueError: if any parameter is not a valid argument
|
||||
for the target function
|
||||
"""
|
||||
self._check_params(params)
|
||||
self.target_params = params
|
||||
|
||||
def filter_params(self, fn, override=None):
|
||||
"""Filters `target_params` and return those in `fn`'s arguments.
|
||||
Args:
|
||||
fn : arbitrary function
|
||||
override: dict, values to override target_params
|
||||
Returns:
|
||||
result : dict, dictionary containing variables
|
||||
in both target_params and fn's arguments.
|
||||
"""
|
||||
override = override or {}
|
||||
result = {}
|
||||
for name, value in self.target_params.items():
|
||||
if has_arg(fn, name):
|
||||
result.update({name: value})
|
||||
result.update(override)
|
||||
return result
|
||||
|
||||
|
||||
class SegmentationAlgorithm(BaseWrapper):
|
||||
""" Define the image segmentation function based on Scikit-Image
|
||||
implementation and a set of provided parameters
|
||||
|
||||
Args:
|
||||
algo_type: string, segmentation algorithm among the following:
|
||||
'quickshift', 'slic', 'felzenszwalb'
|
||||
target_params: dict, algorithm parameters (valid model paramters
|
||||
as define in Scikit-Image documentation)
|
||||
"""
|
||||
|
||||
def __init__(self, algo_type, **target_params):
|
||||
self.algo_type = algo_type
|
||||
if (self.algo_type == 'quickshift'):
|
||||
BaseWrapper.__init__(self, quickshift, **target_params)
|
||||
kwargs = self.filter_params(quickshift)
|
||||
self.set_params(**kwargs)
|
||||
elif (self.algo_type == 'felzenszwalb'):
|
||||
BaseWrapper.__init__(self, felzenszwalb, **target_params)
|
||||
kwargs = self.filter_params(felzenszwalb)
|
||||
self.set_params(**kwargs)
|
||||
elif (self.algo_type == 'slic'):
|
||||
BaseWrapper.__init__(self, slic, **target_params)
|
||||
kwargs = self.filter_params(slic)
|
||||
self.set_params(**kwargs)
|
||||
|
||||
def __call__(self, *args):
|
||||
return self.target_fn(args[0], **self.target_params)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from skimage.segmentation import mark_boundaries
|
||||
from torchvision import transforms
|
||||
import torch
|
||||
from .lime import lime_image
|
||||
import numpy as np
|
||||
from .. import imagenet_utils, pytorch_utils, utils
|
||||
|
||||
class LimeImageExplainer:
|
||||
def __init__(self, model, predict_fn):
|
||||
self.model = model
|
||||
self.predict_fn = predict_fn
|
||||
|
||||
def preprocess_input(self, inp):
|
||||
return inp
|
||||
def preprocess_label(self, label):
|
||||
return label
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None, top_labels=5, hide_color=0, num_samples=1000,
|
||||
positive_only=True, num_features=5, hide_rest=True, pixel_val_max=255.0):
|
||||
explainer = lime_image.LimeImageExplainer()
|
||||
explanation = explainer.explain_instance(self.preprocess_input(raw_inp), self.predict_fn,
|
||||
top_labels=5, hide_color=0, num_samples=1000)
|
||||
|
||||
temp, mask = explanation.get_image_and_mask(self.preprocess_label(ind) or explanation.top_labels[0],
|
||||
positive_only=True, num_features=5, hide_rest=True)
|
||||
|
||||
img = mark_boundaries(temp/pixel_val_max, mask)
|
||||
img = torch.from_numpy(img)
|
||||
img = torch.transpose(img, 0, 2)
|
||||
img = torch.transpose(img, 1, 2)
|
||||
return img.unsqueeze(0)
|
||||
|
||||
class LimeImagenetExplainer(LimeImageExplainer):
|
||||
def __init__(self, model, predict_fn=None):
|
||||
super(LimeImagenetExplainer, self).__init__(model, predict_fn or self._imagenet_predict)
|
||||
|
||||
def _preprocess_transform(self):
|
||||
transf = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
imagenet_utils.get_normalize_transform()
|
||||
])
|
||||
|
||||
return transf
|
||||
|
||||
def preprocess_input(self, inp):
|
||||
return np.array(imagenet_utils.get_resize_transform()(inp))
|
||||
def preprocess_label(self, label):
|
||||
return label.item() if label is not None and utils.has_method(label, 'item') else label
|
||||
|
||||
def _imagenet_predict(self, images):
|
||||
probs = imagenet_utils.predict(self.model, images, image_transform=self._preprocess_transform())
|
||||
return pytorch_utils.tensor2numpy(probs)
|
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from torch.autograd import Variable
|
||||
from skimage.util import view_as_windows
|
||||
|
||||
# modified from https://github.com/marcoancona/DeepExplain/blob/master/deepexplain/tensorflow/methods.py#L291-L342
|
||||
# note the different dim order in pytorch (NCHW) and tensorflow (NHWC)
|
||||
|
||||
class OcclusionExplainer:
|
||||
def __init__(self, model, window_shape=10, step=1):
|
||||
self.model = model
|
||||
self.window_shape = window_shape
|
||||
self.step = step
|
||||
|
||||
def explain(self, inp, ind=None, raw_inp=None):
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
return OcclusionExplainer._occlusion(inp, self.model, self.window_shape, self.step)
|
||||
|
||||
@staticmethod
|
||||
def _occlusion(inp, model, window_shape, step=None):
|
||||
if type(window_shape) == int:
|
||||
window_shape = (window_shape, window_shape, 3)
|
||||
|
||||
if step is None:
|
||||
step = 1
|
||||
n, c, h, w = inp.data.size()
|
||||
total_dim = c * h * w
|
||||
index_matrix = np.arange(total_dim).reshape(h, w, c)
|
||||
idx_patches = view_as_windows(index_matrix, window_shape, step).reshape(
|
||||
(-1,) + window_shape)
|
||||
heatmap = np.zeros((n, h, w, c), dtype=np.float32).reshape((-1), total_dim)
|
||||
weights = np.zeros_like(heatmap)
|
||||
|
||||
inp_data = inp.data.clone()
|
||||
new_inp = Variable(inp_data)
|
||||
eval0 = model(new_inp)
|
||||
pred_id = eval0.max(1)[1].data[0]
|
||||
|
||||
for i, p in enumerate(idx_patches):
|
||||
mask = np.ones((h, w, c)).flatten()
|
||||
mask[p.flatten()] = 0
|
||||
th_mask = torch.from_numpy(mask.reshape(1, h, w, c).transpose(0, 3, 1, 2)).float().cuda()
|
||||
masked_xs = Variable(th_mask * inp_data)
|
||||
delta = (eval0[0, pred_id] - model(masked_xs)[0, pred_id]).data.cpu().numpy()
|
||||
delta_aggregated = np.sum(delta.reshape(n, -1), -1, keepdims=True)
|
||||
heatmap[:, p.flatten()] += delta_aggregated
|
||||
weights[:, p.flatten()] += p.size
|
||||
|
||||
attribution = np.reshape(heatmap / (weights + 1e-10), (n, h, w, c)).transpose(0, 3, 1, 2)
|
||||
return torch.from_numpy(attribution)
|
|
@ -0,0 +1,112 @@
|
|||
from .gradcam import GradCAMExplainer
|
||||
from .backprop import VanillaGradExplainer, GradxInputExplainer, SaliencyExplainer, \
|
||||
IntegrateGradExplainer, DeconvExplainer, GuidedBackpropExplainer, SmoothGradExplainer
|
||||
from .deeplift import DeepLIFTRescaleExplainer
|
||||
from .occlusion import OcclusionExplainer
|
||||
from .epsilon_lrp import EpsilonLrp
|
||||
from .lime_image_explainer import LimeImageExplainer, LimeImagenetExplainer
|
||||
import skimage.transform
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
from .. import image_utils
|
||||
|
||||
class ImageSaliencyResult:
|
||||
def __init__(self, raw_image, saliency, title, saliency_alpha=0.4, saliency_cmap='jet'):
|
||||
self.raw_image, self.saliency, self.title = raw_image, saliency, title
|
||||
self.saliency_alpha, self.saliency_cmap = saliency_alpha, saliency_cmap
|
||||
|
||||
def _get_explainer(explainer_name, model, layer_path=None):
|
||||
if explainer_name == 'gradcam':
|
||||
return GradCAMExplainer(model, target_layer_name_keys=layer_path, use_inp=True)
|
||||
if explainer_name == 'vanilla_grad':
|
||||
return VanillaGradExplainer(model)
|
||||
if explainer_name == 'grad_x_input':
|
||||
return GradxInputExplainer(model)
|
||||
if explainer_name == 'saliency':
|
||||
return SaliencyExplainer(model)
|
||||
if explainer_name == 'integrate_grad':
|
||||
return IntegrateGradExplainer(model)
|
||||
if explainer_name == 'deconv':
|
||||
return DeconvExplainer(model)
|
||||
if explainer_name == 'guided_backprop':
|
||||
return GuidedBackpropExplainer(model)
|
||||
if explainer_name == 'smooth_grad':
|
||||
return SmoothGradExplainer(model)
|
||||
if explainer_name == 'deeplift':
|
||||
return DeepLIFTRescaleExplainer(model)
|
||||
if explainer_name == 'occlusion':
|
||||
return OcclusionExplainer(model)
|
||||
if explainer_name == 'lrp':
|
||||
return EpsilonLrp(model)
|
||||
if explainer_name == 'lime_imagenet':
|
||||
return LimeImagenetExplainer(model)
|
||||
|
||||
raise ValueError('Explainer {} is not recognized'.format(explainer_name))
|
||||
|
||||
def _get_layer_path(model):
|
||||
if model.__class__.__name__ == 'VGG':
|
||||
return ['features', '30'] # pool5
|
||||
elif model.__class__.__name__ == 'GoogleNet':
|
||||
return ['pool5']
|
||||
elif model.__class__.__name__ == 'ResNet':
|
||||
return ['avgpool'] #layer4
|
||||
elif model.__class__.__name__ == 'Inception3':
|
||||
return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10'
|
||||
else: #unknown network
|
||||
return None
|
||||
|
||||
def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model.to(device)
|
||||
input = input.to(device)
|
||||
if label is not None:
|
||||
label = label.to(device)
|
||||
|
||||
if input.grad is not None:
|
||||
input.grad.zero_()
|
||||
if label is not None and label.grad is not None:
|
||||
label.grad.zero_()
|
||||
model.eval()
|
||||
model.zero_grad()
|
||||
|
||||
layer_path = layer_path or _get_layer_path(model)
|
||||
|
||||
exp = _get_explainer(method, model, layer_path)
|
||||
saliency = exp.explain(input, label, raw_input)
|
||||
|
||||
saliency = saliency.abs().sum(dim=1)[0].squeeze()
|
||||
saliency -= saliency.min()
|
||||
saliency /= (saliency.max() + 1e-20)
|
||||
|
||||
return saliency.detach().cpu().numpy()
|
||||
|
||||
def get_image_saliency_results(model, raw_image, input, label,
|
||||
methods=['lime_imagenet', 'gradcam', 'smooth_grad',
|
||||
'guided_backprop', 'deeplift', 'grad_x_input'],
|
||||
layer_path=None):
|
||||
results = []
|
||||
for method in methods:
|
||||
sal = get_saliency(model, raw_image, input, label, method=method)
|
||||
results.append(ImageSaliencyResult(raw_image, sal, method))
|
||||
return results
|
||||
|
||||
def get_image_saliency_plot(image_saliency_results, cols = 2, figsize = None):
|
||||
rows = math.ceil(len(image_saliency_results) / cols)
|
||||
figsize=figsize or (8, 3 * rows)
|
||||
figure = plt.figure(figsize=figsize) #figsize=(8, 3)
|
||||
|
||||
for i, r in enumerate(image_saliency_results):
|
||||
ax = figure.add_subplot(rows, cols, i+1)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.set_title(r.title, fontdict={'fontsize': 24}) #'fontweight': 'light'
|
||||
|
||||
#upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear')
|
||||
saliency_upsampled = skimage.transform.resize(r.saliency,
|
||||
(r.raw_image.height, r.raw_image.width))
|
||||
|
||||
image_utils.show_image(r.raw_image, img2=saliency_upsampled,
|
||||
alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)
|
||||
return figure
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import weakref, uuid
|
||||
from typing import Any
|
||||
from . import utils
|
||||
|
||||
class Stream:
|
||||
def __init__(self, stream_name:str=None, console_debug:bool=False):
|
||||
self._subscribers = weakref.WeakSet()
|
||||
self._subscribed_to = weakref.WeakSet()
|
||||
self.held_refs = set() # on some rare occasion we might want stream to hold references of other streams
|
||||
self.closed = False
|
||||
self.console_debug = console_debug
|
||||
self.stream_name = stream_name or str(uuid.uuid4()) # useful to use as key and avoid circular references
|
||||
|
||||
def subscribe(self, stream:'Stream'): # notify other stream
|
||||
utils.debug_log('{} added {} as subscription'.format(self.stream_name, stream.stream_name))
|
||||
stream._subscribers.add(self)
|
||||
self._subscribed_to.add(stream)
|
||||
|
||||
def unsubscribe(self, stream:'Stream'):
|
||||
utils.debug_log('{} removed {} as subscription'.format(self.stream_name, stream.stream_name))
|
||||
stream._subscribers.discard(self)
|
||||
self._subscribed_to.discard(stream)
|
||||
self.held_refs.discard(stream)
|
||||
#stream.held_refs.discard(self) # not needed as only subscriber should hold ref
|
||||
|
||||
def write(self, val:Any, from_stream:'Stream'=None):
|
||||
if self.console_debug:
|
||||
print(self.stream_name, val)
|
||||
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.write(val, from_stream=self)
|
||||
|
||||
def load(self, from_stream:'Stream'=None):
|
||||
for subscribed_to in self._subscribed_to:
|
||||
subscribed_to.load(from_stream=self)
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
for subscribed_to in self._subscribed_to:
|
||||
subscribed_to._subscribers.discard(self)
|
||||
self._subscribed_to.clear()
|
||||
self.closed = True
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self.close()
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict, Sequence
|
||||
from .zmq_stream import ZmqStream
|
||||
from .file_stream import FileStream
|
||||
from .stream import Stream
|
||||
from .stream_union import StreamUnion
|
||||
|
||||
class StreamFactory:
|
||||
r"""Allows to create shared stream such as file and ZMQ streams
|
||||
"""
|
||||
|
||||
def __init__(self)->None:
|
||||
self.closed = None
|
||||
self._streams:Dict[str, Stream] = None
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
self._streams:Dict[str, Stream] = {}
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
for stream in self._streams.values():
|
||||
stream.close()
|
||||
self._reset()
|
||||
self.closed = True
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self.close()
|
||||
|
||||
def get_streams(self, stream_types:Sequence[str], for_write:bool=None)->Stream:
|
||||
streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
|
||||
return streams
|
||||
|
||||
def get_combined_stream(self, stream_types:Sequence[str], for_write:bool=None)->Stream:
|
||||
streams = [self._create_stream_by_string(stream_type, for_write) for stream_type in stream_types]
|
||||
if len(streams) == 1:
|
||||
return self._streams[0]
|
||||
else:
|
||||
# we create new union of child but this is not necessory
|
||||
return StreamUnion(streams, for_write=for_write)
|
||||
|
||||
def _create_stream_by_string(self, stream_spec:str, for_write:bool)->Stream:
|
||||
parts = stream_spec.split(':', 1) if stream_spec is not None else ['']
|
||||
stream_type = parts[0]
|
||||
stream_args = parts[1] if len(parts) > 1 else None
|
||||
|
||||
if stream_type == 'tcp':
|
||||
port = int(stream_args or 0)
|
||||
stream_name = '{}:{}:{}'.format(stream_type, port, for_write)
|
||||
if stream_name not in self._streams:
|
||||
self._streams[stream_name] = ZmqStream(for_write=for_write,
|
||||
port=port, stream_name=stream_name, block_until_connected=False)
|
||||
# else we already have this stream
|
||||
return self._streams[stream_name]
|
||||
|
||||
|
||||
if stream_args is None: # file name specified without 'file:' prefix
|
||||
stream_args = stream_type
|
||||
stream_type = 'file'
|
||||
if len(stream_type) == 1: # windows drive letter
|
||||
stream_type = 'file'
|
||||
stream_args = stream_spec
|
||||
|
||||
if stream_type == 'file':
|
||||
if stream_args is None:
|
||||
raise ValueError('File name must be specified for stream type "file"')
|
||||
stream_name = '{}:{}:{}'.format(stream_type, stream_args, for_write)
|
||||
if stream_name not in self._streams:
|
||||
self._streams[stream_name] = FileStream(for_write=for_write,
|
||||
file_name=stream_args, stream_name=stream_name)
|
||||
# else we already have this stream
|
||||
return self._streams[stream_name]
|
||||
|
||||
if stream_type == '':
|
||||
return Stream()
|
||||
|
||||
raise ValueError('stream_type "{}" has unknown type'.format(stream_type))
|
||||
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .stream import Stream
|
||||
from typing import Iterator
|
||||
|
||||
class StreamUnion(Stream):
|
||||
def __init__(self, child_streams:Iterator[Stream], for_write:bool, stream_name:str=None, console_debug:bool=False) -> None:
|
||||
super(StreamUnion, self).__init__(stream_name=stream_name, console_debug=console_debug)
|
||||
|
||||
# save references, child streams does away only if parent goes away
|
||||
self.child_streams = child_streams
|
||||
|
||||
# when someone does write to us, we write to all our listeners
|
||||
if for_write:
|
||||
for child_stream in child_streams:
|
||||
child_stream.subscribe(self)
|
||||
else:
|
||||
# union of all child streams
|
||||
for child_stream in child_streams:
|
||||
self.subscribe(child_stream)
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from . import utils
|
||||
import pandas as pd
|
||||
import ipywidgets as widgets
|
||||
from IPython import get_ipython #, display
|
||||
|
||||
from .vis_base import VisBase
|
||||
|
||||
class TextVis(VisBase):
|
||||
def __init__(self, cell:widgets.Box=None, title:str=None, show_legend:bool=None,
|
||||
stream_name:str=None, console_debug:bool=False, **vis_args):
|
||||
super(TextVis, self).__init__(widgets.HTML(), cell, title, show_legend,
|
||||
stream_name=stream_name, console_debug=console_debug, **vis_args)
|
||||
self.df = pd.DataFrame([])
|
||||
|
||||
def _get_column_prefix(self, stream_vis, i):
|
||||
return '[S.{}]:{}'.format(stream_vis.index, i)
|
||||
|
||||
def _get_title(self, stream_vis):
|
||||
title = stream_vis.title or 'Stream ' + str(len(self._stream_vises))
|
||||
return title
|
||||
|
||||
# this will be called from _show_stream_items
|
||||
def _append(self, stream_vis, vals):
|
||||
if vals is None:
|
||||
self.df = self.df.append(pd.Series({self._get_column_prefix(stream_vis, 0) : None}),
|
||||
sort=False, ignore_index=True)
|
||||
return
|
||||
for val in vals:
|
||||
if val is None or utils.is_scalar(val):
|
||||
self.df = self.df.append(pd.Series({self._get_column_prefix(stream_vis, 0) : val}),
|
||||
sort=False, ignore_index=True)
|
||||
elif utils.is_array_like(val):
|
||||
val_dict = {}
|
||||
for i,val_i in enumerate(val):
|
||||
val_dict[self._get_column_prefix(stream_vis, i)] = val_i
|
||||
self.df = self.df.append(pd.Series(val_dict), sort=False, ignore_index=True)
|
||||
else:
|
||||
self.df = self.df.append(pd.Series(val.__dict__), sort=False, ignore_index=True)
|
||||
|
||||
def _post_add_subscription(self, stream_vis, **stream_vis_args):
|
||||
only_summary = stream_vis_args.get('only_summary', False)
|
||||
stream_vis.text = self._get_title(stream_vis)
|
||||
stream_vis.only_summary = only_summary
|
||||
|
||||
def clear_plot(self, stream_vis, clear_history):
|
||||
self.df = self.df.iloc[0:0]
|
||||
|
||||
def _show_stream_items(self, stream_vis, stream_items):
|
||||
for stream_item in stream_items:
|
||||
if stream_item.ended:
|
||||
self.df = self.df.append(pd.Series({'Ended':True}),
|
||||
sort=False, ignore_index=True)
|
||||
else:
|
||||
vals = self._extract_vals((stream_item,))
|
||||
self._append(stream_vis, vals)
|
||||
return True
|
||||
|
||||
def _post_update_stream_plot(self, stream_vis):
|
||||
if get_ipython():
|
||||
if not stream_vis.only_summary:
|
||||
self.widget.value = self.df.to_html(classes=['output_html', 'rendered_html'])
|
||||
else:
|
||||
self.widget.value = self.df.describe().to_html(classes=['output_html', 'rendered_html'])
|
||||
# below doesn't work because of threading issue
|
||||
#self.widget.clear_output(wait=True)
|
||||
#with self.widget:
|
||||
# display.display(self.df)
|
||||
else:
|
||||
last_recs = self.df.iloc[[-1]].to_dict('records')
|
||||
if len(last_recs) == 1:
|
||||
print(last_recs[0])
|
||||
else:
|
||||
print(last_recs)
|
||||
|
||||
def _show_widget_native(self, blocking:bool):
|
||||
return None # we will be using console
|
||||
|
||||
def _show_widget_notebook(self):
|
||||
return self.widget
|
|
@ -0,0 +1,352 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np #pip install numpy
|
||||
import math
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import inspect
|
||||
import re
|
||||
import uuid
|
||||
from collections import abc
|
||||
import textwrap
|
||||
|
||||
from functools import wraps
|
||||
import gc
|
||||
import timeit
|
||||
def MeasureTime(f):
|
||||
@wraps(f)
|
||||
def _wrapper(*args, **kwargs):
|
||||
gcold = gc.isenabled()
|
||||
gc.disable()
|
||||
start_time = timeit.default_timer()
|
||||
try:
|
||||
result = f(*args, **kwargs)
|
||||
finally:
|
||||
elapsed = timeit.default_timer() - start_time
|
||||
if gcold:
|
||||
gc.enable()
|
||||
print('Function "{}": {}s'.format(f.__name__, elapsed))
|
||||
return result
|
||||
return _wrapper
|
||||
class MeasureBlockTime:
|
||||
def __init__(self, name="(block)", no_print = False, disable_gc = True, format_str=":.2f"):
|
||||
self.name = name
|
||||
self.no_print = no_print
|
||||
self.disable_gc = disable_gc
|
||||
self.format_str = format_str
|
||||
self.gcold = None
|
||||
self.start_time = None
|
||||
self.elapsed = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.disable_gc:
|
||||
self.gcold = gc.isenabled()
|
||||
gc.disable()
|
||||
self.start_time = timeit.default_timer()
|
||||
return self
|
||||
def __exit__(self,ty,val,tb):
|
||||
self.elapsed = timeit.default_timer() - self.start_time
|
||||
if self.disable_gc and self.gcold:
|
||||
gc.enable()
|
||||
if not self.no_print:
|
||||
print(('{}: {' + self.format_str + '}s').format(self.name, self.elapsed))
|
||||
return False #re-raise any exceptions
|
||||
def getTime():
|
||||
return timeit.default_timer()
|
||||
def getElapsedTime(start_time):
|
||||
return timeit.default_timer() - start_time
|
||||
def string_to_uint8_array(bstr):
|
||||
return np.fromstring(bstr, np.uint8)
|
||||
|
||||
def string_to_float_array(bstr):
|
||||
return np.fromstring(bstr, np.float32)
|
||||
|
||||
def list_to_2d_float_array(flst, width, height):
|
||||
return np.reshape(np.asarray(flst, np.float32), (height, width))
|
||||
|
||||
def get_pfm_array(response):
|
||||
return list_to_2d_float_array(response.image_data_float, response.width, response.height)
|
||||
|
||||
# creates same list as len of seq filled with val - if val is already not a list of same size
|
||||
def fill_like(val, seq):
|
||||
l = len(seq)
|
||||
if is_array_like(val) and len(val) == l:
|
||||
return val
|
||||
return [val] * len(seq)
|
||||
|
||||
def is_array_like(obj, string_is_array=False, tuple_is_array=True):
|
||||
result = hasattr(obj, "__len__") and hasattr(obj, '__getitem__')
|
||||
if result and not string_is_array and isinstance(obj, (str, abc.ByteString)):
|
||||
result = False
|
||||
if result and not tuple_is_array and isinstance(obj, tuple):
|
||||
result = False
|
||||
return result
|
||||
|
||||
def is_scalar(x):
|
||||
return x is None or np.isscalar(x)
|
||||
|
||||
def is_scaler_array(x): #detects (x,y) or [x, y]
|
||||
if is_array_like(x):
|
||||
if len(x) > 0:
|
||||
return len(x) if is_scalar(x[0]) else -1
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
return -1
|
||||
|
||||
def get_public_fields(obj):
|
||||
return [attr for attr in dir(obj)
|
||||
if not (attr.startswith("_")
|
||||
or inspect.isbuiltin(attr)
|
||||
or inspect.isfunction(attr)
|
||||
or inspect.ismethod(attr))]
|
||||
|
||||
def set_default(dictionary, key, default_val, replace_none=True):
|
||||
if key not in dictionary or (replace_none and dictionary[key] is None):
|
||||
dictionary[key] = default_val
|
||||
|
||||
def to_array_like(val):
|
||||
if is_array_like(val):
|
||||
return val
|
||||
return [val]
|
||||
|
||||
def to_dict(obj):
|
||||
return dict([attr, getattr(obj, attr)] for attr in get_public_fields(obj))
|
||||
|
||||
|
||||
def to_str(obj):
|
||||
return str(to_dict(obj))
|
||||
|
||||
|
||||
def write_file(filename, bstr):
|
||||
with open(filename, 'wb') as afile:
|
||||
afile.write(bstr)
|
||||
|
||||
# helper method for converting getOrientation to roll/pitch/yaw
|
||||
# https:#en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
|
||||
|
||||
def has_method(o, name):
|
||||
return callable(getattr(o, name, None))
|
||||
|
||||
def to_eularian_angles(q):
|
||||
z = q.z_val
|
||||
y = q.y_val
|
||||
x = q.x_val
|
||||
w = q.w_val
|
||||
ysqr = y * y
|
||||
|
||||
# roll (x-axis rotation)
|
||||
t0 = +2.0 * (w*x + y*z)
|
||||
t1 = +1.0 - 2.0*(x*x + ysqr)
|
||||
roll = math.atan2(t0, t1)
|
||||
|
||||
# pitch (y-axis rotation)
|
||||
t2 = +2.0 * (w*y - z*x)
|
||||
if (t2 > 1.0):
|
||||
t2 = 1
|
||||
if (t2 < -1.0):
|
||||
t2 = -1.0
|
||||
pitch = math.asin(t2)
|
||||
|
||||
# yaw (z-axis rotation)
|
||||
t3 = +2.0 * (w*z + x*y)
|
||||
t4 = +1.0 - 2.0 * (ysqr + z*z)
|
||||
yaw = math.atan2(t3, t4)
|
||||
|
||||
return (pitch, roll, yaw)
|
||||
|
||||
|
||||
# TODO: sync with AirSim utils.py
|
||||
|
||||
def wait_key(message = ''):
|
||||
''' Wait for a key press on the console and return it. '''
|
||||
if message != '':
|
||||
print (message)
|
||||
|
||||
result = None
|
||||
if os.name == 'nt':
|
||||
import msvcrt
|
||||
result = msvcrt.getch()
|
||||
else:
|
||||
# pylint: disable=import-error
|
||||
import termios # pylint: disable=import-error
|
||||
fd = sys.stdin.fileno()
|
||||
|
||||
oldterm = termios.tcgetattr(fd)
|
||||
newattr = termios.tcgetattr(fd)
|
||||
newattr[3] = newattr[3] & ~termios.ICANON & ~termios.ECHO
|
||||
termios.tcsetattr(fd, termios.TCSANOW, newattr)
|
||||
|
||||
try:
|
||||
result = sys.stdin.read(1)
|
||||
except IOError:
|
||||
pass
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSAFLUSH, oldterm)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def read_pfm(file):
|
||||
""" Read a pfm file """
|
||||
file = open(file, 'rb')
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
header = str(bytes.decode(header, encoding='utf-8'))
|
||||
if header == 'PF':
|
||||
color = True
|
||||
elif header == 'Pf':
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Not a PFM file.')
|
||||
|
||||
temp_str = str(bytes.decode(file.readline(), encoding='utf-8'))
|
||||
dim_match = re.match(r'^(\d+)\s(\d+)\s$', temp_str)
|
||||
if dim_match:
|
||||
width, height = map(int, dim_match.groups())
|
||||
else:
|
||||
raise Exception('Malformed PFM header.')
|
||||
|
||||
scale = float(file.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = '<'
|
||||
scale = -scale
|
||||
else:
|
||||
endian = '>' # big-endian
|
||||
|
||||
data = np.fromfile(file, endian + 'f')
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
# DEY: I don't know why this was there.
|
||||
#data = np.flipud(data)
|
||||
file.close()
|
||||
|
||||
return data, scale
|
||||
|
||||
|
||||
def write_pfm(file, image, scale=1):
|
||||
""" Write a pfm file """
|
||||
file = open(file, 'wb')
|
||||
|
||||
color = None
|
||||
|
||||
if image.dtype.name != 'float32':
|
||||
raise Exception('Image dtype must be float32.')
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||
color = True
|
||||
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
||||
|
||||
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
|
||||
temp_str = '%d %d\n' % (image.shape[1], image.shape[0])
|
||||
file.write(temp_str.encode('utf-8'))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
||||
scale = -scale
|
||||
|
||||
temp_str = '%f\n' % scale
|
||||
file.write(temp_str.encode('utf-8'))
|
||||
|
||||
image.tofile(file)
|
||||
|
||||
|
||||
def write_png(filename, image):
|
||||
""" image must be numpy array H X W X channels
|
||||
"""
|
||||
import zlib, struct
|
||||
|
||||
buf = image.flatten().tobytes()
|
||||
width = image.shape[1]
|
||||
height = image.shape[0]
|
||||
|
||||
# reverse the vertical line order and add null bytes at the start
|
||||
width_byte_4 = width * 4
|
||||
raw_data = b''.join(b'\x00' + buf[span:span + width_byte_4]
|
||||
for span in range((height - 1) * width_byte_4, -1, - width_byte_4))
|
||||
|
||||
def png_pack(png_tag, data):
|
||||
chunk_head = png_tag + data
|
||||
return (struct.pack("!I", len(data)) +
|
||||
chunk_head +
|
||||
struct.pack("!I", 0xFFFFFFFF & zlib.crc32(chunk_head)))
|
||||
|
||||
png_bytes = b''.join([
|
||||
b'\x89PNG\r\n\x1a\n',
|
||||
png_pack(b'IHDR', struct.pack("!2I5B", width, height, 8, 6, 0, 0, 0)),
|
||||
png_pack(b'IDAT', zlib.compress(raw_data, 9)),
|
||||
png_pack(b'IEND', b'')])
|
||||
|
||||
write_file(filename, png_bytes)
|
||||
|
||||
def add_windows_ctrl_c():
|
||||
def handler(a,b=None): # pylint: disable=unused-argument
|
||||
sys.exit(1)
|
||||
add_windows_ctrl_c.is_handler_installed = \
|
||||
vars(add_windows_ctrl_c).setdefault('is_handler_installed',False)
|
||||
if sys.platform == "win32" and not add_windows_ctrl_c.is_handler_installed:
|
||||
if sys.stdin is not None and sys.stdin.isatty():
|
||||
#this is Console based application
|
||||
import win32api
|
||||
win32api.SetConsoleCtrlHandler(handler, True)
|
||||
#else do not install handler for non-console applications
|
||||
add_windows_ctrl_c.is_handler_installed = True
|
||||
|
||||
_utils_debug_verbosity=0
|
||||
_utils_start_time = time.time()
|
||||
def set_debug_verbosity(verbosity=0):
|
||||
global _utils_debug_verbosity # pylink: disable=global-statement
|
||||
_utils_debug_verbosity = verbosity
|
||||
def debug_log(msg, param=None, verbosity=3):
|
||||
global _utils_debug_verbosity # pylink: disable=global-statement
|
||||
if _utils_debug_verbosity is not None and _utils_debug_verbosity >= verbosity:
|
||||
print("[Debug][{}]: {} : {} : t={:.2f}".format(verbosity, msg, param, time.time()-_utils_start_time))
|
||||
|
||||
def get_uuid(is_hex = False):
|
||||
return str(uuid.uuid4()) if not is_hex else uuid.uuid4().hex
|
||||
|
||||
def is_uuid4(s, is_hex=False):
|
||||
try:
|
||||
val = uuid.UUID(s, version=4)
|
||||
return val.hex == s if is_hex else str(val) == s
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def frange(start, stop=None, step=None, steps=None):
|
||||
if stop is None:
|
||||
start, stop = 0, start
|
||||
if steps is None:
|
||||
if step is None:
|
||||
step = 1
|
||||
steps = int((stop-start)/step)
|
||||
else:
|
||||
if step is not None:
|
||||
raise ValueError("Both step and steps cannot be specified")
|
||||
step = (stop-start)/steps
|
||||
for _ in range(steps):
|
||||
yield start
|
||||
start += step
|
||||
|
||||
def wrap_string(s, chars_per_line=12):
|
||||
return "\n".join(textwrap.wrap(s, chars_per_line))
|
||||
|
||||
def is_eof(f):
|
||||
s = f.read(1)
|
||||
if s != b'': # restore position
|
||||
f.seek(-1, os.SEEK_CUR)
|
||||
return s == b''
|
|
@ -0,0 +1,174 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sys, time, threading, queue, functools
|
||||
from typing import Any
|
||||
from types import MethodType
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from .lv_types import StreamPlot, StreamItem
|
||||
from . import utils
|
||||
from .stream import Stream
|
||||
|
||||
from IPython import get_ipython, display
|
||||
import ipywidgets as widgets
|
||||
|
||||
class VisBase(Stream, metaclass=ABCMeta):
|
||||
def __init__(self, widget, cell:widgets.Box, title:str, show_legend:bool, stream_name:str=None, console_debug:bool=False, **vis_args):
|
||||
super(VisBase, self).__init__(stream_name=stream_name, console_debug=console_debug)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self._use_hbox = True
|
||||
utils.set_default(vis_args, 'cell_width', '100%')
|
||||
|
||||
self.widget = widget
|
||||
|
||||
self.cell = cell or widgets.HBox(layout=widgets.Layout(\
|
||||
width=vis_args['cell_width'])) if self._use_hbox else None
|
||||
if self._use_hbox:
|
||||
self.cell.children += (self.widget,)
|
||||
self._stream_vises = {}
|
||||
self.is_shown = cell is not None
|
||||
self.title = title
|
||||
self.last_ex = None
|
||||
self.layout_dirty = False
|
||||
self.q_last_processed = 0
|
||||
|
||||
def subscribe(self, stream:Stream, title=None, clear_after_end=False, clear_after_each=False,
|
||||
show:bool=False, history_len=1, dim_history=True, opacity=None, **stream_vis_args):
|
||||
# in this ovedrride we don't call base class method
|
||||
with self.lock:
|
||||
self.layout_dirty = True
|
||||
|
||||
stream_vis = StreamPlot(stream, title, clear_after_end,
|
||||
clear_after_each, history_len, dim_history, opacity,
|
||||
len(self._stream_vises), stream_vis_args, 0)
|
||||
stream_vis._clear_pending = False
|
||||
stream_vis._pending_items = queue.Queue()
|
||||
self._stream_vises[stream.stream_name] = stream_vis
|
||||
|
||||
self._post_add_subscription(stream_vis, **stream_vis_args)
|
||||
|
||||
super(VisBase, self).subscribe(stream)
|
||||
|
||||
if show or (show is None and not self.is_shown):
|
||||
return self.show()
|
||||
|
||||
def show(self, blocking:bool=False):
|
||||
self.is_shown = True
|
||||
if get_ipython():
|
||||
if self._use_hbox:
|
||||
display.display(self.cell) # this method doesn't need returns
|
||||
#return self.cell
|
||||
else:
|
||||
return self._show_widget_notebook()
|
||||
else:
|
||||
return self._show_widget_native(blocking)
|
||||
|
||||
def write(self, val:Any, from_stream:'Stream'=None):
|
||||
# let the base class know about new item, this will notify any subscribers
|
||||
super(VisBase, self).write(val)
|
||||
|
||||
stream_vis:StreamPlot = None
|
||||
if from_stream:
|
||||
stream_vis = self._stream_vises.get(from_stream.stream_name, None)
|
||||
|
||||
if not stream_vis: # select the first one we have
|
||||
stream_vis = next(iter(self._stream_vises.values()))
|
||||
|
||||
VisBase.write_stream_plot(self, stream_vis, val)
|
||||
|
||||
@staticmethod
|
||||
def write_stream_plot(vis, stream_vis:StreamPlot, stream_item:StreamItem):
|
||||
with vis.lock: # this could be from separate thread!
|
||||
#if stream_vis is None:
|
||||
# utils.debug_log('stream_vis not specified in VisBase.write')
|
||||
# stream_vis = next(iter(vis._stream_vises.values())) # use first as default
|
||||
utils.debug_log("Stream received: {}".format(stream_item.stream_name), verbosity=5)
|
||||
stream_vis._pending_items.put(stream_item)
|
||||
|
||||
# if we accumulated enough of pending items then let's process them
|
||||
if vis._can_update_stream_plots():
|
||||
vis._update_stream_plots()
|
||||
|
||||
def _extract_results(self, stream_vis):
|
||||
stream_items, clear_current, clear_history = [], False, False
|
||||
while not stream_vis._pending_items.empty():
|
||||
stream_item = stream_vis._pending_items.get()
|
||||
if stream_item.stream_reset:
|
||||
utils.debug_log("Stream reset", stream_item.stream_name)
|
||||
stream_items.clear() # no need to process these events
|
||||
clear_current, clear_history = True, True
|
||||
else:
|
||||
# check if there was an exception
|
||||
if stream_item.exception is not None:
|
||||
#TODO: need better handling here?
|
||||
print(stream_item.exception, file=sys.stderr)
|
||||
raise stream_item.exception
|
||||
|
||||
# state management for _clear_pending
|
||||
# if we need to clear plot before putting in data, do so
|
||||
if stream_vis._clear_pending:
|
||||
stream_items.clear()
|
||||
clear_current = True
|
||||
stream_vis._clear_pending = False
|
||||
if stream_vis.clear_after_each or (stream_item.ended and stream_vis.clear_after_end):
|
||||
stream_vis._clear_pending = True
|
||||
|
||||
stream_items.append(stream_item)
|
||||
|
||||
return stream_items, clear_current, clear_history
|
||||
|
||||
def _extract_vals(self, stream_items):
|
||||
vals = []
|
||||
for stream_item in stream_items:
|
||||
if stream_item.ended or stream_item.value is None:
|
||||
pass # no values to add
|
||||
else:
|
||||
if utils.is_array_like(stream_item.value, tuple_is_array=False):
|
||||
vals.extend(stream_item.value)
|
||||
else:
|
||||
vals.append(stream_item.value)
|
||||
return vals
|
||||
|
||||
@abstractmethod
|
||||
def clear_plot(self, stream_vis, clear_history):
|
||||
"""(for derived class) Clears the data in specified plot before new data is redrawn"""
|
||||
pass
|
||||
@abstractmethod
|
||||
def _show_stream_items(self, stream_vis, stream_items):
|
||||
"""(for derived class) Plot the data in given axes"""
|
||||
pass
|
||||
@abstractmethod
|
||||
def _post_add_subscription(self, stream_vis, **stream_vis_args):
|
||||
pass
|
||||
|
||||
# typically we want to batch up items for performance
|
||||
def _can_update_stream_plots(self):
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def _post_update_stream_plot(self, stream_vis):
|
||||
pass
|
||||
|
||||
def _update_stream_plots(self):
|
||||
with self.lock:
|
||||
self.q_last_processed = time.time()
|
||||
for stream_vis in self._stream_vises.values():
|
||||
stream_items, clear_current, clear_history = self._extract_results(stream_vis)
|
||||
|
||||
if clear_current:
|
||||
self.clear_plot(stream_vis, clear_history)
|
||||
|
||||
# if we have something to render
|
||||
dirty = self._show_stream_items(stream_vis, stream_items)
|
||||
if dirty:
|
||||
self._post_update_stream_plot(stream_vis)
|
||||
stream_vis.last_update = time.time()
|
||||
|
||||
@abstractmethod
|
||||
def _show_widget_native(self, blocking:bool):
|
||||
pass
|
||||
@abstractmethod
|
||||
def _show_widget_notebook(self):
|
||||
pass
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .stream import Stream
|
||||
from .vis_base import VisBase
|
||||
import ipywidgets as widgets
|
||||
|
||||
class Visualizer:
|
||||
def __init__(self, stream:Stream, vis_type:str=None, host:'Visualizer'=None,
|
||||
cell:'Visualizer'=None, title:str=None,
|
||||
clear_after_end=False, clear_after_each=False, history_len=1, dim_history=True, opacity=None,
|
||||
|
||||
rows=2, cols=5, img_width=None, img_height=None, img_channels=None,
|
||||
colormap=None, viz_img_scale=None,
|
||||
|
||||
# these image params are for hover on point for t-sne
|
||||
hover_images=None, hover_image_reshape=None, cell_width:str=None, cell_height:str=None,
|
||||
|
||||
only_summary=False, separate_yaxis=True, xtitle=None, ytitle=None, ztitle=None, color=None,
|
||||
xrange=None, yrange=None, zrange=None, draw_line=True, draw_marker=False,
|
||||
|
||||
vis_args={}, stream_vis_args={})->None:
|
||||
|
||||
cell = cell._host_base.cell if cell is not None else None
|
||||
|
||||
if host:
|
||||
self._host_base = host._host_base
|
||||
else:
|
||||
self._host_base = self._get_vis_base(vis_type, cell, title, hover_images=hover_images, hover_image_reshape=hover_image_reshape,
|
||||
cell_width=cell_width, cell_height=cell_height,
|
||||
**vis_args)
|
||||
|
||||
self._host_base.subscribe(stream, show=False, clear_after_end=clear_after_end, clear_after_each=clear_after_each,
|
||||
history_len=history_len, dim_history=dim_history, opacity=opacity,
|
||||
only_summary=only_summary if 'summary' != vis_type else True,
|
||||
separate_yaxis=separate_yaxis, xtitle=xtitle, ytitle=ytitle, ztitle=ztitle, color=color,
|
||||
xrange=xrange, yrange=yrange, zrange=zrange,
|
||||
draw_line=draw_line if vis_type is not None and 'scatter' in vis_type else True,
|
||||
draw_marker=draw_marker,
|
||||
rows=rows, cols=cols, img_width=img_width, img_height=img_height, img_channels=img_channels,
|
||||
colormap=colormap, viz_img_scale=viz_img_scale,
|
||||
**stream_vis_args)
|
||||
|
||||
stream.load()
|
||||
|
||||
def show(self):
|
||||
return self._host_base.show()
|
||||
|
||||
def _get_vis_base(self, vis_type, cell:widgets.Box, title, hover_images=None, hover_image_reshape=None, cell_width=None, cell_height=None, **vis_args)->VisBase:
|
||||
if vis_type is None:
|
||||
from .text_vis import TextVis
|
||||
return TextVis(cell=cell, title=title, **vis_args)
|
||||
if vis_type in ['text', 'summary']:
|
||||
from .text_vis import TextVis
|
||||
return TextVis(cell=cell, title=title, **vis_args)
|
||||
if vis_type in ['plotly-line', 'scatter', 'plotly-scatter',
|
||||
'line3d', 'scatter3d', 'mesh3d']:
|
||||
from . import plotly
|
||||
return plotly.LinePlot(cell=cell, title=title,
|
||||
is_3d=vis_type in ['line3d', 'scatter3d', 'mesh3d'], **vis_args)
|
||||
if vis_type in ['image', 'mpl-image']:
|
||||
from . import mpl
|
||||
return mpl.ImagePlot(cell=cell, title=title, cell_width=cell_width, cell_height=cell_height, **vis_args)
|
||||
if vis_type in ['line', 'mpl-line', 'mpl-scatter']:
|
||||
from . import mpl
|
||||
return mpl.LinePlot(cell=cell, title=title, **vis_args)
|
||||
if vis_type in ['tsne', 'embeddings', 'tsne2d', 'embeddings2d']:
|
||||
from . import plotly
|
||||
return plotly.EmbeddingsPlot(cell=cell, title=title, is_3d='2d' not in vis_type,
|
||||
hover_images=hover_images, hover_image_reshape=hover_image_reshape, **vis_args)
|
||||
else:
|
||||
raise ValueError('Render vis_type parameter has invalid value: "{}"'.format(vis_type))
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import uuid
|
||||
from typing import Sequence
|
||||
from .zmq_wrapper import ZmqWrapper
|
||||
from .watcher_base import WatcherBase
|
||||
from .lv_types import CliSrvReqTypes
|
||||
from .lv_types import DefaultPorts, PublisherTopics, ServerMgmtMsg
|
||||
from . import utils
|
||||
import threading, time
|
||||
|
||||
class Watcher(WatcherBase):
|
||||
def __init__(self, filename:str=None, port:int=0, srv_name:str=None):
|
||||
super(Watcher, self).__init__()
|
||||
|
||||
self.port = port
|
||||
self.filename = filename
|
||||
|
||||
# used to detect server restarts
|
||||
self.srv_name = srv_name or str(uuid.uuid4())
|
||||
|
||||
# define vars in __init__
|
||||
self._clisrv = None
|
||||
self._zmq_stream_pub = None
|
||||
self._file = None
|
||||
self._th = None
|
||||
|
||||
self._open_devices()
|
||||
|
||||
def _open_devices(self):
|
||||
if self.port is not None:
|
||||
self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
|
||||
is_server=True, callback=self._clisrv_callback)
|
||||
|
||||
# notify existing listeners of our ID
|
||||
self._zmq_stream_pub = self._stream_factory.get_streams(stream_types=['tcp:'+str(self.port)], for_write=True)[0]
|
||||
|
||||
# ZMQ quirk: we must wait a bit after opening port and before sending message
|
||||
# TODO: can we do better?
|
||||
self._th = threading.Thread(target=self._send_server_start)
|
||||
self._th.start()
|
||||
if self.filename is not None:
|
||||
self._file = self._stream_factory.get_streams(stream_types=['file:'+self.filename], for_write=True)[0]
|
||||
|
||||
def _send_server_start(self):
|
||||
time.sleep(2)
|
||||
self._zmq_stream_pub.write(ServerMgmtMsg(event_name=ServerMgmtMsg.EventServerStart,
|
||||
event_args=self.srv_name), topic=PublisherTopics.ServerMgmt)
|
||||
|
||||
def default_devices(self)->Sequence[str]: # overriden
|
||||
devices = []
|
||||
if self.port is not None:
|
||||
devices.append('tcp:' + str(self.port))
|
||||
if self.filename is not None:
|
||||
devices.append('file:' + self.filename)
|
||||
return devices
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
if self._clisrv is not None:
|
||||
self._clisrv.close()
|
||||
if self._zmq_stream_pub is not None:
|
||||
self._zmq_stream_pub.close()
|
||||
if self._file is not None:
|
||||
self._file.close()
|
||||
utils.debug_log("Watcher is closed", verbosity=1)
|
||||
super(Watcher, self).close()
|
||||
|
||||
def _reset(self):
|
||||
self._clisrv = None
|
||||
self._zmq_stream_pub = None
|
||||
self._file = None
|
||||
self._th = None
|
||||
utils.debug_log("Watcher reset", verbosity=1)
|
||||
super(Watcher, self)._reset()
|
||||
|
||||
def _clisrv_callback(self, clisrv, clisrv_req): # pylint: disable=unused-argument
|
||||
utils.debug_log("Received client request", clisrv_req.req_type)
|
||||
|
||||
# request = create stream
|
||||
if clisrv_req.req_type == CliSrvReqTypes.create_stream:
|
||||
stream_req = clisrv_req.req_data
|
||||
self.create_stream(stream_name=stream_req.stream_name, devices=stream_req.devices,
|
||||
event_name=stream_req.event_name, expr=stream_req.expr, throttle=stream_req.throttle,
|
||||
vis_params=stream_req.vis_params)
|
||||
return None # ignore return as we can't send back stream obj
|
||||
elif clisrv_req.req_type == CliSrvReqTypes.del_stream:
|
||||
stream_name = clisrv_req.req_data
|
||||
return self.del_stream(stream_name)
|
||||
else:
|
||||
raise ValueError('ClientServer Request Type {} is not recognized'.format(clisrv_req))
|
|
@ -0,0 +1,218 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Dict, Any, Sequence
|
||||
from .lv_types import EventVars, StreamItem, StreamCreateRequest, VisParams
|
||||
from .evaler import Evaler
|
||||
from .stream import Stream
|
||||
from .stream_factory import StreamFactory
|
||||
from .filtered_stream import FilteredStream
|
||||
import uuid
|
||||
import time
|
||||
from . import utils
|
||||
|
||||
|
||||
class WatcherBase:
|
||||
class StreamInfo:
|
||||
def __init__(self, req:StreamCreateRequest, evaler:Evaler, stream:Stream,
|
||||
index:int, disabled=False, last_sent:float=None)->None:
|
||||
r"""Holds togaher stream_req, stream and evaler
|
||||
"""
|
||||
self.req, self.evaler, self.stream = req, evaler, stream
|
||||
self.index, self.disabled, self.last_sent = index, disabled, last_sent
|
||||
self.item_count = 0 # creator of StreamItem needs to set to set item num
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.closed = None
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
# for each event, store (stream_name, stream_info)
|
||||
self._stream_infos:Dict[str, Dict[str, WatcherBase.StreamInfo]] = {}
|
||||
|
||||
self._global_vars:Dict[str, Any] = {}
|
||||
self._stream_count = 0
|
||||
|
||||
# factory streams are shared per watcher instance
|
||||
self._stream_factory = StreamFactory()
|
||||
|
||||
# each StreamItem should be stamped by its creator
|
||||
self.creator_id = str(uuid.uuid4())
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
# close all the streams
|
||||
for stream_infos in self._stream_infos.values(): # per event
|
||||
for stream_info in stream_infos.values():
|
||||
stream_info.stream.close()
|
||||
self._stream_factory.close()
|
||||
self._reset() # clean variables
|
||||
self.closed = True
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
self.close()
|
||||
|
||||
def default_devices(self)->Sequence[str]:
|
||||
return None
|
||||
|
||||
def open_stream(self, stream_name:str=None, devices:Sequence[str]=None,
|
||||
event_name:str='')->Stream:
|
||||
r"""Opens stream from specified devices or returns one by name if
|
||||
it was created before.
|
||||
"""
|
||||
# TODO: what if devices were specified AND stream exist in cache?
|
||||
|
||||
# create devices is any
|
||||
devices = devices or self.default_devices()
|
||||
device_streams = None
|
||||
if devices is not None:
|
||||
# we open devices in read-only mode
|
||||
device_streams = self._stream_factory.get_streams(stream_types=devices,
|
||||
for_write=False)
|
||||
# if no devices then open stream by name from cache
|
||||
if device_streams is None:
|
||||
if stream_name is None:
|
||||
raise ValueError('Both device and stream_name cannot be None')
|
||||
|
||||
# first search by event
|
||||
stream_infos = self._stream_infos.get(event_name, None)
|
||||
if stream_infos is None:
|
||||
raise ValueError('Requested event was not found: ' + event_name)
|
||||
# then search by stream name
|
||||
stream_info = stream_infos.get(stream_name, None)
|
||||
if stream_info is None:
|
||||
raise ValueError('Requested stream was not found: ' + stream_name)
|
||||
return stream_info.stream
|
||||
|
||||
# if we have device, first create stream and then attach device to it
|
||||
stream = Stream(stream_name=stream_name)
|
||||
for device_stream in device_streams:
|
||||
# each device may have multiple streams so let's filter it
|
||||
filtered_stream = FilteredStream(source_stream=device_stream,
|
||||
filter_expr=((lambda steam_item: (steam_item, steam_item.stream_name == stream_name)) \
|
||||
if stream_name is not None \
|
||||
else None))
|
||||
stream.subscribe(filtered_stream)
|
||||
stream.held_refs.add(filtered_stream) # otherwise filtered stream will be destroyed by gc
|
||||
return stream
|
||||
|
||||
def create_stream(self, stream_name:str=None, devices:Sequence[str]=None, event_name:str='',
|
||||
expr=None, throttle:float=None, vis_params:VisParams=None)->Stream:
|
||||
|
||||
r"""Create stream with or without expression and attach to devices where
|
||||
it will be written to.
|
||||
"""
|
||||
|
||||
stream_name = stream_name or str(uuid.uuid4())
|
||||
|
||||
# we allow few shortcuts, so modify expression if needed
|
||||
expr = expr
|
||||
if expr=='' or expr=='x':
|
||||
expr = 'map(lambda x:x, l)'
|
||||
elif expr.strip().startswith('lambda '):
|
||||
expr = 'map({}, l)'.format(expr)
|
||||
# else no rewrites
|
||||
|
||||
# if no expression specified then we don't create evaler
|
||||
evaler = Evaler(expr) if expr is not None else None
|
||||
|
||||
# get stream infos for this event
|
||||
stream_infos = self._stream_infos.get(event_name, None)
|
||||
# if first for this event, create dictionary
|
||||
if stream_infos is None:
|
||||
stream_infos = self._stream_infos[event_name] = {}
|
||||
|
||||
stream_info = stream_infos.get(stream_name, None)
|
||||
if not stream_info:
|
||||
utils.debug_log("Creating stream", stream_name)
|
||||
stream = Stream(stream_name=stream_name)
|
||||
devices = devices or self.default_devices()
|
||||
if devices is not None:
|
||||
# attached devices are opened in write-only mode
|
||||
device_streams = self._stream_factory.get_streams(stream_types=devices,
|
||||
for_write=True)
|
||||
for device_stream in device_streams:
|
||||
device_stream.subscribe(stream)
|
||||
stream_req = StreamCreateRequest(stream_name=stream_name, devices=devices, event_name=event_name,
|
||||
expr=expr, throttle=throttle, vis_params=vis_params)
|
||||
stream_info = stream_infos[stream_name] = WatcherBase.StreamInfo(
|
||||
stream_req, evaler, stream, self._stream_count)
|
||||
self._stream_count += 1
|
||||
else:
|
||||
# TODO: throw error?
|
||||
utils.debug_log("Stream already exist, not creating again", stream_name)
|
||||
|
||||
return stream_info.stream
|
||||
|
||||
def set_globals(self, **global_vars):
|
||||
self._global_vars.update(global_vars)
|
||||
|
||||
def observe(self, event_name:str='', **obs_vars) -> None:
|
||||
# get stream requests for this event
|
||||
stream_infos = self._stream_infos.get(event_name, {})
|
||||
|
||||
# TODO: remove list() call - currently needed because of error dictionary
|
||||
# can't be changed - happens when multiple clients gets started
|
||||
for stream_info in list(stream_infos.values()):
|
||||
if stream_info.disabled or stream_info.evaler is None:
|
||||
continue
|
||||
|
||||
# apply throttle
|
||||
if stream_info.req.throttle is None or stream_info.last_sent is None or \
|
||||
time.time() - stream_info.last_sent >= stream_info.req.throttle:
|
||||
stream_info.last_sent = time.time()
|
||||
|
||||
events_vars = EventVars(self._global_vars, **obs_vars)
|
||||
self._eval_wrie(stream_info, events_vars)
|
||||
else:
|
||||
utils.debug_log("Throttled", event_name, verbosity=5)
|
||||
|
||||
def _eval_wrie(self, stream_info:'WatcherBase.StreamInfo', event_vars:EventVars):
|
||||
eval_return = stream_info.evaler.post(event_vars)
|
||||
if eval_return.is_valid:
|
||||
event_name = stream_info.req.event_name
|
||||
stream_item = StreamItem(stream_info.item_count,
|
||||
eval_return.result, stream_info.req.stream_name, self.creator_id, stream_info.index,
|
||||
exception=eval_return.exception)
|
||||
stream_info.stream.write(stream_item)
|
||||
stream_info.item_count += 1
|
||||
utils.debug_log("eval_return sent", event_name, verbosity=5)
|
||||
else:
|
||||
utils.debug_log("Invalid eval_return not sent", verbosity=5)
|
||||
|
||||
def end_event(self, event_name:str='', disable_streams=False) -> None:
|
||||
stream_infos = self._stream_infos.get(event_name, {})
|
||||
for stream_info in stream_infos.values():
|
||||
if not stream_info.disabled:
|
||||
self._end_stream_req(stream_info, disable_streams)
|
||||
|
||||
def _end_stream_req(self, stream_info:'WatcherBase.StreamInfo', disable_stream:bool):
|
||||
eval_return = stream_info.evaler.post(ended=True,
|
||||
continue_thread=not disable_stream)
|
||||
# TODO: check eval_return.is_valid ?
|
||||
# event_name = stream_info.req.event_name
|
||||
if disable_stream:
|
||||
stream_info.disabled = True
|
||||
utils.debug_log("{} stream disabled".format(stream_info.req.stream_name), verbosity=1)
|
||||
|
||||
stream_item = StreamItem(item_index=stream_info.item_count,
|
||||
value=eval_return.result, stream_name=stream_info.req.stream_name,
|
||||
creator_id=self.creator_id, stream_index=stream_info.index,
|
||||
exception=eval_return.exception, ended=True)
|
||||
stream_info.stream.write(stream_item)
|
||||
stream_info.item_count += 1
|
||||
|
||||
def del_stream(self, stream_name:str) -> None:
|
||||
utils.debug_log("deleting stream", stream_name)
|
||||
for stream_infos in self._stream_infos.values(): # per event
|
||||
stream_info = stream_infos.get(stream_name, None)
|
||||
if stream_info:
|
||||
stream_info.disabled = True
|
||||
stream_info.evaler.abort()
|
||||
return True
|
||||
#TODO: to enable delete we need to protect iteration in set_vars
|
||||
#del stream_reqs[stream_info.req.stream_name]
|
||||
return False
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any, Dict, Sequence, List
|
||||
from .zmq_wrapper import ZmqWrapper
|
||||
from .lv_types import CliSrvReqTypes, ClientServerRequest, DefaultPorts
|
||||
from .lv_types import VisParams, PublisherTopics, ServerMgmtMsg, StreamCreateRequest
|
||||
from .stream import Stream
|
||||
from .zmq_mgmt_stream import ZmqMgmtStream
|
||||
from . import utils
|
||||
from .watcher_base import WatcherBase
|
||||
|
||||
class WatcherClient(WatcherBase):
|
||||
r"""Extends watcher to add methods so calls for create and delete stream can be sent to server.
|
||||
"""
|
||||
def __init__(self, filename:str=None, port:int=0):
|
||||
super(WatcherClient, self).__init__()
|
||||
self.port = port
|
||||
self.filename = filename
|
||||
|
||||
# define vars in __init__
|
||||
self._clisrv = None # client-server sockets allows to send create/del stream requests
|
||||
self._zmq_srvmgmt_sub = None
|
||||
self._file = None
|
||||
|
||||
self._open()
|
||||
|
||||
def _reset(self):
|
||||
self._clisrv = None
|
||||
self._zmq_srvmgmt_sub = None
|
||||
self._file = None
|
||||
utils.debug_log("WatcherClient reset", verbosity=1)
|
||||
super(WatcherClient, self)._reset()
|
||||
|
||||
def _open(self):
|
||||
if self.port is not None:
|
||||
self._clisrv = ZmqWrapper.ClientServer(port=DefaultPorts.CliSrv+self.port,
|
||||
is_server=False)
|
||||
# create subscription where we will receive server management events
|
||||
self._zmq_srvmgmt_sub = ZmqMgmtStream(clisrv=self._clisrv, for_write=False, port=self.port,
|
||||
stream_name='zmq_srvmgmt_sub:'+str(self.port)+':False')
|
||||
if self.filename is not None:
|
||||
self._file = self._stream_factory.get_streams(stream_types=['file:'+self.filename], for_write=False)[0]
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self._zmq_srvmgmt_sub.close()
|
||||
self._clisrv.close()
|
||||
utils.debug_log("WatcherClient is closed", verbosity=1)
|
||||
super(WatcherClient, self).close()
|
||||
|
||||
def default_devices(self)->Sequence[str]: # overriden
|
||||
devices = []
|
||||
if self.port is not None:
|
||||
devices.append('tcp:' + str(self.port))
|
||||
if self.filename is not None:
|
||||
devices.append('file:' + self.filename)
|
||||
return devices
|
||||
|
||||
# override to send request to server, instead of underlying WatcherBase base class
|
||||
def create_stream(self, stream_name:str=None, devices:Sequence[str]=None, event_name:str='',
|
||||
expr=None, throttle:float=1, vis_params:VisParams=None)->Stream: # overriden
|
||||
|
||||
stream_req = StreamCreateRequest(stream_name=stream_name, devices=devices or self.default_devices(),
|
||||
event_name=event_name, expr=expr, throttle=throttle, vis_params=vis_params)
|
||||
|
||||
self._zmq_srvmgmt_sub.add_stream_req(stream_req)
|
||||
|
||||
if stream_req.devices is not None:
|
||||
stream = self.open_stream(stream_name=stream_req.stream_name,
|
||||
devices=stream_req.devices, event_name=stream_req.event_name)
|
||||
else: # we cannot return remote streams that are not backed by a device
|
||||
stream = None
|
||||
return stream
|
||||
|
||||
# override to set devices default to tcp
|
||||
def open_stream(self, stream_name:str=None, devices:Sequence[str]=None,
|
||||
event_name:str='')->Stream: # overriden
|
||||
|
||||
return super(WatcherClient, self).open_stream(stream_name=stream_name, devices=devices,
|
||||
event_name=event_name)
|
||||
|
||||
|
||||
# override to send request to server
|
||||
def del_stream(self, stream_name:str) -> None:
|
||||
self._zmq_srvmgmt_sub.del_stream(stream_req)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any
|
||||
from .zmq_stream import ZmqStream
|
||||
from .lv_types import PublisherTopics, ServerMgmtMsg, StreamCreateRequest
|
||||
from .zmq_wrapper import ZmqWrapper
|
||||
from .lv_types import CliSrvReqTypes, ClientServerRequest
|
||||
from . import utils
|
||||
|
||||
class ZmqMgmtStream(ZmqStream):
|
||||
# default topic is mgmt
|
||||
def __init__(self, clisrv:ZmqWrapper.ClientServer, for_write:bool, port:int=0, topic=PublisherTopics.ServerMgmt, block_until_connected=True,
|
||||
stream_name:str=None, console_debug:bool=False):
|
||||
super(ZmqMgmtStream, self).__init__(for_write=for_write, port=port, topic=topic,
|
||||
block_until_connected=block_until_connected, stream_name=stream_name, console_debug=console_debug)
|
||||
|
||||
self._clisrv = clisrv
|
||||
self._stream_reqs:Dict[str,StreamCreateRequest] = {}
|
||||
|
||||
def write(self, mgmt_msg:Any, from_stream:'Stream'=None):
|
||||
r"""Handles server management events.
|
||||
"""
|
||||
utils.debug_log("Received - SeverMgmtevent", mgmt_msg)
|
||||
# if server was restarted then send create stream requests again
|
||||
if mgmt_msg.event_name == ServerMgmtMsg.EventServerStart:
|
||||
for stream_req in self._stream_reqs.values():
|
||||
self._send_create_stream(stream_req)
|
||||
|
||||
super(ZmqMgmtStream, self).write(mgmt_msg)
|
||||
|
||||
def add_stream_req(self, stream_req:StreamCreateRequest)->None:
|
||||
self._send_create_stream(stream_req)
|
||||
|
||||
# save this for later for resend if server restarts
|
||||
self._stream_reqs[stream_req.stream_name] = stream_req
|
||||
|
||||
# override to send request to server
|
||||
def del_stream(self, stream_name:str) -> None:
|
||||
clisrv_req = ClientServerRequest(CliSrvReqTypes.del_stream, stream_name)
|
||||
self._clisrv.send_obj(clisrv_req)
|
||||
self._stream_reqs.pop(stream_name, None)
|
||||
|
||||
def _send_create_stream(self, stream_req):
|
||||
utils.debug_log("sending create streamreq...")
|
||||
clisrv_req = ClientServerRequest(CliSrvReqTypes.create_stream, stream_req)
|
||||
self._clisrv.send_obj(clisrv_req)
|
||||
utils.debug_log("sent create streamreq")
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self._stream_reqs = {}
|
||||
self._clisrv = None
|
||||
utils.debug_log('ZmqMgmtStream is closed', verbosity=1)
|
||||
super(ZmqMgmtStream, self).close()
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Any
|
||||
from .zmq_wrapper import ZmqWrapper
|
||||
from .stream import Stream
|
||||
from .lv_types import DefaultPorts, PublisherTopics
|
||||
from . import utils
|
||||
|
||||
# on writes send data on ZMQ transport
|
||||
class ZmqStream(Stream):
|
||||
def __init__(self, for_write:bool, port:int=0, topic=PublisherTopics.StreamItem, block_until_connected=True,
|
||||
stream_name:str=None, console_debug:bool=False):
|
||||
super(ZmqStream, self).__init__(stream_name=stream_name, console_debug=console_debug)
|
||||
|
||||
self.for_write = for_write
|
||||
self._zmq = None
|
||||
|
||||
self.topic = topic
|
||||
self._open(for_write, port, block_until_connected)
|
||||
utils.debug_log('ZmqStream started', verbosity=1)
|
||||
|
||||
def _open(self, for_write:bool, port:int, block_until_connected:bool):
|
||||
if for_write:
|
||||
self._zmq = ZmqWrapper.Publication(port=DefaultPorts.PubSub+port,
|
||||
block_until_connected=block_until_connected)
|
||||
else:
|
||||
self._zmq = ZmqWrapper.Subscription(port=DefaultPorts.PubSub+port,
|
||||
topic=self.topic, callback=self._on_subscription_item)
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self._zmq.close()
|
||||
self._zmq = None
|
||||
utils.debug_log('ZmqStream is closed', verbosity=1)
|
||||
super(ZmqStream, self).close()
|
||||
|
||||
def _on_subscription_item(self, val:Any):
|
||||
utils.debug_log('Received subscription item', verbosity=5)
|
||||
self.write(val)
|
||||
|
||||
def write(self, val:Any, from_stream:'Stream'=None, topic=None):
|
||||
super(ZmqStream, self).write(val)
|
||||
if self.for_write:
|
||||
topic = topic or self.topic
|
||||
utils.debug_log('Sent subscription item', verbosity=5)
|
||||
self._zmq.send_obj(val, topic)
|
||||
# else if this was opened for read then we have subscription and
|
||||
# we shouldn't be calling send_obj
|
|
@ -0,0 +1,290 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import zmq
|
||||
import errno
|
||||
import pickle
|
||||
from zmq.eventloop import ioloop, zmqstream
|
||||
import zmq.utils.monitor
|
||||
import functools, sys, logging, traceback
|
||||
from threading import Thread, Event
|
||||
from . import utils
|
||||
import weakref
|
||||
|
||||
class ZmqWrapper:
|
||||
|
||||
_thread:Thread = None
|
||||
_ioloop:ioloop.IOLoop = None
|
||||
_start_event:Event = None
|
||||
_ioloop_block:Event = None # indicates if there is any blocking IOLoop call in progress
|
||||
|
||||
@staticmethod
|
||||
def initialize():
|
||||
# create thread that will wait on IO Loop
|
||||
if ZmqWrapper._thread is None:
|
||||
ZmqWrapper._thread = Thread(target=ZmqWrapper._run_io_loop, name='ZMQIOLoop', daemon=True)
|
||||
ZmqWrapper._start_event = Event()
|
||||
ZmqWrapper._ioloop_block = Event()
|
||||
ZmqWrapper._ioloop_block.set() # no blocking call in progress right now
|
||||
ZmqWrapper._thread.start()
|
||||
# this is needed to make sure IO Loop has enough time to start
|
||||
ZmqWrapper._start_event.wait()
|
||||
|
||||
@staticmethod
|
||||
def close():
|
||||
# terminate the IO Loop
|
||||
if ZmqWrapper._thread is not None:
|
||||
ZmqWrapper._ioloop_block.set() # free any blocking call
|
||||
ZmqWrapper._ioloop.add_callback(ZmqWrapper._ioloop.stop)
|
||||
ZmqWrapper._thread = None
|
||||
ZmqWrapper._ioloop = None
|
||||
print("ZMQ IOLoop is now closed")
|
||||
|
||||
@staticmethod
|
||||
def get_timer(secs, callback, start=True):
|
||||
utils.debug_log("Adding PeriodicCallback", secs)
|
||||
pc = ioloop.PeriodicCallback(callback, secs * 1e3)
|
||||
if (start):
|
||||
pc.start()
|
||||
return pc
|
||||
|
||||
@staticmethod
|
||||
def _run_io_loop():
|
||||
if 'asyncio' in sys.modules:
|
||||
# tornado may be using asyncio,
|
||||
# ensure an eventloop exists for this thread
|
||||
import asyncio
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
|
||||
ZmqWrapper._ioloop = ioloop.IOLoop()
|
||||
ZmqWrapper._ioloop.make_current()
|
||||
while ZmqWrapper._thread is not None:
|
||||
try:
|
||||
ZmqWrapper._start_event.set()
|
||||
utils.debug_log("starting ioloop...")
|
||||
ZmqWrapper._ioloop.start()
|
||||
except zmq.ZMQError as ex:
|
||||
if ex.errno == errno.EINTR:
|
||||
print("Cannot start IOLoop! ZMQError: {}".format(ex), file=sys.stderr)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
# Utility method to run given function on IOLoop
|
||||
# this is blocking method if has_rresult=True
|
||||
# This can be called from any thread.
|
||||
@staticmethod
|
||||
def _io_loop_call(has_result, f, *kargs, **kwargs):
|
||||
class Result:
|
||||
def __init__(self, val=None):
|
||||
self.val = val
|
||||
|
||||
def wrapper(f, r, *kargs, **kwargs):
|
||||
try:
|
||||
r.val = f(*kargs, **kwargs)
|
||||
ZmqWrapper._ioloop_block.set()
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
logging.fatal(ex, exc_info=True)
|
||||
traceback.print_exc(file=sys.stdout)
|
||||
|
||||
# We will add callback in IO Loop and then wait for that
|
||||
# call back to be completed
|
||||
# If result is expected then we wait other wise fire and forget
|
||||
if has_result:
|
||||
if not ZmqWrapper._ioloop_block.is_set():
|
||||
# TODO: better way to raise this error?
|
||||
print('Previous blocking call on IOLoop is not yet complete!')
|
||||
ZmqWrapper._ioloop_block.clear()
|
||||
r = Result()
|
||||
f_wrapped = functools.partial(wrapper, f, r, *kargs, **kwargs)
|
||||
ZmqWrapper._ioloop.add_callback(f_wrapped)
|
||||
utils.debug_log("Waiting for call on ioloop", f, verbosity=5)
|
||||
ZmqWrapper._ioloop_block.wait()
|
||||
utils.debug_log("Call on ioloop done", f, verbosity=5)
|
||||
return r.val
|
||||
else:
|
||||
f_wrapped = functools.partial(f, *kargs, **kwargs)
|
||||
ZmqWrapper._ioloop.add_callback(f_wrapped)
|
||||
|
||||
class Publication:
|
||||
def __init__(self, port, host="*", block_until_connected=True):
|
||||
# define vars
|
||||
self._socket = None
|
||||
self._mon_socket = None
|
||||
self._mon_stream = None
|
||||
|
||||
ZmqWrapper.initialize()
|
||||
utils.debug_log('Creating Publication', port, verbosity=1)
|
||||
# make sure the call blocks until connection is made
|
||||
ZmqWrapper._io_loop_call(block_until_connected, self._start_srv, port, host)
|
||||
|
||||
def _start_srv(self, port, host):
|
||||
context = zmq.Context()
|
||||
self._socket = context.socket(zmq.PUB)
|
||||
utils.debug_log('Binding socket', (host, port), verbosity=5)
|
||||
self._socket.bind("tcp://%s:%d" % (host, port))
|
||||
utils.debug_log('Bound socket', (host, port), verbosity=5)
|
||||
self._mon_socket = self._socket.get_monitor_socket(zmq.EVENT_CONNECTED | zmq.EVENT_DISCONNECTED)
|
||||
self._mon_stream = zmqstream.ZMQStream(self._mon_socket)
|
||||
self._mon_stream.on_recv(self._on_mon)
|
||||
|
||||
def close(self):
|
||||
if self._socket:
|
||||
ZmqWrapper._io_loop_call(False, self._socket.close)
|
||||
|
||||
# we need this wrapper method as self._socket might not be there yet
|
||||
def _send_multipart(self, parts):
|
||||
#utils.debug_log('_send_multipart', parts, verbosity=6)
|
||||
return self._socket.send_multipart(parts)
|
||||
|
||||
def send_obj(self, obj, topic=""):
|
||||
ZmqWrapper._io_loop_call(False, self._send_multipart,
|
||||
[topic.encode(), pickle.dumps(obj)])
|
||||
|
||||
def _on_mon(self, msg):
|
||||
ev = zmq.utils.monitor.parse_monitor_message(msg)
|
||||
event = ev['event']
|
||||
endpoint = ev['endpoint']
|
||||
if event == zmq.EVENT_CONNECTED:
|
||||
utils.debug_log("Subscriber connect event", endpoint, verbosity=1)
|
||||
elif event == zmq.EVENT_DISCONNECTED:
|
||||
utils.debug_log("Subscriber disconnect event", endpoint, verbosity=1)
|
||||
|
||||
|
||||
class Subscription:
|
||||
# subscribe to topic, call callback when object is received on topic
|
||||
def __init__(self, port, topic="", callback=None, host="localhost"):
|
||||
self._socket = None
|
||||
self._stream = None
|
||||
self.topic = None
|
||||
|
||||
ZmqWrapper.initialize()
|
||||
utils.debug_log('Creating Subscription', port, verbosity=1)
|
||||
ZmqWrapper._io_loop_call(False, self._add_sub,
|
||||
port, topic=topic, callback=callback, host=host)
|
||||
|
||||
def close(self):
|
||||
if self._socket:
|
||||
ZmqWrapper._io_loop_call(False, self._socket.close)
|
||||
|
||||
def _add_sub(self, port, topic, callback, host):
|
||||
def callback_wrapper(weak_callback, msg):
|
||||
[topic, obj_s] = msg # pylint: disable=unused-variable
|
||||
try:
|
||||
if weak_callback and weak_callback():
|
||||
weak_callback()(pickle.loads(obj_s))
|
||||
except Exception as ex:
|
||||
print(ex, file=sys.stderr) # TODO: standardize this
|
||||
raise
|
||||
|
||||
# connect to stream socket
|
||||
context = zmq.Context()
|
||||
self.topic = topic.encode()
|
||||
self._socket = context.socket(zmq.SUB)
|
||||
|
||||
utils.debug_log("Subscriber connecting...", (host, port), verbosity=1)
|
||||
self._socket.connect("tcp://%s:%d" % (host, port))
|
||||
utils.debug_log("Subscriber connected!", (host, port), verbosity=1)
|
||||
|
||||
# setup socket filtering
|
||||
if topic != "":
|
||||
self._socket.setsockopt(zmq.SUBSCRIBE, self.topic)
|
||||
|
||||
# if callback is specified then create a stream and set it
|
||||
# for on_recv event - this would require running ioloop
|
||||
if callback is not None:
|
||||
self._stream = zmqstream.ZMQStream(self._socket)
|
||||
wr_cb = weakref.WeakMethod(callback)
|
||||
wrapper = functools.partial(callback_wrapper, wr_cb)
|
||||
self._stream.on_recv(wrapper)
|
||||
#else use receive_obj
|
||||
|
||||
def _receive_obj(self):
|
||||
[topic, obj_s] = self._socket.recv_multipart() # pylint: disable=unbalanced-tuple-unpacking
|
||||
if topic != self.topic:
|
||||
raise ValueError("Expected topic: %s, Received topic: %s" % (topic, self.topic))
|
||||
return pickle.loads(obj_s)
|
||||
|
||||
def receive_obj(self):
|
||||
return ZmqWrapper._io_loop_call(True, self._receive_obj)
|
||||
|
||||
def _get_socket_identity(self):
|
||||
ep_id = self._socket.getsockopt(zmq.LAST_ENDPOINT)
|
||||
return ep_id
|
||||
|
||||
def get_socket_identity(self):
|
||||
return ZmqWrapper._io_loop_call(True, self._get_socket_identity)
|
||||
|
||||
|
||||
class ClientServer:
|
||||
def __init__(self, port, is_server, callback=None, host=None):
|
||||
self._socket = None
|
||||
self._stream = None
|
||||
|
||||
ZmqWrapper.initialize()
|
||||
utils.debug_log('Creating ClientServer', (is_server, port), verbosity=1)
|
||||
|
||||
# make sure call blocks until connection is made
|
||||
# otherwise variables would not be available
|
||||
ZmqWrapper._io_loop_call(True, self._connect,
|
||||
port, is_server, callback, host)
|
||||
|
||||
def close(self):
|
||||
if self._socket:
|
||||
ZmqWrapper._io_loop_call(False, self._socket.close)
|
||||
|
||||
def _connect(self, port, is_server, callback, host):
|
||||
def callback_wrapper(callback, msg):
|
||||
utils.debug_log("Server received request...", verbosity=6)
|
||||
|
||||
[obj_s] = msg
|
||||
try:
|
||||
ret = callback(self, pickle.loads(obj_s))
|
||||
# we must send reply to complete the cycle
|
||||
self._socket.send_multipart([pickle.dumps((ret, None))])
|
||||
except Exception as ex:
|
||||
print("ClientServer call raised exception: ", ex, file=sys.stderr)
|
||||
# we must send reply to complete the cycle
|
||||
self._socket.send_multipart([pickle.dumps((None, ex))])
|
||||
|
||||
utils.debug_log("Server sent response", verbosity=6)
|
||||
|
||||
context = zmq.Context()
|
||||
if is_server:
|
||||
host = host or "127.0.0.1"
|
||||
self._socket = context.socket(zmq.REP)
|
||||
utils.debug_log('Binding socket', (host, port), verbosity=5)
|
||||
self._socket.bind("tcp://%s:%d" % (host, port))
|
||||
utils.debug_log('Bound socket', (host, port), verbosity=5)
|
||||
else:
|
||||
host = host or "localhost"
|
||||
self._socket = context.socket(zmq.REQ)
|
||||
self._socket.setsockopt(zmq.REQ_CORRELATE, 1)
|
||||
self._socket.setsockopt(zmq.REQ_RELAXED, 1)
|
||||
|
||||
utils.debug_log("Client connecting...", verbosity=1)
|
||||
self._socket.connect("tcp://%s:%d" % (host, port))
|
||||
utils.debug_log("Client connected!", verbosity=1)
|
||||
|
||||
if callback is not None:
|
||||
self._stream = zmqstream.ZMQStream(self._socket)
|
||||
wrapper = functools.partial(callback_wrapper, callback)
|
||||
self._stream.on_recv(wrapper)
|
||||
#else use receive_obj
|
||||
|
||||
def send_obj(self, obj):
|
||||
ZmqWrapper._io_loop_call(False, self._socket.send_multipart,
|
||||
[pickle.dumps(obj)])
|
||||
|
||||
def receive_obj(self):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
[obj_s] = ZmqWrapper._io_loop_call(True, self._socket.recv_multipart)
|
||||
return pickle.loads(obj_s)
|
||||
|
||||
def request(self, req_obj):
|
||||
utils.debug_log("Client sending request...", verbosity=6)
|
||||
self.send_obj(req_obj)
|
||||
r = self.receive_obj()
|
||||
utils.debug_log("Client received response", verbosity=6)
|
||||
return r
|
|
@ -0,0 +1,18 @@
|
|||
import tensorwatch as tw
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
fruits_ds = datasets.ImageFolder(r'D:\datasets\fruits-360\Training')
|
||||
mnist_ds = datasets.MNIST('../data', train=True, download=True)
|
||||
|
||||
images = [tw.ImagePlotItem(fruits_ds[i][0], title=str(i)) for i in range(5)] + \
|
||||
[tw.ImagePlotItem(mnist_ds[i][0], title=str(i)) for i in range(5)]
|
||||
|
||||
stream = tw.ArrayStream(images)
|
||||
|
||||
img_plot = tw.Visualizer(stream, vis_type='image', viz_img_scale=3)
|
||||
img_plot.show()
|
||||
|
||||
tw.image_utils.plt_loop()
|
|
@ -0,0 +1,15 @@
|
|||
import tensorwatch as tw
|
||||
import objgraph, time #pip install objgraph
|
||||
|
||||
cli = tw.WatcherClient()
|
||||
time.sleep(10)
|
||||
del cli
|
||||
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
import time
|
||||
time.sleep(2)
|
||||
|
||||
objgraph.show_backrefs(objgraph.by_type('WatcherClient'), refcounts=True, filename='b.png')
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
import tensorwatch as tw
|
||||
from tensorwatch import evaler
|
||||
|
||||
|
||||
e = evaler.Evaler('reduce(lambda x,y: x+y, map(lambda x:x**2, filter(lambda x: x%2==0, l)))')
|
||||
for i in range(5):
|
||||
eval_return = e.post(i)
|
||||
print(i, eval_return)
|
||||
eval_return = e.post(ended=True)
|
||||
print(i, eval_return)
|
|
@ -0,0 +1,15 @@
|
|||
import ipywidgets as widgets
|
||||
from IPython import get_ipython
|
||||
|
||||
class PrinterX:
|
||||
def __init__(self):
|
||||
self.w = w=widgets.HTML()
|
||||
def show(self):
|
||||
return self.w
|
||||
def write(self,s):
|
||||
self.w.value = s
|
||||
|
||||
print("Running from within ipython?", get_ipython() is not None)
|
||||
p=PrinterX()
|
||||
p.show()
|
||||
p.write('ffffffffff')
|
|
@ -0,0 +1,20 @@
|
|||
from tensorwatch.watcher_base import WatcherBase
|
||||
from tensorwatch.mpl.line_plot import LinePlot
|
||||
from tensorwatch.image_utils import plt_loop
|
||||
from tensorwatch.stream import Stream
|
||||
from tensorwatch.lv_types import StreamItem
|
||||
|
||||
|
||||
def main():
|
||||
watcher = WatcherBase()
|
||||
line_plot = LinePlot()
|
||||
stream = watcher.create_stream(expr='lambda vars:vars.x')
|
||||
line_plot.subscribe(stream)
|
||||
line_plot.show()
|
||||
|
||||
for i in range(5):
|
||||
watcher.observe(x=(i, i*i))
|
||||
plt_loop()
|
||||
|
||||
main()
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
import pandas as pd
|
||||
|
||||
class SomeThing:
|
||||
def __init__(self, x, y):
|
||||
self.x, self.y = x, y
|
||||
|
||||
things = [SomeThing(1,2), SomeThing(3,4), SomeThing(4,5)]
|
||||
|
||||
df = pd.DataFrame([t.__dict__ for t in things ])
|
||||
|
||||
print(df.iloc[[-1]].to_dict('records')[0])
|
||||
#print(df.to_html())
|
||||
print(df.style.render())
|
|
@ -0,0 +1,19 @@
|
|||
from tensorwatch.watcher_base import WatcherBase
|
||||
from tensorwatch.plotly.line_plot import LinePlot
|
||||
from tensorwatch.image_utils import plt_loop
|
||||
from tensorwatch.stream import Stream
|
||||
from tensorwatch.lv_types import StreamItem
|
||||
|
||||
|
||||
def main():
|
||||
watcher = WatcherBase()
|
||||
line_plot = LinePlot()
|
||||
stream = watcher.create_stream(expr='lambda vars:vars.x')
|
||||
line_plot.subscribe(stream)
|
||||
line_plot.show()
|
||||
|
||||
for i in range(5):
|
||||
watcher.observe(x=(i, i*i))
|
||||
|
||||
main()
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from tensorwatch.stream import Stream
|
||||
|
||||
|
||||
s1 = Stream(stream_name='s1', console_debug=True)
|
||||
s2 = Stream(stream_name='s2', console_debug=True)
|
||||
s3 = Stream(stream_name='s3', console_debug=True)
|
||||
|
||||
s1.subscribe(s2)
|
||||
s2.subscribe(s3)
|
||||
|
||||
s3.write('S3 wrote this')
|
||||
s2.write('S2 wrote this')
|
||||
s1.write('S1 wrote this')
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
import threading
|
||||
import time
|
||||
import sys
|
||||
|
||||
def handler(a,b=None):
|
||||
sys.exit(1)
|
||||
def install_handler():
|
||||
if sys.platform == "win32":
|
||||
if sys.stdin is not None and sys.stdin.isatty():
|
||||
#this is Console based application
|
||||
import win32api
|
||||
win32api.SetConsoleCtrlHandler(handler, True)
|
||||
|
||||
|
||||
def work():
|
||||
time.sleep(10000)
|
||||
t = threading.Thread(target=work, name='ThreadTest')
|
||||
t.daemon = True
|
||||
t.start()
|
||||
while(True):
|
||||
t.join(0.1) #100ms ~ typical human response
|
||||
# you will get KeyboardIntrupt exception
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from tensorwatch.watcher_base import WatcherBase
|
||||
from tensorwatch.stream import Stream
|
||||
|
||||
def main():
|
||||
watcher = WatcherBase()
|
||||
console_pub = Stream(stream_name = 'S1', console_debug=True)
|
||||
stream = watcher.create_stream(expr='lambda vars:vars.x**2')
|
||||
console_pub.subscribe(stream)
|
||||
|
||||
for i in range(5):
|
||||
watcher.observe(x=i)
|
||||
|
||||
main()
|
||||
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
import tensorwatch as tw
|
||||
import time
|
||||
import math
|
||||
from tensorwatch import utils
|
||||
|
||||
utils.set_debug_verbosity(4)
|
||||
|
||||
def dlc_show_rand_outputs():
|
||||
cli = cli_train = tw.WatcherClient()
|
||||
imgs = cli.create_stream(event_name='batch',
|
||||
expr="top(l, out_xform=pyt_img_img_out_xform, group_key=lambda x:'', topk=10, order='rnd')",
|
||||
throttle=1)
|
||||
img_plot = tw.mpl.ImagePlot()
|
||||
img_plot.show(imgs, img_width=39, img_height=69, viz_img_scale=10)
|
||||
|
||||
utils.wait_key()
|
||||
|
||||
def img2img_rnd():
|
||||
cli_train = tw.WatcherClient()
|
||||
cli = tw.WatcherClient()
|
||||
|
||||
imgs = cli_train.create_stream(event_name='batch',
|
||||
expr="top(l, out_xform=pyt_img_img_out_xform, group_key=lambda x:'', topk=2, order='rnd')",
|
||||
throttle=1)
|
||||
img_plot = tw.mpl.ImagePlot()
|
||||
img_plot.show(imgs, img_width=100, img_height=100, viz_img_scale=3, cols=1)
|
||||
|
||||
utils.wait_key()
|
||||
|
||||
dlc_show_rand_outputs()
|
||||
img2img_rnd()
|
|
@ -0,0 +1,28 @@
|
|||
from tensorwatch.watcher_base import WatcherBase
|
||||
from tensorwatch.stream import Stream
|
||||
from tensorwatch.file_stream import FileStream
|
||||
from tensorwatch.mpl.line_plot import LinePlot
|
||||
from tensorwatch.image_utils import plt_loop
|
||||
import tensorwatch as tw
|
||||
|
||||
def file_write():
|
||||
watcher = WatcherBase()
|
||||
stream = watcher.create_stream(expr='lambda vars:(vars.x, vars.x**2)',
|
||||
devices=[r'c:\temp\obs.txt'])
|
||||
|
||||
for i in range(5):
|
||||
watcher.observe(x=i)
|
||||
|
||||
def file_read():
|
||||
watcher = WatcherBase()
|
||||
stream = watcher.open_stream(devices=[r'c:\temp\obs.txt'])
|
||||
vis = tw.Visualizer(stream, vis_type='mpl-line')
|
||||
vis.show()
|
||||
plt_loop()
|
||||
|
||||
def main():
|
||||
file_write()
|
||||
file_read()
|
||||
|
||||
main()
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from matplotlib import pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
from random import randrange
|
||||
from threading import Thread
|
||||
import time
|
||||
|
||||
class LiveGraph:
|
||||
def __init__(self):
|
||||
self.x_data, self.y_data = [], []
|
||||
self.figure = plt.figure()
|
||||
self.line, = plt.plot(self.x_data, self.y_data)
|
||||
self.animation = FuncAnimation(self.figure, self.update, interval=1000)
|
||||
self.th = Thread(target=self.thread_f, name='LiveGraph', daemon=True)
|
||||
self.th.start()
|
||||
|
||||
def update(self, frame):
|
||||
self.line.set_data(self.x_data, self.y_data)
|
||||
self.figure.gca().relim()
|
||||
self.figure.gca().autoscale_view()
|
||||
return self.line,
|
||||
|
||||
def show(self):
|
||||
plt.show()
|
||||
|
||||
def thread_f(self):
|
||||
x = 0
|
||||
while True:
|
||||
self.x_data.append(x)
|
||||
x += 1
|
||||
self.y_data.append(randrange(0, 100))
|
||||
time.sleep(1)
|
|
@ -0,0 +1,110 @@
|
|||
import tensorwatch as tw
|
||||
import time
|
||||
import math
|
||||
from tensorwatch import utils
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
utils.set_debug_verbosity(4)
|
||||
|
||||
def img_in_class():
|
||||
cli_train = tw.WatcherClient()
|
||||
|
||||
imgs = cli_train.create_stream(event_name='batch',
|
||||
expr="top(l, out_xform=pyt_img_class_out_xform, order='rnd')", throttle=1)
|
||||
img_plot = tw.mpl.ImagePlot()
|
||||
img_plot.subscribe(imgs, viz_img_scale=3)
|
||||
img_plot.show()
|
||||
|
||||
tw.image_utils.plt_loop()
|
||||
|
||||
def show_find_lr():
|
||||
cli_train = tw.WatcherClient()
|
||||
plot = tw.mpl.LinePlot()
|
||||
|
||||
train_batch_loss = cli_train.create_stream(event_name='batch',
|
||||
expr='lambda d:(d.tt.scheduler.get_lr()[0], d.metrics.batch_loss)')
|
||||
plot.subscribe(train_batch_loss, xtitle='Epoch', ytitle='Loss')
|
||||
|
||||
utils.wait_key()
|
||||
|
||||
def plot_grads():
|
||||
train_cli = tw.WatcherClient()
|
||||
grads = train_cli.create_stream(event_name='batch',
|
||||
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
|
||||
p = tw.plotly.LinePlot('Demo')
|
||||
p.subscribe(grads, xtitle='Epoch', ytitle='Gradients', history_len=30, new_on_eval=True)
|
||||
utils.wait_key()
|
||||
|
||||
|
||||
def plot_grads1():
|
||||
train_cli = tw.WatcherClient()
|
||||
|
||||
grads = train_cli.create_stream(event_name='batch',
|
||||
expr='lambda d:agg_params(d.model, lambda p: p.grad.abs().mean().item())', throttle=1)
|
||||
grad_plot = tw.mpl.LinePlot()
|
||||
grad_plot.subscribe(grads, xtitle='Epoch', ytitle='Gradients', clear_after_each=1, history_len=40, dim_history=True)
|
||||
grad_plot.show()
|
||||
|
||||
tw.plt_loop()
|
||||
|
||||
def plot_weight():
|
||||
train_cli = tw.WatcherClient()
|
||||
|
||||
params = train_cli.create_stream(event_name='batch',
|
||||
expr='lambda d:agg_params(d.model, lambda p: p.abs().mean().item())', throttle=1)
|
||||
params_plot = tw.mpl.LinePlot()
|
||||
params_plot.subscribe(params, xtitle='Epoch', ytitle='avg |params|', clear_after_each=1, history_len=40, dim_history=True)
|
||||
params_plot.show()
|
||||
|
||||
tw.plt_loop()
|
||||
|
||||
def epoch_stats():
|
||||
train_cli = tw.WatcherClient(port=0)
|
||||
test_cli = tw.WatcherClient(port=1)
|
||||
|
||||
plot = tw.mpl.LinePlot()
|
||||
|
||||
train_loss = train_cli.create_stream(event_name="epoch",
|
||||
expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_loss)')
|
||||
plot.subscribe(train_loss, xtitle='Epoch', ytitle='Train Loss')
|
||||
|
||||
test_acc = test_cli.create_stream(event_name="epoch",
|
||||
expr='lambda v:(v.metrics.epoch_index, v.metrics.epoch_accuracy)')
|
||||
plot.subscribe(test_acc, xtitle='Epoch', ytitle='Test Accuracy', ylim=(0,1))
|
||||
|
||||
plot.show()
|
||||
tw.plt_loop()
|
||||
|
||||
|
||||
def batch_stats():
|
||||
train_cli = tw.WatcherClient()
|
||||
stream = train_cli.create_stream(event_name="batch",
|
||||
expr='lambda v:(v.metrics.epochf, v.metrics.batch_loss)', throttle=0.75)
|
||||
|
||||
train_loss = tw.Visualizer(stream, clear_after_end=False, vis_type='mpl-line',
|
||||
xtitle='Epoch', ytitle='Train Loss')
|
||||
|
||||
#train_acc = tw.Visualizer('lambda v:(v.metrics.epochf, v.metrics.epoch_loss)', event_name="batch",
|
||||
# xtitle='Epoch', ytitle='Train Accuracy', clear_after_end=False, yrange=(0,1),
|
||||
# vis=train_loss, vis_type='mpl-line')
|
||||
|
||||
train_loss.show()
|
||||
tw.image_utils.plt_loop()
|
||||
|
||||
def text_stats():
|
||||
train_cli = tw.WatcherClient()
|
||||
stream = train_cli.create_stream(event_name="batch",
|
||||
expr='lambda d:(d.x, d.metrics.batch_loss)')
|
||||
|
||||
trl = tw.Visualizer(stream, vis_type=None)
|
||||
trl.show()
|
||||
input('Paused...')
|
||||
|
||||
|
||||
|
||||
#epoch_stats()
|
||||
#plot_weight()
|
||||
#plot_grads1()
|
||||
#img_in_class()
|
||||
#text_stats()
|
||||
batch_stats()
|
|
@ -0,0 +1,16 @@
|
|||
from tensorwatch.saliency import saliency
|
||||
from tensorwatch import image_utils, imagenet_utils, pytorch_utils
|
||||
|
||||
model = pytorch_utils.get_model('resnet50')
|
||||
raw_input, input, target_class = pytorch_utils.image_class2tensor('../data/dogs.png', 240, #'../data/elephant.png', 101,
|
||||
image_transform=imagenet_utils.get_image_transform(), image_convert_mode='RGB')
|
||||
|
||||
results = saliency.get_image_saliency_results(model, raw_input, input, target_class)
|
||||
figure = saliency.get_image_saliency_plot(results)
|
||||
|
||||
image_utils.plt_loop()
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
import torch
|
||||
import torchvision.models
|
||||
import tensorwatch as tw
|
||||
|
||||
vgg16_model = torchvision.models.vgg16()
|
||||
|
||||
drawing = tw.draw_model(vgg16_model, [1, 3, 224, 224])
|
||||
drawing.save('abc')
|
||||
|
||||
input("Press any key")
|
|
@ -0,0 +1,10 @@
|
|||
import tensorwatch as tw
|
||||
|
||||
from regim import *
|
||||
ds = DataUtils.mnist_datasets(linearize=True, train_test=False)
|
||||
ds = DataUtils.sample_by_class(ds, k=5, shuffle=True, as_np=True, no_test=True)
|
||||
|
||||
comps = tw.get_tsne_components(ds)
|
||||
print(comps)
|
||||
plot = tw.Visualizer(comps, hover_images=ds[0], hover_image_reshape=(28,28), vis_type='tsne')
|
||||
plot.show()
|
|
@ -0,0 +1,77 @@
|
|||
import tensorwatch as tw
|
||||
import time
|
||||
import math
|
||||
from tensorwatch import utils
|
||||
utils.set_debug_verbosity(4)
|
||||
|
||||
def mpl_line_plot():
|
||||
cli = tw.WatcherClient()
|
||||
p = tw.mpl.LinePlot(title='Demo')
|
||||
s1 = cli.create_stream(event_name='ev_i', expr='map(lambda v:math.sqrt(v.val)*2, l)')
|
||||
p.subscribe(s1, xtitle='Index', ytitle='sqrt(ev_i)')
|
||||
p.show()
|
||||
tw.plt_loop()
|
||||
|
||||
def mpl_history_plot():
|
||||
cli = tw.WatcherClient()
|
||||
p2 = tw.mpl.LinePlot(title='History Demo')
|
||||
p2s1 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.val, math.sqrt(v.val)*2), l)')
|
||||
p2.subscribe(p2s1, xtitle='Index', ytitle='sqrt(ev_j)', clear_after_end=True, history_len=15)
|
||||
p2.show()
|
||||
tw.plt_loop()
|
||||
|
||||
def show_stream():
|
||||
cli = tw.WatcherClient()
|
||||
|
||||
print("Subscribing to event ev_i...")
|
||||
s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:math.sqrt(v.val), l)')
|
||||
r1 = tw.TextVis(title='L1')
|
||||
r1.subscribe(s1)
|
||||
r1.show()
|
||||
|
||||
print("Subscribing to event ev_j...")
|
||||
s2 = cli.create_stream(event_name="ev_j", expr='map(lambda v:v.val*v.val, l)')
|
||||
r2 = tw.TextVis(title='L2')
|
||||
r2.subscribe(s2)
|
||||
|
||||
r2.show()
|
||||
|
||||
print("Waiting for key...")
|
||||
|
||||
utils.wait_key()
|
||||
|
||||
# this no longer directly supported
|
||||
# TODO: create stream that allows enumeration from buffered values
|
||||
#def read_stream():
|
||||
# cli = tw.WatcherClient()
|
||||
|
||||
# with cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)') as s1:
|
||||
# for stream_item in s1:
|
||||
# print(stream_item.value)
|
||||
# print('done')
|
||||
# utils.wait_key()
|
||||
|
||||
def plotly_line_graph():
|
||||
cli = tw.WatcherClient()
|
||||
s1 = cli.create_stream(event_name="ev_i", expr='map(lambda v:(v.x, math.sqrt(v.val)), l)')
|
||||
|
||||
p = tw.plotly.LinePlot()
|
||||
p.subscribe(s1)
|
||||
p.show()
|
||||
|
||||
utils.wait_key()
|
||||
|
||||
def plotly_history_graph():
|
||||
cli = tw.WatcherClient()
|
||||
p = tw.plotly.LinePlot(title='Demo')
|
||||
s2 = cli.create_stream(event_name='ev_j', expr='map(lambda v:(v.x, v.val), l)')
|
||||
p.subscribe(s2, ytitle='ev_j', history_len=15)
|
||||
p.show()
|
||||
utils.wait_key()
|
||||
|
||||
|
||||
mpl_line_plot()
|
||||
#mpl_history_plot()
|
||||
#show_stream()
|
||||
#plotly_line_graph()
|
||||
#plotly_history_graph()
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче