This commit is contained in:
dotnet-bot 2018-10-01 11:22:59 -07:00 коммит произвёл Immo Landwerth
Коммит 73207ecc9c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 962A13C9167CE951
1462 изменённых файлов: 520857 добавлений и 0 удалений

262
.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,262 @@
## Ignore Visual Studio temporary files, build results, and
## files generated by popular Visual Studio add-ons.
# User-specific files
*.suo
*.user
*.userosscache
*.sln.docstates
# User-specific files (MonoDevelop/Xamarin Studio)
*.userprefs
# Build results
[Dd]ebug/
[Dd]ebugPublic/
[Rr]elease/
[Rr]eleases/
x64/
x86/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
# Visual Studio 2015 cache/options directory
.vs/
# Uncomment if you have tasks that create the project's static files in wwwroot
#wwwroot/
# MSTest test Results
[Tt]est[Rr]esult*/
[Bb]uild[Ll]og.*
# NUNIT
*.VisualState.xml
TestResult.xml
# Build Results of an ATL Project
[Dd]ebugPS/
[Rr]eleasePS/
dlldata.c
# DNX
project.lock.json
project.fragment.lock.json
artifacts/
*_i.c
*_p.c
*_i.h
*.ilk
*.meta
*.obj
*.pch
*.pdb
*.pgc
*.pgd
*.rsp
*.sbr
*.tlb
*.tli
*.tlh
*.tmp
*.tmp_proj
*.log
*.vspscc
*.vssscc
.builds
*.pidb
*.svclog
*.scc
# Chutzpah Test files
_Chutzpah*
# Visual C++ cache files
ipch/
*.aps
*.ncb
*.opendb
*.opensdf
*.sdf
*.cachefile
*.VC.db
*.VC.VC.opendb
# Visual Studio profiler
*.psess
*.vsp
*.vspx
*.sap
# TFS 2012 Local Workspace
$tf/
# Guidance Automation Toolkit
*.gpState
# ReSharper is a .NET coding add-in
_ReSharper*/
*.[Rr]e[Ss]harper
*.DotSettings.user
# JustCode is a .NET coding add-in
.JustCode
# TeamCity is a build add-in
_TeamCity*
# DotCover is a Code Coverage Tool
*.dotCover
# NCrunch
_NCrunch_*
.*crunch*.local.xml
nCrunchTemp_*
# MightyMoose
*.mm.*
AutoTest.Net/
# Web workbench (sass)
.sass-cache/
# Installshield output folder
[Ee]xpress/
# DocProject is a documentation generator add-in
DocProject/buildhelp/
DocProject/Help/*.HxT
DocProject/Help/*.HxC
DocProject/Help/*.hhc
DocProject/Help/*.hhk
DocProject/Help/*.hhp
DocProject/Help/Html2
DocProject/Help/html
# Click-Once directory
publish/
# Publish Web Output
*.[Pp]ublish.xml
*.azurePubxml
# TODO: Comment the next line if you want to checkin your web deploy settings
# but database connection strings (with potential passwords) will be unencrypted
*.pubxml
*.publishproj
# Microsoft Azure Web App publish settings. Comment the next line if you want to
# checkin your Azure Web App publish settings, but sensitive information contained
# in these scripts will be unencrypted
PublishScripts/
# NuGet Packages
*.nupkg
# The packages folder can be ignored because of Package Restore
**/packages/*
# except build/, which is used as an MSBuild target.
!**/packages/build/
# Uncomment if necessary however generally it will be regenerated when needed
#!**/packages/repositories.config
# NuGet v3's project.json files produces more ignoreable files
*.nuget.props
*.nuget.targets
# Microsoft Azure Build Output
csx/
*.build.csdef
# Microsoft Azure Emulator
ecf/
rcf/
# Windows Store app package directories and files
AppPackages/
BundleArtifacts/
Package.StoreAssociation.xml
_pkginfo.txt
# Visual Studio cache files
# files ending in .cache can be ignored
*.[Cc]ache
# but keep track of directories ending in .cache
!*.[Cc]ache/
# Others
ClientBin/
~$*
*~
*.dbmdl
*.dbproj.schemaview
*.pfx
*.publishsettings
node_modules/
orleans.codegen.cs
# Since there are multiple workflows, uncomment next line to ignore bower_components
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
#bower_components/
# RIA/Silverlight projects
Generated_Code/
# Backup & report files from converting an old project file
# to a newer Visual Studio version. Backup files are not needed,
# because we have git ;-)
_UpgradeReport_Files/
Backup*/
UpgradeLog*.XML
UpgradeLog*.htm
# SQL Server files
*.mdf
*.ldf
# Business Intelligence projects
*.rdl.data
*.bim.layout
*.bim_*.settings
# Microsoft Fakes
FakesAssemblies/
# GhostDoc plugin setting file
*.GhostDoc.xml
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
# Visual Studio 6 build log
*.plg
# Visual Studio 6 workspace options file
*.opt
# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
**/*.DesktopClient/ModelManifest.xml
**/*.Server/GeneratedArtifacts
**/*.Server/ModelManifest.xml
_Pvt_Extensions
# Paket dependency manager
.paket/paket.exe
paket-files/
# FAKE - F# Make
.fake/
# JetBrains Rider
.idea/
*.sln.iml
# Jekyll, DocFX and Sandcastle
docs/obj/
InferNet_Copy_Temp/
apiguide-tmp/
_site/
#code generated by iron python examples
IronPythonWrapper/InferNetExamples/InferNetExamples/GeneratedSource/

15
CONTRIBUTING.md Normal file
Просмотреть файл

@ -0,0 +1,15 @@
# Contributing
Welcome, and thank you for your interest in contributing to Infer.NET!
Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.microsoft.com.
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

358
CodeAnalysis.ruleset Normal file
Просмотреть файл

@ -0,0 +1,358 @@
<?xml version="1.0" encoding="utf-8"?>
<RuleSet Name="Infer.NET Code Analysis Rules" Description="Rule set for Infer.NET" ToolsVersion="11.0">
<IncludeAll Action="Warning" />
<Rules AnalyzerId="Microsoft.Analyzers.NativeCodeAnalysis" RuleNamespace="Microsoft.Rules.Native">
<Rule Id="C26100" Action="None" />
<Rule Id="C26101" Action="None" />
<Rule Id="C26105" Action="None" />
<Rule Id="C26110" Action="None" />
<Rule Id="C26111" Action="None" />
<Rule Id="C26112" Action="None" />
<Rule Id="C26115" Action="None" />
<Rule Id="C26116" Action="None" />
<Rule Id="C26117" Action="None" />
<Rule Id="C26130" Action="None" />
<Rule Id="C26135" Action="None" />
<Rule Id="C26140" Action="None" />
<Rule Id="C26160" Action="None" />
<Rule Id="C26165" Action="None" />
<Rule Id="C26166" Action="None" />
<Rule Id="C26167" Action="None" />
<Rule Id="C28020" Action="None" />
<Rule Id="C28021" Action="None" />
<Rule Id="C28022" Action="None" />
<Rule Id="C28023" Action="None" />
<Rule Id="C28024" Action="None" />
<Rule Id="C28039" Action="None" />
<Rule Id="C28101" Action="None" />
<Rule Id="C28103" Action="None" />
<Rule Id="C28104" Action="None" />
<Rule Id="C28105" Action="None" />
<Rule Id="C28106" Action="None" />
<Rule Id="C28107" Action="None" />
<Rule Id="C28108" Action="None" />
<Rule Id="C28109" Action="None" />
<Rule Id="C28110" Action="None" />
<Rule Id="C28111" Action="None" />
<Rule Id="C28112" Action="None" />
<Rule Id="C28113" Action="None" />
<Rule Id="C28114" Action="None" />
<Rule Id="C28120" Action="None" />
<Rule Id="C28121" Action="None" />
<Rule Id="C28122" Action="None" />
<Rule Id="C28123" Action="None" />
<Rule Id="C28124" Action="None" />
<Rule Id="C28125" Action="None" />
<Rule Id="C28126" Action="None" />
<Rule Id="C28127" Action="None" />
<Rule Id="C28128" Action="None" />
<Rule Id="C28129" Action="None" />
<Rule Id="C28131" Action="None" />
<Rule Id="C28132" Action="None" />
<Rule Id="C28133" Action="None" />
<Rule Id="C28134" Action="None" />
<Rule Id="C28135" Action="None" />
<Rule Id="C28137" Action="None" />
<Rule Id="C28138" Action="None" />
<Rule Id="C28139" Action="None" />
<Rule Id="C28141" Action="None" />
<Rule Id="C28143" Action="None" />
<Rule Id="C28144" Action="None" />
<Rule Id="C28145" Action="None" />
<Rule Id="C28146" Action="None" />
<Rule Id="C28147" Action="None" />
<Rule Id="C28150" Action="None" />
<Rule Id="C28151" Action="None" />
<Rule Id="C28152" Action="None" />
<Rule Id="C28153" Action="None" />
<Rule Id="C28156" Action="None" />
<Rule Id="C28157" Action="None" />
<Rule Id="C28158" Action="None" />
<Rule Id="C28159" Action="None" />
<Rule Id="C28160" Action="None" />
<Rule Id="C28161" Action="None" />
<Rule Id="C28162" Action="None" />
<Rule Id="C28163" Action="None" />
<Rule Id="C28164" Action="None" />
<Rule Id="C28165" Action="None" />
<Rule Id="C28166" Action="None" />
<Rule Id="C28167" Action="None" />
<Rule Id="C28168" Action="None" />
<Rule Id="C28169" Action="None" />
<Rule Id="C28170" Action="None" />
<Rule Id="C28171" Action="None" />
<Rule Id="C28172" Action="None" />
<Rule Id="C28173" Action="None" />
<Rule Id="C28175" Action="None" />
<Rule Id="C28176" Action="None" />
<Rule Id="C28177" Action="None" />
<Rule Id="C28182" Action="None" />
<Rule Id="C28183" Action="None" />
<Rule Id="C28193" Action="None" />
<Rule Id="C28194" Action="None" />
<Rule Id="C28195" Action="None" />
<Rule Id="C28196" Action="None" />
<Rule Id="C28197" Action="None" />
<Rule Id="C28198" Action="None" />
<Rule Id="C28199" Action="None" />
<Rule Id="C28202" Action="None" />
<Rule Id="C28203" Action="None" />
<Rule Id="C28204" Action="None" />
<Rule Id="C28205" Action="None" />
<Rule Id="C28206" Action="None" />
<Rule Id="C28207" Action="None" />
<Rule Id="C28208" Action="None" />
<Rule Id="C28209" Action="None" />
<Rule Id="C28210" Action="None" />
<Rule Id="C28211" Action="None" />
<Rule Id="C28212" Action="None" />
<Rule Id="C28213" Action="None" />
<Rule Id="C28214" Action="None" />
<Rule Id="C28215" Action="None" />
<Rule Id="C28216" Action="None" />
<Rule Id="C28217" Action="None" />
<Rule Id="C28218" Action="None" />
<Rule Id="C28219" Action="None" />
<Rule Id="C28220" Action="None" />
<Rule Id="C28221" Action="None" />
<Rule Id="C28222" Action="None" />
<Rule Id="C28223" Action="None" />
<Rule Id="C28224" Action="None" />
<Rule Id="C28225" Action="None" />
<Rule Id="C28226" Action="None" />
<Rule Id="C28227" Action="None" />
<Rule Id="C28228" Action="None" />
<Rule Id="C28229" Action="None" />
<Rule Id="C28230" Action="None" />
<Rule Id="C28231" Action="None" />
<Rule Id="C28232" Action="None" />
<Rule Id="C28233" Action="None" />
<Rule Id="C28234" Action="None" />
<Rule Id="C28235" Action="None" />
<Rule Id="C28236" Action="None" />
<Rule Id="C28237" Action="None" />
<Rule Id="C28238" Action="None" />
<Rule Id="C28239" Action="None" />
<Rule Id="C28240" Action="None" />
<Rule Id="C28241" Action="None" />
<Rule Id="C28243" Action="None" />
<Rule Id="C28244" Action="None" />
<Rule Id="C28245" Action="None" />
<Rule Id="C28246" Action="None" />
<Rule Id="C28250" Action="None" />
<Rule Id="C28251" Action="None" />
<Rule Id="C28252" Action="None" />
<Rule Id="C28253" Action="None" />
<Rule Id="C28254" Action="None" />
<Rule Id="C28260" Action="None" />
<Rule Id="C28262" Action="None" />
<Rule Id="C28263" Action="None" />
<Rule Id="C28266" Action="None" />
<Rule Id="C28267" Action="None" />
<Rule Id="C28268" Action="None" />
<Rule Id="C28272" Action="None" />
<Rule Id="C28273" Action="None" />
<Rule Id="C28275" Action="None" />
<Rule Id="C28278" Action="None" />
<Rule Id="C28279" Action="None" />
<Rule Id="C28280" Action="None" />
<Rule Id="C28282" Action="None" />
<Rule Id="C28283" Action="None" />
<Rule Id="C28284" Action="None" />
<Rule Id="C28285" Action="None" />
<Rule Id="C28286" Action="None" />
<Rule Id="C28287" Action="None" />
<Rule Id="C28288" Action="None" />
<Rule Id="C28289" Action="None" />
<Rule Id="C28290" Action="None" />
<Rule Id="C28291" Action="None" />
<Rule Id="C28300" Action="None" />
<Rule Id="C28301" Action="None" />
<Rule Id="C28302" Action="None" />
<Rule Id="C28303" Action="None" />
<Rule Id="C28304" Action="None" />
<Rule Id="C28305" Action="None" />
<Rule Id="C28306" Action="None" />
<Rule Id="C28307" Action="None" />
<Rule Id="C28308" Action="None" />
<Rule Id="C28309" Action="None" />
<Rule Id="C28350" Action="None" />
<Rule Id="C28351" Action="None" />
<Rule Id="C28601" Action="None" />
<Rule Id="C28602" Action="None" />
<Rule Id="C28604" Action="None" />
<Rule Id="C28615" Action="None" />
<Rule Id="C28616" Action="None" />
<Rule Id="C28617" Action="None" />
<Rule Id="C28623" Action="None" />
<Rule Id="C28624" Action="None" />
<Rule Id="C28625" Action="None" />
<Rule Id="C28636" Action="None" />
<Rule Id="C28637" Action="None" />
<Rule Id="C28638" Action="None" />
<Rule Id="C28639" Action="None" />
<Rule Id="C28640" Action="None" />
<Rule Id="C28645" Action="None" />
<Rule Id="C28648" Action="None" />
<Rule Id="C28649" Action="None" />
<Rule Id="C28650" Action="None" />
<Rule Id="C28714" Action="None" />
<Rule Id="C28715" Action="None" />
<Rule Id="C28716" Action="None" />
<Rule Id="C28717" Action="None" />
<Rule Id="C28719" Action="None" />
<Rule Id="C28720" Action="None" />
<Rule Id="C28721" Action="None" />
<Rule Id="C28726" Action="None" />
<Rule Id="C28727" Action="None" />
<Rule Id="C28730" Action="None" />
<Rule Id="C28735" Action="None" />
<Rule Id="C28736" Action="None" />
<Rule Id="C28750" Action="None" />
<Rule Id="C28751" Action="None" />
<Rule Id="C6001" Action="None" />
<Rule Id="C6011" Action="None" />
<Rule Id="C6014" Action="None" />
<Rule Id="C6029" Action="None" />
<Rule Id="C6031" Action="None" />
<Rule Id="C6053" Action="None" />
<Rule Id="C6054" Action="None" />
<Rule Id="C6059" Action="None" />
<Rule Id="C6063" Action="None" />
<Rule Id="C6064" Action="None" />
<Rule Id="C6066" Action="None" />
<Rule Id="C6067" Action="None" />
<Rule Id="C6101" Action="None" />
<Rule Id="C6200" Action="None" />
<Rule Id="C6201" Action="None" />
<Rule Id="C6211" Action="None" />
<Rule Id="C6214" Action="None" />
<Rule Id="C6215" Action="None" />
<Rule Id="C6216" Action="None" />
<Rule Id="C6217" Action="None" />
<Rule Id="C6219" Action="None" />
<Rule Id="C6220" Action="None" />
<Rule Id="C6221" Action="None" />
<Rule Id="C6225" Action="None" />
<Rule Id="C6226" Action="None" />
<Rule Id="C6230" Action="None" />
<Rule Id="C6235" Action="None" />
<Rule Id="C6236" Action="None" />
<Rule Id="C6237" Action="None" />
<Rule Id="C6239" Action="None" />
<Rule Id="C6240" Action="None" />
<Rule Id="C6242" Action="None" />
<Rule Id="C6244" Action="None" />
<Rule Id="C6246" Action="None" />
<Rule Id="C6248" Action="None" />
<Rule Id="C6250" Action="None" />
<Rule Id="C6255" Action="None" />
<Rule Id="C6258" Action="None" />
<Rule Id="C6259" Action="None" />
<Rule Id="C6260" Action="None" />
<Rule Id="C6262" Action="None" />
<Rule Id="C6263" Action="None" />
<Rule Id="C6268" Action="None" />
<Rule Id="C6269" Action="None" />
<Rule Id="C6270" Action="None" />
<Rule Id="C6271" Action="None" />
<Rule Id="C6272" Action="None" />
<Rule Id="C6273" Action="None" />
<Rule Id="C6274" Action="None" />
<Rule Id="C6276" Action="None" />
<Rule Id="C6277" Action="None" />
<Rule Id="C6278" Action="None" />
<Rule Id="C6279" Action="None" />
<Rule Id="C6280" Action="None" />
<Rule Id="C6281" Action="None" />
<Rule Id="C6282" Action="None" />
<Rule Id="C6283" Action="None" />
<Rule Id="C6284" Action="None" />
<Rule Id="C6285" Action="None" />
<Rule Id="C6286" Action="None" />
<Rule Id="C6287" Action="None" />
<Rule Id="C6288" Action="None" />
<Rule Id="C6289" Action="None" />
<Rule Id="C6290" Action="None" />
<Rule Id="C6291" Action="None" />
<Rule Id="C6292" Action="None" />
<Rule Id="C6293" Action="None" />
<Rule Id="C6294" Action="None" />
<Rule Id="C6295" Action="None" />
<Rule Id="C6296" Action="None" />
<Rule Id="C6297" Action="None" />
<Rule Id="C6298" Action="None" />
<Rule Id="C6299" Action="None" />
<Rule Id="C6302" Action="None" />
<Rule Id="C6303" Action="None" />
<Rule Id="C6305" Action="None" />
<Rule Id="C6306" Action="None" />
<Rule Id="C6308" Action="None" />
<Rule Id="C6310" Action="None" />
<Rule Id="C6312" Action="None" />
<Rule Id="C6313" Action="None" />
<Rule Id="C6314" Action="None" />
<Rule Id="C6315" Action="None" />
<Rule Id="C6316" Action="None" />
<Rule Id="C6317" Action="None" />
<Rule Id="C6318" Action="None" />
<Rule Id="C6319" Action="None" />
<Rule Id="C6320" Action="None" />
<Rule Id="C6322" Action="None" />
<Rule Id="C6323" Action="None" />
<Rule Id="C6324" Action="None" />
<Rule Id="C6326" Action="None" />
<Rule Id="C6328" Action="None" />
<Rule Id="C6329" Action="None" />
<Rule Id="C6330" Action="None" />
<Rule Id="C6331" Action="None" />
<Rule Id="C6332" Action="None" />
<Rule Id="C6333" Action="None" />
<Rule Id="C6334" Action="None" />
<Rule Id="C6335" Action="None" />
<Rule Id="C6336" Action="None" />
<Rule Id="C6340" Action="None" />
<Rule Id="C6381" Action="None" />
<Rule Id="C6383" Action="None" />
<Rule Id="C6384" Action="None" />
<Rule Id="C6385" Action="None" />
<Rule Id="C6386" Action="None" />
<Rule Id="C6387" Action="None" />
<Rule Id="C6388" Action="None" />
<Rule Id="C6400" Action="None" />
<Rule Id="C6401" Action="None" />
<Rule Id="C6411" Action="None" />
<Rule Id="C6412" Action="None" />
<Rule Id="C6500" Action="None" />
<Rule Id="C6501" Action="None" />
<Rule Id="C6503" Action="None" />
<Rule Id="C6504" Action="None" />
<Rule Id="C6505" Action="None" />
<Rule Id="C6506" Action="None" />
<Rule Id="C6508" Action="None" />
<Rule Id="C6509" Action="None" />
<Rule Id="C6510" Action="None" />
<Rule Id="C6511" Action="None" />
<Rule Id="C6513" Action="None" />
<Rule Id="C6514" Action="None" />
<Rule Id="C6515" Action="None" />
<Rule Id="C6516" Action="None" />
<Rule Id="C6517" Action="None" />
<Rule Id="C6518" Action="None" />
<Rule Id="C6522" Action="None" />
<Rule Id="C6525" Action="None" />
<Rule Id="C6527" Action="None" />
<Rule Id="C6530" Action="None" />
<Rule Id="C6540" Action="None" />
<Rule Id="C6551" Action="None" />
<Rule Id="C6552" Action="None" />
<Rule Id="C6701" Action="None" />
<Rule Id="C6702" Action="None" />
<Rule Id="C6703" Action="None" />
<Rule Id="C6704" Action="None" />
<Rule Id="C6705" Action="None" />
<Rule Id="C6706" Action="None" />
<Rule Id="C6707" Action="None" />
<Rule Id="C6995" Action="None" />
</Rules>
</RuleSet>

98
CodeCoverage.runsettings Normal file
Просмотреть файл

@ -0,0 +1,98 @@
<?xml version="1.0" encoding="utf-8"?>
<RunSettings>
<DataCollectionRunSettings>
<DataCollectors>
<DataCollector friendlyName="Code Coverage" uri="datacollector://Microsoft/CodeCoverage/2.0" assemblyQualifiedName="Microsoft.VisualStudio.Coverage.DynamicCoverageDataCollector, Microsoft.VisualStudio.TraceCollector, Version=11.0.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a">
<Configuration>
<CodeCoverage>
<ModulePaths>
<!--
About include/exclude lists:
Empty "Include" clauses imply all; empty "Exclude" clauses imply none.
Each element in the list is a regular expression (ECMAScript syntax).
An item must first match at least one entry in the include list to be included.
Included items must then not match any entries in the exclude list to remain included.
It is considered an error to exclude all items from instrumentation as no data would be collected.
-->
<Include>
<!-- Include modules of interest, by their name / path -->
<ModulePath>.*microsoft.ml.probabilistic.compiler.dll</ModulePath>
<ModulePath>.*microsoft.ml.probabilistic.dll</ModulePath>
<ModulePath>.*microsoft.ml.probabilistic.learners.dll</ModulePath>
<ModulePath>.*microsoft.ml.probabilistic.learners.classifier.dll</ModulePath>
<ModulePath>.*microsoft.ml.probabilistic.learners.recommender.dll</ModulePath>
<ModulePath>.*learner.exe</ModulePath>
<ModulePath>.*microsoft.ml.probabilistic.learners.runners.dll</ModulePath>
</Include>
<Exclude>
<!-- Do not specify any excludes. Anything not included will get excluded -->
</Exclude>
</ModulePaths>
<UseVerifiableInstrumentation>True</UseVerifiableInstrumentation>
<AllowLowIntegrityProcesses>True</AllowLowIntegrityProcesses>
<CollectFromChildProcesses>True</CollectFromChildProcesses>
<CollectAspDotNet>False</CollectAspDotNet>
<!--
Additional paths to search for symbol files. Symbols must be found for modules to be instrumented.
If symbols are alongside the binaries, they are automatically picked up. Otherwise specify the here.
Note that searching for symbols increases code coverage runtime. So keep this small and local.
<SymbolSearchPaths>
<Path>C:\Users\User\Documents\Visual Studio 11\Projects\ProjectX\bin\Debug</Path>
<Path>\\mybuildshare\builds\ProjectX</Path>
</SymbolSearchPaths>
-->
<Functions>
<Exclude>
<Function>^std::.*</Function>
<Function>^ATL::.*</Function>
<Function>.*::__GetTestMethodInfo.*</Function>
<Function>^Microsoft::VisualStudio::CppCodeCoverageFramework::.*</Function>
<Function>^Microsoft::VisualStudio::CppUnitTestFramework::.*</Function>
<Function>.*::YOU_CAN_ONLY_DESIGNATE_ONE_.*</Function>
</Exclude>
</Functions>
<Attributes>
<Exclude>
<Attribute>^System.Diagnostics.DebuggerHiddenAttribute$</Attribute>
<Attribute>^System.Diagnostics.DebuggerNonUserCodeAttribute$</Attribute>
<Attribute>^System.Runtime.CompilerServices.CompilerGeneratedAttribute$</Attribute>
<Attribute>^System.CodeDom.Compiler.GeneratedCodeAttribute$</Attribute>
<Attribute>^System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute$</Attribute>
</Exclude>
</Attributes>
<Sources>
<Exclude>
<Source>.*\\atlmfc\\.*</Source>
<Source>.*\\vctools\\.*</Source>
<Source>.*\\public\\sdk\\.*</Source>
<Source>.*\\microsoft sdks\\.*</Source>
<Source>.*\\vc\\include\\.*</Source>
</Exclude>
</Sources>
<CompanyNames>
<Exclude>
<!--
<CompanyName>.*microsoft.*</CompanyName>
-->
</Exclude>
</CompanyNames>
<PublicKeyTokens>
<Exclude>
<PublicKeyToken>^B77A5C561934E089$</PublicKeyToken>
<PublicKeyToken>^B03F5F7F11D50A3A$</PublicKeyToken>
<PublicKeyToken>^31BF3856AD364E35$</PublicKeyToken>
<PublicKeyToken>^89845DCD8080CC91$</PublicKeyToken>
<PublicKeyToken>^71E9BCE111E9429C$</PublicKeyToken>
<PublicKeyToken>^8F50407C4E9E73B6$</PublicKeyToken>
<PublicKeyToken>^E361AF139669C375$</PublicKeyToken>
</Exclude>
</PublicKeyTokens>
</CodeCoverage>
</Configuration>
</DataCollector>
</DataCollectors>
</DataCollectionRunSettings>
</RunSettings>

430
Infer2.sln Normal file
Просмотреть файл

@ -0,0 +1,430 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.27428.2005
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{A181C943-2E01-454D-9008-2E3C53AA09CC}"
ProjectSection(SolutionItems) = preProject
Infer2.snk = Infer2.snk
TestRunConfig1.testrunconfig = TestRunConfig1.testrunconfig
EndProjectSection
ProjectSection(FolderStartupServices) = postProject
{B4F97281-0DBD-4835-9ED8-7DFB966E87FF} = {B4F97281-0DBD-4835-9ED8-7DFB966E87FF}
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestApp", "test\TestApp\TestApp.csproj", "{2A61553C-A089-4310-AA63-BC130C1AEE6E}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tests", "test\Tests\Tests.csproj", "{CEFE65E8-29F8-4B2E-809B-96D7DA924430}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tutorials", "src\Tutorials\Tutorials.csproj", "{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}"
ProjectSection(ProjectDependencies) = postProject
{FB669026-E1C8-417F-962B-D8235A3DA2D3} = {FB669026-E1C8-417F-962B-D8235A3DA2D3}
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14} = {6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}
{DF2DB4DE-A48C-4309-B867-CDE55512BB72} = {DF2DB4DE-A48C-4309-B867-CDE55512BB72}
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Runtime", "src\Runtime\Runtime.csproj", "{DF2DB4DE-A48C-4309-B867-CDE55512BB72}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Compiler", "src\Compiler\Compiler.csproj", "{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestPublic", "test\TestPublic\TestPublic.csproj", "{2D65E625-54CE-4946-98E2-22427B47F186}"
EndProject
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "FSharpWrapper", "src\FSharpWrapper\FSharpWrapper.fsproj", "{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}"
EndProject
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TestFSharp", "test\TestFSharp\TestFSharp.fsproj", "{945BA80B-A7F7-450B-AE93-9FA2F7617442}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Examples", "Examples", "{DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ClickThroughModel", "src\Examples\ClickThroughModel\ClickThroughModel.csproj", "{33D86EA2-2161-4EF0-8F17-59602296273C}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ClinicalTrial", "src\Examples\ClinicalTrial\ClinicalTrial.csproj", "{B517BBF2-60E6-4C69-885A-AE5C014D877B}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "InferNET101", "src\Examples\InferNET101\InferNET101.csproj", "{52D174E7-2407-4FC1-9DDA-4D9D14F18618}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MontyHall", "src\Examples\MontyHall\MontyHall.csproj", "{6139CF19-0190-4ED5-AEE3-D3CE7458E517}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MotifFinder", "src\Examples\MotifFinder\MotifFinder.csproj", "{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Image_Classifier", "src\Examples\ImageClassifier\Image_Classifier.csproj", "{87D09BD4-119E-49C1-B0B4-86DF962A00EE}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LDA", "src\Examples\LDA\LDA.csproj", "{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Csoft", "src\Csoft\Csoft.csproj", "{5B669C82-B04C-4DD6-8CE6-47D025D98777}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Learners", "Learners", "{2964BB90-4E6D-49ED-AA35-645D94337C76}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Runners", "Runners", "{3DB795A6-5FE8-447C-89B9-9E608285C6F8}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CommandLine", "src\Learners\Runners\CommandLine\CommandLine.csproj", "{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Common", "src\Learners\Runners\Common\Common.csproj", "{25D28099-E338-4543-B1DE-261439654CA6}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Evaluator", "src\Learners\Runners\Evaluator\Evaluator.csproj", "{040FA938-BE24-4391-86BA-D04B331A787A}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Classifier", "src\Learners\Classifier\Classifier.csproj", "{07E9E91D-6593-4FF9-A266-270ED5241C98}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ClassifierModels", "src\Learners\ClassifierModels\ClassifierModels.csproj", "{3207C90B-DB71-4293-8974-0795863C076B}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Core", "src\Learners\Core\Core.csproj", "{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Recommender", "src\Learners\Recommender\Recommender.csproj", "{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "RecommenderModels", "src\Learners\RecommenderModels\RecommenderModels.csproj", "{8D4D5502-4321-46A5-975B-C20BD745FC06}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestApp", "test\Learners\TestApp\TestApp.csproj", "{E2409457-2BA1-47C8-B53B-CE712896FE6E}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LearnersTests", "test\Learners\LearnersTests\LearnersTests.csproj", "{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Visualizers.Windows", "src\Visualizers\Windows\Visualizers.Windows.csproj", "{FB669026-E1C8-417F-962B-D8235A3DA2D3}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
DebugCore|Any CPU = DebugCore|Any CPU
DebugFull|Any CPU = DebugFull|Any CPU
Release|Any CPU = Release|Any CPU
ReleaseCore|Any CPU = ReleaseCore|Any CPU
ReleaseFull|Any CPU = ReleaseFull|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.Release|Any CPU.Build.0 = Release|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{2A61553C-A089-4310-AA63-BC130C1AEE6E}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.Debug|Any CPU.Build.0 = Debug|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.Release|Any CPU.ActiveCfg = Release|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.Release|Any CPU.Build.0 = Release|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{CEFE65E8-29F8-4B2E-809B-96D7DA924430}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.Debug|Any CPU.Build.0 = Debug|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.Release|Any CPU.ActiveCfg = Release|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.Release|Any CPU.Build.0 = Release|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{04787AC6-CAC0-4D77-8C4A-2EA66C3CBE3C}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.Release|Any CPU.ActiveCfg = Release|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.Release|Any CPU.Build.0 = Release|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{DF2DB4DE-A48C-4309-B867-CDE55512BB72}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.Release|Any CPU.Build.0 = Release|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{6DF146DD-6CE4-40B9-9B89-5E7DFAA54B14}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.Release|Any CPU.Build.0 = Release|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{2D65E625-54CE-4946-98E2-22427B47F186}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.Release|Any CPU.Build.0 = Release|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{F8F34F1F-F222-4792-9BF8-D6AB3018E6F5}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.Debug|Any CPU.Build.0 = Debug|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.Release|Any CPU.ActiveCfg = Release|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.Release|Any CPU.Build.0 = Release|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{945BA80B-A7F7-450B-AE93-9FA2F7617442}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.Debug|Any CPU.Build.0 = Debug|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.Release|Any CPU.ActiveCfg = Release|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.Release|Any CPU.Build.0 = Release|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{33D86EA2-2161-4EF0-8F17-59602296273C}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.Release|Any CPU.Build.0 = Release|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{B517BBF2-60E6-4C69-885A-AE5C014D877B}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.Debug|Any CPU.Build.0 = Debug|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.Release|Any CPU.ActiveCfg = Release|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.Release|Any CPU.Build.0 = Release|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{52D174E7-2407-4FC1-9DDA-4D9D14F18618}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.Release|Any CPU.Build.0 = Release|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{6139CF19-0190-4ED5-AEE3-D3CE7458E517}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.Release|Any CPU.Build.0 = Release|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.Release|Any CPU.Build.0 = Release|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{87D09BD4-119E-49C1-B0B4-86DF962A00EE}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.Release|Any CPU.Build.0 = Release|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.Release|Any CPU.Build.0 = Release|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Debug|Any CPU.Build.0 = Debug|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Release|Any CPU.ActiveCfg = Release|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Release|Any CPU.Build.0 = Release|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.Debug|Any CPU.Build.0 = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.Release|Any CPU.ActiveCfg = Release|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.Release|Any CPU.Build.0 = Release|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.Release|Any CPU.Build.0 = Release|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{040FA938-BE24-4391-86BA-D04B331A787A}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.Debug|Any CPU.Build.0 = Debug|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.Release|Any CPU.ActiveCfg = Release|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.Release|Any CPU.Build.0 = Release|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{07E9E91D-6593-4FF9-A266-270ED5241C98}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.Release|Any CPU.Build.0 = Release|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{3207C90B-DB71-4293-8974-0795863C076B}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.Debug|Any CPU.Build.0 = Debug|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.Release|Any CPU.ActiveCfg = Release|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.Release|Any CPU.Build.0 = Release|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.DebugCore|Any CPU.Build.0 = Debug|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.Release|Any CPU.Build.0 = Release|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.Debug|Any CPU.Build.0 = Debug|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.Release|Any CPU.Build.0 = Release|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{8D4D5502-4321-46A5-975B-C20BD745FC06}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.Release|Any CPU.Build.0 = Release|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{E2409457-2BA1-47C8-B53B-CE712896FE6E}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.Release|Any CPU.Build.0 = Release|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.DebugFull|Any CPU.ActiveCfg = Debug|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.DebugFull|Any CPU.Build.0 = Debug|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.Release|Any CPU.Build.0 = Release|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.ReleaseCore|Any CPU.ActiveCfg = Release|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{FB669026-E1C8-417F-962B-D8235A3DA2D3}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
GlobalSection(NestedProjects) = preSolution
{33D86EA2-2161-4EF0-8F17-59602296273C} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{B517BBF2-60E6-4C69-885A-AE5C014D877B} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{52D174E7-2407-4FC1-9DDA-4D9D14F18618} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{6139CF19-0190-4ED5-AEE3-D3CE7458E517} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{D2A7B5F5-8D33-45AC-9776-07C23F5859BB} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{87D09BD4-119E-49C1-B0B4-86DF962A00EE} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{3DB795A6-5FE8-447C-89B9-9E608285C6F8} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
{25D28099-E338-4543-B1DE-261439654CA6} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
{040FA938-BE24-4391-86BA-D04B331A787A} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
{07E9E91D-6593-4FF9-A266-270ED5241C98} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{3207C90B-DB71-4293-8974-0795863C076B} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{29A1A83E-51DC-4409-B29B-CDE7D65A1EF8} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{5AB7D09F-5F98-465E-AB9D-07014F1DBC3F} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{8D4D5502-4321-46A5-975B-C20BD745FC06} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{E2409457-2BA1-47C8-B53B-CE712896FE6E} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{7A774F1F-31D6-4D7F-90D5-9C4F387D2EEE} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {160F773C-9CF5-4F8D-B45A-1112A1BC5E16}
EndGlobalSection
GlobalSection(TestCaseManagementSettings) = postSolution
CategoryFile = Infer2.vsmdi
EndGlobalSection
GlobalSection(TextTemplating) = postSolution
TextTemplating = 1
EndGlobalSection
EndGlobal

Двоичные данные
Infer2.snk Normal file

Двоичный файл не отображается.

10
LICENSE.txt Normal file
Просмотреть файл

@ -0,0 +1,10 @@
The MIT License (MIT)
Copyright (c) 2018 .NET Foundation
All rights reserved
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the Software), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

164
README.md Normal file
Просмотреть файл

@ -0,0 +1,164 @@
**Infer&#46;NET** is a framework for running Bayesian inference in graphical models. It can also be used for probabilistic programming.
One can use Infer&#46;NET to solve many different kinds of machine learning problems - from standard problems like [classification](https://microsoft.github.io/Infer.NET/userguide/docs/Infer.NET%20Learners%20-%20Bayes%20Point%20Machine%20classifiers),
[recommendation](https://microsoft.github.io/Infer.NET/userguide/docs/Infer.NET%20Learners%20%20Matchbox%20recommender) or [clustering](https://microsoft.github.io/Infer.NET/userguide/docs/Mixture%20of%20Gaussians%20tutorial) through to [customised solutions to domain-specific problems](https://microsoft.github.io/Infer.NET/userguide/docs/Click%20through%20model%20sample).
**Infer&#46;NET** has been used in a wide variety of domains including information retrieval, bioinformatics, epidemiology, vision,
and many others.
# Contents
- [Structure of repository](#structure-of-repository)
- [Build and test](#build-and-test)
- [Windows](##windows)
- [Linux](##linux)
# Structure of repository
* The Visual Studio solution `Infer2.sln` in the root of the repository contains all Infer&#46;NET components, unit tests and sample programs from the folders described below.
* `src/`
* `Compiler` contains the Infer&#46;NET Compiler project which takes model descriptions written using the Infer&#46;NET API, and converts them into inference code. The project also contains utility methods for visualization of the generated code.
* `Csoft` is an experimental feature that allows to express probabilistic models in a subset of the C# language. You can find many unit tests of `Csoft` models in the `Tests` project marked with `Category: CsoftModel` trait.
* `Examples` contains C# projects that illustrate how to use Infer&#46;NET to solve a variety of different problems.
* `ClickThroughModel` - a web search example of converting a sequence of clicks by the user into inferences about the relevance of documents.
* `ClinicalTrial` - the clinical trial tutorial example with an interactive user interface.`
* `InferNET101` - samples from Infer&#46;NET 101 introduction to the basics of Microsoft Infer&#46;NET programming.
* `ImageClassifier` - an image search example of classifying tagged images.
* `LDA` - this example provides Infer&#46;NET implementations of the popular LDA model for topic modeling. The implementations pay special attention to scalability with respect to vocabulary size, and with respect to the number of documents. As such, they provide good examples for how to scale Infer&#46;NET models in general.
* `MontyHall` - an Infer&#46;NET implementation of the Monty Hall problem, along with a graphical user interface.
* `FSharpWrapper` is a wrapper project that hides some of the generic constructs in the Infer&#46;NET API allowing simpler calls to the Infer&#46;NET API from standard F#.
* `IronPythonWrapper` contains wrapper for calling Infer&#46;NET from the [IronPython](https://ironpython.net/) programming language and tests for the wrapper. Please refer to [README.md](IronPythonWrapper/README.md) for more information.
* `Learners` folder contains Visual Studio projects for complete machine learning applications including classification and recommendation. You can read more about Learners [here](https://microsoft.github.io/Infer.NET/userguide/docs/Infer.NET%20Learners.md).
* `Runtime` - is a C# project with classes and methods needed to execute the inference code.
* `Tutorials` contains [Examples Browser](https://microsoft.github.io/Infer.NET/userguide/docs/The%20Example%20Browser.md) project with simple examples that provide a step-by-step introduction to Infer.NET.
* `test/`
* `TestApp` contains C# console application for quick invocation and debugging of variouse Infer&#46;NET components.
* `TestFSharp` is an F# console project for smoke testing of Infer&#46;NET F# wrapper.
* `TestPublic` contains scenario tests for tutorial code. These tests are a part of the PR and nightly builds.
* `Tests` - main unit test project containing thousands of tests. These tests are a part of the PR and nightly builds. The folder `Tests\Vibes` contains MATLab scripts that compare Infer&#46;NET to the [VIBES](https://vibes.sourceforge.net/) package. Running them requires `Vibes2_0.jar` (can be obtained on the [VIBES](https://vibes.sourceforge.net/) website) to be present in the same folder.
* `Learners` folder contains the unit tests and the test application for `Learners` (see above).
* `docs` folder contains the scripts for bulding API documentation and for updating https://microsoft.github.io/Infer.NET. Please refer to [README.md](docs/README.md) for more details.
# Build and Test
Infer&#46;NET is cross platform and supports .NET Framework 4.5.2 and Mono 5.0. Unit tests are written using the [XUnit](https://xunit.github.io/) framework.
## Windows
### Prerequisites
**Visual Studio 2017.**
If you don't have Visual Studio 2017, you can install the free [Visual Studio 2015/2017 Community](http://www.visualstudio.com/en-us/products/visual-studio-community-vs.aspx).
### Build and test
You can load `Infer2.sln` solution located in the root of repository into Visual Studio and build all libraries and samples.
**NB!** The solution has a number of build configurations that allows building either for all supported frameworks simultaneously or only for a specific one, but in order for Visual Studio to behave correctly, the solution needs to be closed and re-opened after switching between such configurations.
Unit tests are available in `Test Explorer` window. Normally you should see tests from 3 projects: `Tests`, `PublicTests` and `LearnersTest`. Note, that some of the tests are categorized, and those falling in the `OpenBug` or `BadTest` categories are not supposed to succeed.
## Linux
Almost all components of Infer&#46;NET run on Mono and/or .net core 2.0 except some visualizations in `Compiler` project and sample applications that use WPF.
### Prerequisites
1. **[Mono and MSBuild](https://www.mono-project.com/download/stable/#download-lin)** (version 5.0 and higher)
1. **[.NET Core 2.0 SDK](https://www.microsoft.com/net/download/linux-package-manager/ubuntu18-04/sdk-2.1.202)**
1. **[NuGet](https://docs.microsoft.com/en-us/nuget/install-nuget-client-tools)** package manager
### Build
1. Restore required NuGet packages after cloning the repository. Execute the following command in the root directory of the repository:
```bash
msbuild /p:MonoSupport=true /restore Infer2.sln
```
2. Then build the entire solution (or individual projects) using the following commands:
```bash
msbuild /p:MonoSupport=true Infer2.sln
```
or
```bash
msbuild /p:MonoSupport=true src/Runtime/Runtime.csproj
```
These commands set the `MonoSupport` property to `true`. It excludes code that uses WPF from build.
### Run unit tests
In order to run unit tests, build the test project and execute one of the following commands:
```bash
mono ~/.nuget/packages/xunit.runner.console/2.3.1/tools/net452/xunit.console.exe <path to net452 assembly with tests> <filter>
```
```bash
dotnet ~/.nuget/packages/xunit.runner.console/2.3.1/tools/netcoreapp2.0/xunit.console.dll <path to netcoreapp2.0 assembly with tests> <filter>
```
There are three test assemblies in the solution:
- **Infer.Tests.dll** in the folder `test/Tests`.
- **TestPublic.dll** in the folder `test/TestPublic`.
- **Infer.Learners.Tests.dll** in the folder `test/Learners/LearnersTests`.
Depending on the build configuration, the assemblies will be located in the `bin/Debug` or `bin/Release` subdirectories
of the test project.
`<filter>` is a rule to chose what tests will be run. You can specify them
using `-trait Category=<category>` and `-notrait Category=<category>` parts
of `<filter>`. The former selects tests of
the given category, while the latter selects test that don't belong to the given
category. These can be combined: several `-trait` options mean that _at least one_ of the listed traits has to be present, while several `-notrait` options mean that _none_ of such traits can be present on the filtered tests.
Runner executes tests in parallel by default. However, some test category must be run
sequentially. Such categories are:
- _Performance_
- _DistributedTest_
- _CsoftModel_
- _ModifiesGlobals_
Add the `-parallel none` argument to run them.
_CompilerOptionsTest_ is a category for long running tests, so, for quick
testing you must filter these out by `-notrait`.
_BadTest_ is a category of tests that must fail.
_OpenBug_ is a category of tests that can fail.
An example of quick testing of `Infer.Tests.dll` in `Debug` configuration after changing working directory to
the `Tests` project looks like:
```bash
mono ~/.nuget/packages/xunit.runner.console/2.3.1/tools/net452/xunit.console.exe bin/Debug/net452/Infer.Tests.dll -notrait Category=OpenBug -notrait Category=BadTest -notrait Category=CompilerOptionsTest -notrait Category=CsoftModel -notrait Category=ModifiesGlobals -notrait Category=DistributedTest -notrait Category=Performance
mono ~/.nuget/packages/xunit.runner.console/2.3.1/tools/net452/xunit.console.exe bin/Debug/net452/Infer.Tests.dll -trait Category=CsoftModel -trait Category=ModifiesGlobals -trait Category=DistributedTests -trait Category=Performance -notrait Category=OpenBug -notrait Category=BadTest -notrait Category=CompilerOptionsTest -parallel none
```
To run the same set of tests on .net core:
```bash
dotnet ~/.nuget/packages/xunit.runner.console/2.3.1/tools/netcoreapp2.0/xunit.console.dll bin/Debug/netcoreapp2.0/Infer.Tests.dll -notrait Category=OpenBug -notrait Category=BadTest -notrait Category=CompilerOptionsTest -notrait Category=CsoftModel -notrait Category=ModifiesGlobals -notrait Category=DistributedTest -notrait Category=Performance
dotnet ~/.nuget/packages/xunit.runner.console/2.3.1/tools/netcoreapp2.0/xunit.console.dll bin/Debug/netcoreapp2.0/Infer.Tests.dll -trait Category=CsoftModel -trait Category=ModifiesGlobals -trait Category=DistributedTests -trait Category=Performance -notrait Category=OpenBug -notrait Category=BadTest -notrait Category=CompilerOptionsTest -parallel none
```
Helper scripts `monotest.sh` and `netcoretest.sh` for running unit tests on Mono and .net core respectively are located in the `test` folder.

52
Settings.StyleCop Normal file
Просмотреть файл

@ -0,0 +1,52 @@
<StyleCopSettings Version="105">
<GlobalSettings>
<CollectionProperty Name="RecognizedWords">
<Value>determinization</Value>
<Value>determinize</Value>
<Value>determinizing</Value>
<Value>determinized</Value>
<Value>Kleene</Value>
<Value>memoization</Value>
<Value>non-normalizable</Value>
<Value>non-simplifiable</Value>
<Value>normalizable</Value>
<Value>nucleobase</Value>
<Value>nucleobases</Value>
<Value>quantile</Value>
<Value>recompute</Value>
<Value>recomputing</Value>
<Value>semiring</Value>
<Value>simplifiable</Value>
<Value>Tarjan's</Value>
<Value>trie</Value>
<Value>unnormalized</Value>
</CollectionProperty>
</GlobalSettings>
<Analyzers>
<Analyzer AnalyzerId="StyleCop.CSharp.DocumentationRules">
<Rules>
<Rule Name="FileMustHaveHeader">
<RuleSettings>
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
<Rule Name="FileHeaderMustContainFileName">
<RuleSettings>
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
<Rule Name="FileHeaderFileNameDocumentationMustMatchFileName">
<RuleSettings>
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
<Rule Name="FileHeaderFileNameDocumentationMustMatchTypeName">
<RuleSettings>
<BooleanProperty Name="Enabled">False</BooleanProperty>
</RuleSettings>
</Rule>
</Rules>
<AnalyzerSettings />
</Analyzer>
</Analyzers>
</StyleCopSettings>

Просмотреть файл

@ -0,0 +1,19 @@
<?xml version="1.0" encoding="utf-8"?>
<TestRunConfiguration name="TestRunConfig1" id="8de3ef11-ed64-4bd5-bbd8-d35dd5a63e94" xmlns="http://microsoft.com/schemas/VisualStudio/TeamTest/2010">
<Description>This is a default test run configuration for a local test run.</Description>
<CodeCoverage keyFile="Infer2.snk" />
<Timeouts testTimeout="300000" />
<TestTypeSpecific>
<WebTestRunConfiguration testTypeId="4e7599fa-5ecb-43e9-a887-cd63cf72d207">
<Browser name="Internet Explorer 6.0">
<Headers>
<Header name="User-Agent" value="Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1)" />
<Header name="Accept" value="*/*" />
<Header name="Accept-Language" value="{{$IEAcceptLanguage}}" />
<Header name="Accept-Encoding" value="GZIP" />
</Headers>
</Browser>
<Network Name="LAN" BandwidthInKbps="0" />
</WebTestRunConfiguration>
</TestTypeSpecific>
</TestRunConfiguration>

Просмотреть файл

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<package>
<metadata>
<id>Microsoft.ML.Probabilistic.Compiler</id>
<version>$version$</version>
<authors>Microsoft</authors>
<owners>Microsoft</owners>
<licenseUrl>https://github.com/dotnet/infer/blob/master/LICENSE.txt</licenseUrl>
<projectUrl>https://dotnet.github.io/infer</projectUrl>
<iconUrl>https://raw.githubusercontent.com/dotnet/infer/master/docs/images/infernet.png</iconUrl>
<requireLicenseAcceptance>false</requireLicenseAcceptance>
<description>Infer.NET is a framework for running Bayesian inference in graphical models. It can also be used for probabilistic programming. This package contains the Infer.NET Compiler, which takes model descriptions written using the Infer.NET API and converts them into inference code.</description>
<tags>Infer.NET machine learning Bayesian inference probabilistic</tags>
<dependencies>
<dependency id="Microsoft.ML.Probabilistic" version="$version$" />
<dependency id="Microsoft.CodeAnalysis.CSharp" version="2.0.0" />
<dependency id="System.Reflection.Emit" version="4.3.0" />
<dependency id="System.Reflection.Emit.Lightweight" version="4.3.0" />
<dependency id="System.CodeDom" version="4.4.0" />
</dependencies>
</metadata>
<files>
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Compiler.dll" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Compiler.xml" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Compiler.pdb" target="lib\netstandard2.0" />
</files>
</package>

Просмотреть файл

@ -0,0 +1,29 @@
<?xml version="1.0"?>
<package>
<metadata>
<id>Microsoft.ML.Probabilistic.Learners</id>
<version>$version$</version>
<authors>Microsoft</authors>
<owners>Microsoft</owners>
<licenseUrl>https://github.com/dotnet/infer/blob/master/LICENSE.txt</licenseUrl>
<projectUrl>https://dotnet.github.io/infer</projectUrl>
<iconUrl>https://raw.githubusercontent.com/dotnet/infer/master/docs/images/infernet.png</iconUrl>
<requireLicenseAcceptance>false</requireLicenseAcceptance>
<description>Infer.NET is a framework for running Bayesian inference in graphical models. It can also be used for probabilistic programming. This package contains complete machine learning applications including a classifier and a recommender system.</description>
<tags>Infer.NET machine learning Bayesian inference probabilistic</tags>
<dependencies>
<dependency id="Microsoft.ML.Probabilistic" version="$version$" />
</dependencies>
</metadata>
<files>
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.dll" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.xml" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.pdb" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.Classifier.dll" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.Classifier.xml" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.Classifier.pdb" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.Recommender.dll" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.Recommender.xml" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.Learners.Recommender.pdb" target="lib\netstandard2.0" />
</files>
</package>

Просмотреть файл

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<package>
<metadata>
<id>Microsoft.ML.Probabilistic.Visualizers.Windows</id>
<version>$version$</version>
<authors>Microsoft</authors>
<owners>Microsoft</owners>
<licenseUrl>https://github.com/dotnet/infer/blob/master/LICENSE.txt</licenseUrl>
<projectUrl>https://dotnet.github.io/infer</projectUrl>
<iconUrl>https://raw.githubusercontent.com/dotnet/infer/master/docs/images/infernet.png</iconUrl>
<requireLicenseAcceptance>false</requireLicenseAcceptance>
<description>Infer.NET is a framework for running Bayesian inference in graphical models. It can also be used for probabilistic programming. This package contains visualization tools for exploring and analyzing models on Windows platform.</description>
<tags>Infer.NET machine learning Bayesian inference probabilistic</tags>
<dependencies>
<dependency id="Microsoft.ML.Probabilistic" version="$version$" />
<dependency id="Microsoft.ML.Probabilistic.Compiler" version="$version$" />
<dependency id="Microsoft.Msagl" version="1.1.1" />
<dependency id="Microsoft.Msagl.Drawing" version="1.1.1" />
<dependency id="Microsoft.Msagl.GraphViewerGdi" version="1.1.1" />
</dependencies>
</metadata>
<files>
<file src="$bin$\net461\Microsoft.ML.Probabilistic.Compiler.Visualizers.Windows.dll" target="lib\net461" />
<file src="$bin$\net461\Microsoft.ML.Probabilistic.Compiler.Visualizers.Windows.xml" target="lib\net461" />
<file src="$bin$\net461\Microsoft.ML.Probabilistic.Compiler.Visualizers.Windows.pdb" target="lib\net461" />
</files>
</package>

Просмотреть файл

@ -0,0 +1,20 @@
<?xml version="1.0"?>
<package>
<metadata>
<id>Microsoft.ML.Probabilistic</id>
<version>$version$</version>
<authors>Microsoft</authors>
<owners>Microsoft</owners>
<licenseUrl>https://github.com/dotnet/infer/blob/master/LICENSE.txt</licenseUrl>
<projectUrl>https://dotnet.github.io/infer</projectUrl>
<iconUrl>https://raw.githubusercontent.com/dotnet/infer/master/docs/images/infernet.png</iconUrl>
<requireLicenseAcceptance>false</requireLicenseAcceptance>
<description>Infer.NET is a framework for running Bayesian inference in graphical models. It can also be used for probabilistic programming. This package contains classes and methods needed to execute the inference code.</description>
<tags>Infer.NET machine learning Bayesian inference probabilistic</tags>
</metadata>
<files>
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.dll" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.xml" target="lib\netstandard2.0" />
<file src="$bin$\netstandard2.0\Microsoft.ML.Probabilistic.pdb" target="lib\netstandard2.0" />
</files>
</package>

48
build/copyassemblies.sh Normal file
Просмотреть файл

@ -0,0 +1,48 @@
# Licensed to the .NET Foundation under one or more agreements.
# The .NET Foundation licenses this file to you under the MIT license.
# See the LICENSE file in the project root for more information.
if [ $# -lt 2 ]; then
echo Usage: $0 "<output folder>" "<build configuration>"
exit 1
fi
# First argument is target folder.
out=$1
# Second argument is build configuration
configuration=$2
rm -r "${out}"
mkdir "${out}"
mkdir "${out}/netstandard2.0"
dst="${out}/netstandard2.0"
src="bin/${configuration}/netstandard2.0"
cp ../src/Runtime/${src}/Microsoft.ML.Probabilistic.dll ${dst}
cp ../src/Runtime/${src}/Microsoft.ML.Probabilistic.pdb ${dst}
cp ../src/Runtime/${src}/Microsoft.ML.Probabilistic.xml ${dst}
cp ../src/Compiler/${src}/Microsoft.ML.Probabilistic.Compiler.dll ${dst}
cp ../src/Compiler/${src}/Microsoft.ML.Probabilistic.Compiler.pdb ${dst}
cp ../src/Compiler/${src}/Microsoft.ML.Probabilistic.Compiler.xml ${dst}
cp ../src/Learners/Core/${src}/Microsoft.ML.Probabilistic.Learners.dll ${dst}
cp ../src/Learners/Core/${src}/Microsoft.ML.Probabilistic.Learners.pdb ${dst}
cp ../src/Learners/Core/${src}/Microsoft.ML.Probabilistic.Learners.xml ${dst}
cp ../src/Learners/Classifier/${src}/Microsoft.ML.Probabilistic.Learners.Classifier.dll ${dst}
cp ../src/Learners/Classifier/${src}/Microsoft.ML.Probabilistic.Learners.Classifier.pdb ${dst}
cp ../src/Learners/Classifier/${src}/Microsoft.ML.Probabilistic.Learners.Classifier.xml ${dst}
cp ../src/Learners/Recommender/${src}/Microsoft.ML.Probabilistic.Learners.Recommender.dll ${dst}
cp ../src/Learners/Recommender/${src}/Microsoft.ML.Probabilistic.Learners.Recommender.pdb ${dst}
cp ../src/Learners/Recommender/${src}/Microsoft.ML.Probabilistic.Learners.Recommender.xml ${dst}
mkdir "${out}/net461"
dst="${out}/net461"
src="bin/${configuration}/net461"
cp ../src/Visualizers/Windows/${src}/Microsoft.ML.Probabilistic.Compiler.Visualizers.Windows.dll ${dst}
cp ../src/Visualizers/Windows/${src}/Microsoft.ML.Probabilistic.Compiler.Visualizers.Windows.pdb ${dst}
cp ../src/Visualizers/Windows/${src}/Microsoft.ML.Probabilistic.Compiler.Visualizers.Windows.xml ${dst}

33
build/dotnet-core-pr.yml Normal file
Просмотреть файл

@ -0,0 +1,33 @@
# Build and test Infer.NET using .NET Core
resources:
- repo: self
clean: true
variables:
buildConfiguration: 'Release'
steps:
- task: DotNetCoreInstaller@0
inputs:
packageType: 'sdk'
version: '2.1.202'
- script: |
dotnet build --configuration $(buildConfiguration)Core Infer2.sln
displayName: Build Solution
- task: Bash@3
inputs:
filePath: test/netcoretest.sh
workingDirectory: test
arguments: |
Tests/bin/$(buildConfiguration)Core/netcoreapp2.0/Microsoft.ML.Probabilistic.Tests.dll
Learners/LearnersTests/bin/$(buildConfiguration)Core/netcoreapp2.0/Microsoft.ML.Probabilistic.Learners.Tests.dll
TestPublic/bin/$(buildConfiguration)Core/netcoreapp2.0/TestPublic.dll
displayName: Run Tests
continueOnError: true
- task: PublishTestResults@2
inputs:
testRunner: XUnit
testResultsFiles: 'test/*core-tests.xml'

13
build/updateversion.sh Normal file
Просмотреть файл

@ -0,0 +1,13 @@
# Licensed to the .NET Foundation under one or more agreements.
# The .NET Foundation licenses this file to you under the MIT license.
# See the LICENSE file in the project root for more information.
if [ $# -lt 1 ]; then
echo Usage: $0 "<version>"
exit 1
fi
for f in SharedAssemblyFileVersion.cs SharedAssemblyFileVersion.fs
do
sed -i "s/\(Assembly\(File\)\?Version(\"\)[0-9]\+.[0-9]\+.[0-9]\+.[0-9]\+/\1$1/" ../src/Shared/$f
done

46
build/windows-nightly.yml Normal file
Просмотреть файл

@ -0,0 +1,46 @@
# Nightly build for Windows. Produces NuGet packages
name: 0.3.$(Date:yyMM).$(Date:dd)$(Rev:rr)
resources:
- repo: self
clean: true
trigger: none # disable CI build
steps:
- script: echo $(Build.BuildNumber)
- task: Bash@3
inputs:
filePath: build/updateversion.sh
workingDirectory: build
arguments: $(Build.BuildNumber)
- template: windows.yml
parameters:
buildConfiguration: 'Release'
testSuite: 'nightly'
testPlatform: 'all'
testRunner: 'vstest'
- task: Bash@3
inputs:
filePath: build/copyassemblies.sh
arguments: ../bin Release
workingDirectory: build
- task: NuGetCommand@2
inputs:
command: pack
packagesToPack: build/*.nuspec
includeSymbols: true
buildProperties: version=$(Build.BuildNumber);bin=../bin
- task: CopyFiles@2
inputs:
sourceFolder: bin
targetFolder: $(Build.ArtifactStagingDirectory)
- task: PublishBuildArtifacts@1
inputs:
artifactName: 'Everything'

14
build/windows-pr.yml Normal file
Просмотреть файл

@ -0,0 +1,14 @@
# Build and test Infer.NET using .NET Core
resources:
- repo: self
clean: true
trigger: none # disable CI build
steps:
- template: windows.yml
parameters:
buildConfiguration: 'Release'
testSuite: 'fast'
testPlatform: 'x64'
testRunner: 'all'

72
build/windows.yml Normal file
Просмотреть файл

@ -0,0 +1,72 @@
# Template for Windows environment (.NET 4.6.1)
parameters:
buildConfiguration: 'Release'
testRunner: 'vstest' # or 'all'; TODO: add support for dotnet
testPlatform: 'x64' # or 'x86' or 'all'
testSuite: 'fast' # or 'all'; TODO: add Nightly
steps:
- task: DotNetCoreInstaller@0
inputs:
packageType: 'sdk'
version: '2.1.202'
- task: NuGetToolInstaller@0
inputs:
versionSpec: '4.7.0'
- task: NuGetCommand@2
inputs:
command: 'restore'
restoreSolution: '**/*.sln'
- task: MSBuild@1
inputs:
solution: '**/*.sln'
clean: true
configuration: ${{ parameters.buildConfiguration }}
- ${{ if or(eq(parameters.testRunner, 'vstest'), eq(parameters.testRunner, 'all')) }}:
- ${{ if or(eq(parameters.testSuite, 'fast'), eq(parameters.testSuite, 'all')) }}:
- ${{ if or(eq(parameters.testPlatform, 'x64'), eq(parameters.testPlatform, 'all')) }}:
# Fast test suite on vstest x64
- task: VSTest@2
displayName: Unit tests x64 (sequential)
inputs:
testSelector: 'testAssemblies'
testAssemblyVer2: test/Tests/bin/*/net461/Microsoft.ML.Probabilistic.Tests.dll
testFiltercriteria: '(Platform!=x86)&(((Category=CsoftModel)|(Category=ModifiesGlobals)|(Category=DistributedTests)|(Category=Performance))&(Category!=OpenBug)&(Category!=BadTest)&(Category!=CompilerOptionsTest))'
runInParallel: false
runSettingsFile: test.runsettings
- task: VSTest@2
displayName: Unit tests x64
inputs:
testSelector: 'testAssemblies'
testAssemblyVer2: |
test/Tests/bin/*/net461/Microsoft.ML.Probabilistic.Tests.dll
test/Learners/LearnersTests/bin/*/net461/Microsoft.ML.Probabilistic.Learners.Tests.dll
test/TestPublic/bin/*/net461/TestPublic.dll
testFiltercriteria: '(Platform!=x86)&(Category!=OpenBug)&(Category!=BadTest)&(Category!=CompilerOptionsTest)&(Category!=CsoftModel)&(Category!=ModifiesGlobals)&(Category!=DistributedTest)&(Category!=Performance)'
runInParallel: true
runSettingsFile: test.runsettings
- ${{ if or(eq(parameters.testPlatform, 'x32'), eq(parameters.testPlatform, 'all')) }}:
# Fast test suite on vstest x86
- task: VSTest@2
displayName: Unit tests x86 (sequential)
inputs:
testSelector: 'testAssemblies'
testAssemblyVer2: test/Tests/bin/*/net461/Microsoft.ML.Probabilistic.Tests.dll
testFiltercriteria: '(Platform!=x64)&(((Category=CsoftModel)|(Category=ModifiesGlobals)|(Category=DistributedTests)|(Category=Performance))&(Category!=OpenBug)&(Category!=BadTest)&(Category!=CompilerOptionsTest))'
runInParallel: false
runSettingsFile: x86.runsettings
- task: VSTest@2
displayName: Unit tests x86
inputs:
testSelector: 'testAssemblies'
testAssemblyVer2: |
test/Tests/bin/*/net461/Microsoft.ML.Probabilistic.Tests.dll
test/Learners/LearnersTests/bin/*/net461/Microsoft.ML.Probabilistic.Learners.Tests.dll
test/TestPublic/bin/*/net461/TestPublic.dll
testFiltercriteria: '(Platform!=x64)&(Category!=OpenBug)&(Category!=BadTest)&(Category!=CompilerOptionsTest)&(Category!=CsoftModel)&(Category!=ModifiesGlobals)&(Category!=DistributedTest)&(Category!=Performance)'
runInParallel: true
runSettingsFile: x86.runsettings

Просмотреть файл

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8" ?>
<configuration>
<startup>
<supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.5" />
</startup>
</configuration>

Просмотреть файл

@ -0,0 +1,67 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" />
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform>
<ProjectGuid>{35309643-E029-441E-B237-499212668A5D}</ProjectGuid>
<OutputType>Exe</OutputType>
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>Microsoft.ML.Probabilistic.Tools.PrepareSource</RootNamespace>
<AssemblyName>PrepareSource</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<SccProjectName>SAK</SccProjectName>
<SccLocalPath>SAK</SccLocalPath>
<SccAuxPath>SAK</SccAuxPath>
<SccProvider>SAK</SccProvider>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
<DebugSymbols>true</DebugSymbols>
<DebugType>full</DebugType>
<Optimize>false</Optimize>
<OutputPath>bin\Debug\</OutputPath>
<DefineConstants>DEBUG;TRACE</DefineConstants>
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
<PlatformTarget>AnyCPU</PlatformTarget>
<DebugType>pdbonly</DebugType>
<Optimize>true</Optimize>
<OutputPath>bin\Release\</OutputPath>
<DefineConstants>TRACE</DefineConstants>
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<PropertyGroup>
<SharedVersionOutputDirectory>..\..\src\Shared</SharedVersionOutputDirectory>
</PropertyGroup>
<ItemGroup>
<Reference Include="System" />
<Reference Include="System.Core" />
<Reference Include="System.Xml.Linq" />
<Reference Include="System.Data.DataSetExtensions" />
<Reference Include="Microsoft.CSharp" />
<Reference Include="System.Data" />
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\src\Shared\SharedAssemblyFileVersion.cs" />
<Compile Include="..\..\src\Shared\SharedAssemblyInfo.cs" />
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
<ItemGroup>
<None Include="App.config" />
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
Other similar extension points exist, see Microsoft.Common.targets.
<Target Name="BeforeBuild">
</Target>
<Target Name="AfterBuild">
</Target>
-->
</Project>

Просмотреть файл

@ -0,0 +1,25 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.27703.2035
MinimumVisualStudioVersion = 10.0.40219.1
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "PrepareSource", "PrepareSource.csproj", "{35309643-E029-441E-B237-499212668A5D}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{35309643-E029-441E-B237-499212668A5D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{35309643-E029-441E-B237-499212668A5D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{35309643-E029-441E-B237-499212668A5D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{35309643-E029-441E-B237-499212668A5D}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {4CF38F85-5998-43B6-ACE4-A747D6EDC1A8}
EndGlobalSection
EndGlobal

Просмотреть файл

@ -0,0 +1,121 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Tools.PrepareSource
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Xml.Linq;
using System.Xml.XPath;
internal static class Program
{
private static void Main(string[] args)
{
try
{
MainWrapped(args);
}
catch (Exception e)
{
Error(e.Message);
}
}
private static void MainWrapped(string[] args)
{
if (args.Length != 1)
{
Error("Usage: {0} <src_folder>", Environment.GetCommandLineArgs()[0]);
}
string sourceFolder = args[0];
string destinationFolder = args[0]; // This tool now works in place
if (!Directory.Exists(sourceFolder))
{
Error("Unknown directory: {0}", sourceFolder);
}
var loadedDocFiles = new Dictionary<string, XDocument>();
foreach (string sourceFileName in Directory.EnumerateFiles(sourceFolder, "*.cs", SearchOption.AllDirectories))
{
string temporaryFile = Path.GetRandomFileName();
string destinationFileName = Path.Combine(destinationFolder, temporaryFile);
ProcessFile(sourceFileName, destinationFileName, loadedDocFiles);
File.Delete(sourceFileName);
File.Move(destinationFileName, sourceFileName);
}
}
private static void Error(string format, params object[] args)
{
Console.WriteLine(format, args);
Environment.Exit(1);
}
private static void ProcessFile(string sourceFileName, string destinationFileName, Dictionary<string, XDocument> loadedDocFiles)
{
using (var reader = new StreamReader(sourceFileName))
using (var writer = new StreamWriter(destinationFileName))
{
string line;
int lineNumber = 0;
while ((line = reader.ReadLine()) != null)
{
++lineNumber;
//// For simplicity this code assumes that the line with <include> has no other content on it.
//// This is currently the case for our codebase.
int indexOfDocStart = line.IndexOf("/// <include", StringComparison.InvariantCulture);
if (indexOfDocStart == -1)
{
// Not a line with an include directive
writer.WriteLine(line);
continue;
}
string includeString = line.Substring(indexOfDocStart + "/// ".Length);
var includeDoc = XDocument.Parse(includeString);
XAttribute fileAttribute = includeDoc.Root.Attribute("file");
XAttribute pathAttribute = includeDoc.Root.Attribute("path");
if (fileAttribute == null || pathAttribute == null)
{
Error("An ill-formed include directive at {0}:{1}", sourceFileName, lineNumber);
}
string fullDocFileName = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(sourceFileName), fileAttribute.Value));
XDocument docFile;
if (!loadedDocFiles.TryGetValue(fullDocFileName, out docFile))
{
docFile = XDocument.Load(fullDocFileName);
loadedDocFiles.Add(fullDocFileName, docFile);
}
XElement[] docElements = ((IEnumerable)docFile.XPathEvaluate(pathAttribute.Value)).Cast<XElement>().ToArray();
if (docElements.Length == 0)
{
Console.WriteLine("WARNING: nothing to include for the include directive at {0}:{1}", sourceFileName, lineNumber);
}
else
{
foreach (XElement docElement in docElements)
{
string[] docElementStringLines = docElement.ToString().Split(new[] { Environment.NewLine }, StringSplitOptions.None);
string indentation = new string(' ', indexOfDocStart);
foreach (string docElementStringLine in docElementStringLines)
{
writer.WriteLine("{0}/// {1}", indentation, docElementStringLine);
}
}
}
}
}
}
}
}

Просмотреть файл

@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// General Information about an assembly is controlled through the following
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
[assembly: AssemblyTitle("PrepareSource")]
[assembly: AssemblyDescription("")]
// Setting ComVisible to false makes the types in this assembly not visible
// to COM components. If you need to access a type in this assembly from
// COM, set the ComVisible attribute to true on that type.
[assembly: ComVisible(false)]
// The following GUID is for the ID of the typelib if this project is exposed to COM
[assembly: Guid("714d6270-33e8-433b-a0cb-f607a3ed95ca")]

17
docs/README.md Normal file
Просмотреть файл

@ -0,0 +1,17 @@
## Build API documentation
The API documentation is generated by using DocFX.
The last version of the API documentation is in the `apiguide/` folder. If you want to create a new version of the API documentation:
1. Commit your changes before start the build process.
2. Run the powershell script `docs/_build/makeApiDocs.ps1`. `makeApiDocs.ps1` script should create the folder `apiguide-tmp/` with the API documentation.
`makeApiDocs.ps1` script should:
* copy a part of repository to the temporary folder `InferNet_Copy_Temp`,
* build `PrepareSource.csproj` project,
* run `PrepareSource.exe` for the `InferNet_Copy_Temp` folder,
* install NuGet package `docfx.console`,
* run `docfx`,
* remove temporary folder `InferNet_Copy_Temp`,
* switch to gh-pages,
* commit and push the new changes.

99
docs/_build/makeApiDocs.ps1 поставляемый Normal file
Просмотреть файл

@ -0,0 +1,99 @@
<#
.SYNOPSIS
Makes API documentation for current version of Infer.NET.
.DESCRIPTION
Builds PrepareSource.csproj, creates API documentation using docfx for Infer2 to docs/apiguide/ folder.
#>
# Licensed to the .NET Foundation under one or more agreements.
# The .NET Foundation licenses this file to you under the MIT license.
# See the LICENSE file in the project root for more information.
$scriptDir = Split-Path -Path $MyInvocation.MyCommand.Definition -Parent
$sourceDirectory = [IO.Path]::GetFullPath((join-path $scriptDir '../../'))
$destinationDirectory = [IO.Path]::GetFullPath((join-path $scriptDir '../../InferNet_Copy_Temp/'))
$excludes = @("InferNet_Copy_Temp", "docs", "packages", "_site", "build", "apiguide-tmp", ".git")
Write-Host "Copy subfolders to InferNet_Copy_Temp directory"
Get-ChildItem $sourceDirectory -Directory |
Where-Object{$_.Name -notin $excludes} |
Copy-Item -Destination $destinationDirectory -Recurse -Force
Write-Host "Copy root files to InferNet_Copy_Temp directory"
Get-ChildItem -Path $sourceDirectory -Include "*.*" | Copy-Item -Destination $destinationDirectory -Force
Write-Host "Build PrepareSource project"
if ([Environment]::Is64BitOperatingSystem) {
$pfiles = ${env:PROGRAMFILES(X86)}
} else {
$pfiles = $env:PROGRAMFILES
}
$msBuildExe = Resolve-Path -Path "${pfiles}\Microsoft Visual Studio\*\*\MSBuild\15.0\bin\msbuild.exe" -ErrorAction SilentlyContinue
if (!($msBuildExe)) {
$msBuildExe = Resolve-Path -Path "~/../../usr/bin/msbuild" -ErrorAction SilentlyContinue
$useMono = "mono "
if (!($msBuildExe)) {
Write-Error -Message ('ERROR: Falied to locate MSBuild at' + $msBuildExe)
exit 1
}
}
if ($msbuildExe.GetType() -Eq [object[]]) {
$msbuildExe = $msbuildExe | Select -index 0
}
$projPath = [IO.Path]::GetFullPath((join-path $scriptDir '../PrepareSource/PrepareSource.csproj'))
if (!(Test-Path $projPath)) {
Write-Error -Message ('ERROR: Failed to locate PrepareSource project file at ' + $projPath)
exit 1
}
$BuildArgs = @{
FilePath = $msBuildExe
ArgumentList = $projPath, "/t:rebuild", "/p:Configuration=Release", "/v:minimal"
}
Start-Process @BuildArgs -NoNewWindow -Wait
Write-Host "Run PrepareSource for InferNet_Copy_Temp folder"
$prepareSourcePath = [IO.Path]::GetFullPath((join-path $scriptDir '../PrepareSource/bin/Release/PrepareSource.exe'))
$prepareSourceCmd = "& $useMono ""$prepareSourcePath"" ""$destinationDirectory"""
Invoke-Expression $prepareSourceCmd
Write-Host "Install nuget package docfx.console"
Install-Package -Name docfx.console -provider Nuget -Source https://nuget.org/api/v2 -RequiredVersion 2.38.0 -Destination $scriptDir\..\..\packages -Force
Write-Host "Run docfx"
$docFXPath = [IO.Path]::GetFullPath((join-path $scriptDir '../../packages/docfx.console.2.38.0/tools/docfx.exe'))
$docFxJsonPath = "$scriptDir/../docfx.json"
$docFxCmd = "& $useMono ""$docFXPath"" ""$docFxJsonPath"""
Invoke-Expression $docFxCmd
if ((Test-Path $destinationDirectory)) {
Write-Host "Remove temp repository"
Remove-Item -Path $destinationDirectory -Recurse -Force
}
$apiguideTmp = "./apiguide-tmp"
if (!(Test-Path $apiguideTmp)) {
Write-Host "Couldn't find the folder \apiguide-tmp."
exit 1
} else {
Write-Host "Switch to gh-pages. All uncommited changes will be stashed."
Try {
git stash
git checkout gh-pages
$apiguidePath = "./apiguide"
git pull origin gh-pages
if ((Test-Path $apiguidePath)) {
Remove-Item $apiguidePath -Force -Recurse
} else {
Write-Host "apiguide folder is not found."
}
Rename-Item -path ./apiguide-tmp -newName $apiguidePath
git add --all
git commit -m "Update API Documentation"
# git push origin gh-pages
}
Catch {
Write-Host $Error
}
}

61
docs/docfx.json Normal file
Просмотреть файл

@ -0,0 +1,61 @@
{
"metadata": [
{
"src": [
{
"files": [ "src/Compiler/Compiler.csproj", "src/Runtime/Runtime.csproj", "src/Learners/Recommender/Recommender.csproj", "src/Learners/Classifier/Classifier.csproj", "src/Learners/Core/Core.csproj" ],
"exclude": [ "**/bin/**", "**/obj/**" ],
"src": "../InferNet_Copy_Temp/"
}
],
"dest": "obj/api",
"properties": {
"TargetFramework": "netstandard2.0"
}
}
],
"build": {
"content": [
{
"files": [ "**.yml" ],
"src": "obj/api",
"dest": "api"
},
{
"files": [ "index.md" ],
"src": ".",
"dest": "api"
},
{
"files": [ "toc.yml" ]
}
],
"resource": [
{
"files": [ "images/**" ],
"exclude": [ "_site/**", "**/obj/**", "**.meta" ]
}
],
"xrefService": [ "https://xref.docs.microsoft.com/query?uid={uid}" ],
"postProcessors": [ "ExtractSearchIndex" ],
"globalMetadata": {
"_appTitle": "Infer.NET API Guide",
"_appFooter": "<span>Copyright © .NET Foundation. All rights reserved.</span>",
"_appLogoPath": "images/infernet.png",
"_disableContribution": true,
"_appFaviconPath": "favicon.ico"
},
"fileMetadata": {
"priority": {
"**.md": 2.5,
"api/**.md": 3
}
},
"markdownEngineName": "markdig",
"dest": "../apiguide-tmp",
"template": [
"default",
"template"
]
}
}

Двоичные данные
docs/favicon.ico Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 9.4 KiB

Двоичные данные
docs/images/infernet.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 5.3 KiB

22
docs/index.md Normal file
Просмотреть файл

@ -0,0 +1,22 @@
# Infer.NET code documentation
Infer.NET is a framework for doing inference on graphical models. The user defines a graphical model by specifying the factors and variables of the model. Infer.NET analyses how everything fits together and creates a schedule for doing inference on the model. The model can then be queried for marginal distributions.
## Namespaces
| Namespace | Description |
|---|---|
| [Microsoft.ML.Probabilistic](xref:Microsoft.ML.Probabilistic) | Infer.NET root namespace |
| [Microsoft.ML.Probabilistic.Collections](xref:Microsoft.ML.Probabilistic.Collections) | Infer.NET collections |
| [Microsoft.ML.Probabilistic.Distributions](xref:Microsoft.ML.Probabilistic.Distributions) | Infer.NET distributions |
| [Microsoft.ML.Probabilistic.Distributions.Automata](xref:Microsoft.ML.Probabilistic.Distributions.Automata) | Infer.NET automata |
| [Microsoft.ML.Probabilistic.Distributions.Kernels](xref:Microsoft.ML.Probabilistic.Distributions.Kernels) | Infer.NET Gaussian Process kernels |
| [Microsoft.ML.Probabilistic.Factors](xref:Microsoft.ML.Probabilistic.Factors) | Infer.NET factors and message operator methods |
| [Microsoft.ML.Probabilistic.Learners](xref:Microsoft.ML.Probabilistic.Learners) | Infer.NET learners |
| [Microsoft.ML.Probabilistic.Learners.Mappings](xref:Microsoft.ML.Probabilistic.Learners.Mappings) | Infer.NET learner input data mappings |
| [Microsoft.ML.Probabilistic.Math](xref:Microsoft.ML.Probabilistic.Math) | Infer.NET maths |
| [Microsoft.ML.Probabilistic.Models](xref:Microsoft.ML.Probabilistic.Models) | Infer.NET model description classes |
| [Microsoft.ML.Probabilistic.Compiler.Transforms](xref:Microsoft.ML.Probabilistic.Compiler.Transforms) | Infer.NET compiler transforms |
| [Microsoft.ML.Probabilistic.Utilities](xref:Microsoft.ML.Probabilistic.Utilities) | Infer.NET utilities |
| [Microsoft.ML.Probabilistic.Compiler.Transforms](xref:Microsoft.ML.Probabilistic.Compiler.Transforms) | Infer.NET compiler transform framework |
| [Microsoft.ML.Probabilistic.Compiler.CodeModel](xref:Microsoft.ML.Probabilistic.Compiler.CodeModel) | Infer.NET Compiler code model interfaces |

79
docs/template/styles/main.css поставляемый Normal file
Просмотреть файл

@ -0,0 +1,79 @@
/* Licensed to the .NET Foundation under one or more agreements.
The .NET Foundation licenses this file to you under the MIT license.
See the LICENSE file in the project root for more information. */
body {
font-size: 15px;
}
.container {
width: 100%;
}
.sidetoc {
width: 280px;
}
.sidefilter {
width: 280px;
}
.article.grid-right {
margin-left: 290px;
}
.toc .level1 > li {
font-weight: normal;
}
@media only screen and (max-width: 768px) {
.article.grid-right {
margin-left: 0;
}
.sidetoc {
width: 100%;
}
.sidefilter {
width: 100%;
}
.article {
margin-top: 30px !important;
}
.article {
margin-top: 120px;
margin-bottom: 115px;
}
}
a.navbar-brand {
pointer-events: none;
cursor: default;
text-decoration: none;
}
@media (min-width: 992px) {
.sidetoc {
width: 280px;
}
.sidefilter {
width: 280px;
}
.article.grid-right {
margin-left: 290px;
}
}
@media (min-width: 1200px) {
.sidetoc {
width: 350px;
}
.sidefilter {
width: 350px;
}
.article.grid-right {
margin-left: 360px;
}
}
@media (min-width: 1200px) {
.sidetoc {
width: 450px;
}
.sidefilter {
width: 450px;
}
.article.grid-right {
margin-left: 460px;
}
}

6
docs/toc.yml Normal file
Просмотреть файл

@ -0,0 +1,6 @@
- name: Home Page
href: ../index.html
- name: User Guide
href: ../userguide/index.html
- name: API Documentation
href: api/index.html

Просмотреть файл

@ -0,0 +1,68 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<GenerateAssemblyInfo>false</GenerateAssemblyInfo>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>..\..\Infer2.snk</AssemblyOriginatorKeyFile>
<AssemblyName>$(AssemblyNamePrefix)Microsoft.ML.Probabilistic.Compiler</AssemblyName>
<WarningLevel>4</WarningLevel>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<WarningsAsErrors />
<DefineConstants>TRACE;SUPPRESS_XMLDOC_WARNINGS, SUPPRESS_UNREACHABLE_CODE_WARNINGS, SUPPRESS_AMBIGUOUS_REFERENCE_WARNINGS</DefineConstants>
<NoWarn>1591</NoWarn>
<RootNamespace>Microsoft.ML.Probabilistic</RootNamespace>
<RoslynSupport>true</RoslynSupport>
<CodeDomSupport>true</CodeDomSupport>
<Configurations>Debug;Release</Configurations>
</PropertyGroup>
<PropertyGroup Condition=" '$(RoslynSupport)' == 'true'">
<DefineConstants>$(DefineConstants);ROSLYN</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition=" '$(CodeDomSupport)' == 'true'">
<DefineConstants>$(DefineConstants);CODEDOM</DefineConstants>
</PropertyGroup>
<PropertyGroup>
<DefineConstants>$(DefineConstants);NETCORE;NETSTANDARD;NETSTANDARD2_0</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DebugType>full</DebugType>
<DebugSymbols>true</DebugSymbols>
<DefineConstants>$(DefineConstants);DEBUG</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
<DebugType>pdbonly</DebugType>
<Optimize>true</Optimize>
</PropertyGroup>
<PropertyGroup>
<DocumentationFile>bin\$(Configuration)\$(TargetFramework)\$(AssemblyNamePrefix)Microsoft.ML.Probabilistic.Compiler.xml</DocumentationFile>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Runtime\Runtime.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Condition="$(DefineConstants.Contains('ROSLYN'))" Include="Microsoft.CodeAnalysis.CSharp" version="2.0.0" />
<PackageReference Include="System.Reflection.Emit" Version="4.3.0" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.3.0" />
<PackageReference Condition="$(DefineConstants.Contains('CODEDOM'))" Include="System.CodeDom" Version="4.4.0" />
</ItemGroup>
<ItemGroup>
<None Remove="Infer\Infer.ico" />
<EmbeddedResource Include="Infer\Infer.ico">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</EmbeddedResource>
</ItemGroup>
<ItemGroup>
<Compile Include="..\Shared\SharedAssemblyFileVersion.cs" />
<Compile Include="..\Shared\SharedAssemblyInfo.cs" />
</ItemGroup>
</Project>

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,690 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Diagnostics;
using System.Reflection;
using System.Reflection.Emit;
using Microsoft.ML.Probabilistic.Utilities;
namespace Microsoft.ML.Probabilistic.Compiler.Reflection
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
public delegate object Converter(object toConvert);
public struct Conversion
{
public Converter Converter;
/// <summary>
/// The number of subclass edges between the two types.
/// </summary>
/// <remarks>This is only valid if converter is null, i.e. no conversion is needed.
/// If the two types are the same, then SubclassCount == 0.
/// If one is a direct subclass of the other, SubclassCount == 1.
/// If one is a subclass of a subclass of the other, SubclassCount == 2, and so on.
/// </remarks>
public int SubclassCount;
/// <summary>
/// True if the conversion is explicit.
/// </summary>
/// <remarks>Must be false if converter is null.
/// An implicit conversion must always succeed and does not lose information.
/// Otherwise, it is explicit.
/// </remarks>
public bool IsExplicit;
public override string ToString()
{
return (IsExplicit ? "Explicit" : "Implicit") + " " + SubclassCount + ((Converter == null) ? "" : (" " + Converter.Method.Name));
}
/// <summary>
/// True if A is a more specific conversion than B.
/// </summary>
/// <param name="a"></param>
/// <param name="b"></param>
/// <returns>True if A is a more specific conversion than B.</returns>
/// <remarks>The following criteria are applied in order:
/// 1. A null conversion versus a non-null conversion.
/// 2. Among null conversions, the one crossing fewer subclass links.
/// 3. An implicit conversion versus an explicit conversion.
/// </remarks>
public static bool operator <(Conversion a, Conversion b)
{
//if (a == null || b == null) return false;
if (a.Converter == null)
{
if (b.Converter == null) return a.SubclassCount < b.SubclassCount;
else return true;
}
if (b.Converter == null) return false;
return !a.IsExplicit && b.IsExplicit;
}
public static bool operator >(Conversion a, Conversion b)
{
return (b < a);
}
/// <summary>
/// Returns a numerical weight such that (a.GetWeight() &lt; b.GetWeight()) iff (a &lt; b)
/// </summary>
/// <returns></returns>
public float GetWeight()
{
if (Converter == null)
{
return SubclassCount;
}
else
{
int maxSubclassCount = 100000;
return maxSubclassCount*(IsExplicit ? 2 : 1);
}
}
public static float GetWeight(IEnumerable<Conversion> array)
{
float weight = 0.0F;
int maxLength = 1000;
int index = maxLength;
foreach (Conversion c in array) weight += c.GetWeight()*(index--);
return weight;
}
public static Converter GetPrimitiveConverter(Type fromType, Type toType)
{
Debug.Assert(toType.IsPrimitive);
TypeCode typeCode = Type.GetTypeCode(toType);
string name = typeCode.ToString();
Type[] actuals = new Type[] {fromType};
MethodInfo method = typeof (Convert).GetMethod("To" + name, actuals);
// CreateDelegate doesn't work since method is not of type Converter:
//(Converter)Delegate.CreateDelegate(typeof(Converter), method);
return delegate (object value)
{
return Util.Invoke(method, null, value);
};
}
/// <summary>
/// Get a Conversion structure to a primitive type.
/// </summary>
/// <param name="fromType">Any type.</param>
/// <param name="toType">A primitive type.</param>
/// <param name="info"></param>
/// <returns>false if no conversion exists.</returns>
public static bool TryGetPrimitiveConversion(Type fromType, Type toType, out Conversion info)
{
Debug.Assert(toType.IsPrimitive);
info = new Conversion();
// not needed: info.isExplicit = false;
TypeCode fromTypeCode = Type.GetTypeCode(fromType);
TypeCode toTypeCode = Type.GetTypeCode(toType);
if (fromTypeCode == toTypeCode)
{
// fromType has the same TypeCode but is not assignable to toType.
info.SubclassCount = 1000;
return true;
}
info.Converter = GetPrimitiveConverter(fromType, toType);
// from now on, explicit is the default
info.IsExplicit = true;
// string and DateTime are not actually primitive types, but we leave them in anyway.
if (toTypeCode == TypeCode.String)
{
// anything converts to string
return true;
}
// string and object convert to anything, but not implicitly
bool ok = (fromTypeCode == TypeCode.String || fromTypeCode == TypeCode.Object);
if (fromTypeCode == TypeCode.DateTime || (toTypeCode == TypeCode.DateTime && !ok))
{
// DateTime can only be converted to itself or string
return false;
}
// The primitive conversions are listed here:
// ms-help://MS.VSCC.v80/MS.MSDN.v80/MS.NETDEVFX.v20.en/cpref2/html/T_System_Convert.htm
// The implicit conversions are listed here:
// ms-help://MS.VSCC.v80/MS.MSDN.v80/MS.NETDEVFX.v20.en/cpref10/html/M_System_Reflection_Binder_ChangeType_1_f31c470b.htm
// Conversion from signed to unsigned cannot be implicit, since it may fail.
switch (fromTypeCode)
{
case TypeCode.Char:
switch (toTypeCode)
{
case TypeCode.SByte:
case TypeCode.Byte:
case TypeCode.Int16:
case TypeCode.Int32:
case TypeCode.Int64:
case TypeCode.UInt16:
case TypeCode.UInt32:
case TypeCode.UInt64:
info.IsExplicit = false;
break;
case TypeCode.Boolean:
case TypeCode.Single:
case TypeCode.Double:
case TypeCode.Decimal:
return false;
}
break;
case TypeCode.Byte: // unsigned
switch (toTypeCode)
{
case TypeCode.Char:
case TypeCode.Int16:
case TypeCode.Int32:
case TypeCode.Int64:
case TypeCode.UInt16:
case TypeCode.UInt32:
case TypeCode.UInt64:
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.SByte: // signed
switch (toTypeCode)
{
case TypeCode.Int16:
case TypeCode.Int32:
case TypeCode.Int64:
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.UInt16:
switch (toTypeCode)
{
case TypeCode.UInt32:
case TypeCode.Int32:
case TypeCode.UInt64:
case TypeCode.Int64:
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.Int16:
switch (toTypeCode)
{
case TypeCode.Int32:
case TypeCode.Int64:
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.UInt32:
switch (toTypeCode)
{
case TypeCode.UInt64:
case TypeCode.Int64:
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.Int32:
switch (toTypeCode)
{
case TypeCode.Int64:
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.UInt64:
switch (toTypeCode)
{
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.Int64:
switch (toTypeCode)
{
case TypeCode.Single:
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
case TypeCode.Single:
switch (toTypeCode)
{
case TypeCode.Double:
info.IsExplicit = false;
break;
}
break;
}
if (info.IsExplicit && fromType.IsPrimitive)
{
// wrap the converter with a compatibility check
Converter conv = info.Converter;
Converter back = GetPrimitiveConverter(toType, fromType);
info.Converter = delegate(object fromValue)
{
object toValue = conv(fromValue);
object backValue = back(toValue);
if (!backValue.Equals(fromValue))
throw new ArgumentException("The value " + fromValue + " does not convert to " + toValue.GetType().Name);
return toValue;
};
}
return true;
}
/// <summary>
/// Change array rank and convert elements.
/// </summary>
/// <param name="fromArray"></param>
/// <param name="toRank">Can be smaller, larger, or equal to fromArray.Rank.</param>
/// <param name="toElementType"></param>
/// <param name="conv"></param>
/// <returns>A new array of rank toRank with the same contents as fromArray.</returns>
public static Array ChangeRank(Array fromArray, int toRank, Type toElementType, Converter conv)
{
int fromRank = fromArray.Rank;
int[] lengths = new int[toRank];
int minRank = System.Math.Min(fromRank, toRank);
// if fromRank == 3 and toRank == 1 then lengths = fromArray[1:2] (exclude fromArray[0])
for (int i = 0; i < minRank; i++) lengths[i] = fromArray.GetLength(fromRank - minRank + i);
for (int i = minRank; i < toRank; i++) lengths[i] = 1;
Array toArray = Array.CreateInstance(toElementType, lengths);
if (toArray.Length != fromArray.Length)
{
throw new ArgumentException("The input array has true rank greater than " + toRank.ToString(CultureInfo.InvariantCulture));
}
int[] fromIndex = new int[fromRank];
int[] toIndex = new int[toRank];
if (fromRank == 1)
{
for (int i = 0; i < lengths[0]; i++)
{
object item = fromArray.GetValue(i);
object value;
if (conv != null) value = conv(item);
else value = item;
toIndex[toRank - 1] = i;
toArray.SetValue(value, toIndex);
}
}
else if (toRank == 1)
{
for (int i = 0; i < lengths[0]; i++)
{
// here we assume index[>0]=0
fromIndex[fromRank - 1] = i;
object item = fromArray.GetValue(fromIndex);
object value;
if (conv != null) value = conv(item);
else value = item;
toArray.SetValue(value, i);
}
}
else
{
throw new NotImplementedException();
}
return toArray;
}
public static bool IsNullable(Type type)
{
return !type.IsValueType || (type.IsGenericType && type.GetGenericTypeDefinition().Equals(typeof (Nullable<>)));
}
// must be kept in sync with Binding.TypesAssignableFrom
public static bool IsAssignableFrom(Type toType, Type fromType, out int subclassCount)
{
bool isObject = false;
subclassCount = 0;
for (Type baseType = fromType; baseType != null; baseType = baseType.BaseType)
{
if (baseType.Equals(typeof (object)))
{
isObject = true;
break;
}
if (baseType.Equals(toType)) return true;
subclassCount++;
}
Type[] faces = fromType.GetInterfaces();
foreach (Type face in faces)
{
if (face.Equals(toType)) return true;
subclassCount++;
}
// array covariance (C# 2.0 specification, sec 20.5.9)
if (fromType.IsArray && fromType.GetArrayRank() == 1 && toType.IsGenericType && toType.GetGenericTypeDefinition().Equals(typeof (IList<>)))
{
Type fromElementType = fromType.GetElementType();
Type toElementType = toType.GetGenericArguments()[0];
int elementSubclassCount;
bool ok = IsAssignableFrom(toElementType, fromElementType, out elementSubclassCount);
subclassCount += elementSubclassCount;
return ok;
}
if (isObject && toType.Equals(typeof (object))) return true;
return false;
}
/// <summary>
/// Get a type converter.
/// </summary>
/// <param name="fromType">non-null. May contain type parameters. Use typeof(Nullable) to convert from a null value.</param>
/// <param name="toType">non-null. May contain type parameters. May be typeof(void), for which no conversion is needed.</param>
/// <param name="info"></param>
/// <returns>null if no converter was found.</returns>
public static bool TryGetConversion(Type fromType, Type toType, out Conversion info)
{
info = new Conversion();
if (fromType == typeof (Nullable)) return IsNullable(toType);
int subclassCount;
if (IsAssignableFrom(toType, fromType, out subclassCount))
{
//(toType.IsAssignableFrom(fromType)) {
// toType is a superclass or an interface of fromType
info.SubclassCount = subclassCount;
return true;
}
if (toType == typeof (void))
{
return true;
}
if (fromType.Equals(typeof (object)))
{
info.IsExplicit = true;
info.Converter = delegate(object value) { return ChangeType(value, toType); };
return true;
}
// string -> enum conversion
if (toType.IsEnum && fromType.Equals(typeof (string)))
{
info.IsExplicit = true;
info.Converter = delegate(object fromString) { return Enum.Parse(toType, (string) fromString); };
return true;
}
if (typeof (Delegate).IsAssignableFrom(toType))
{
// DelegateGroup or ComCallback -> Delegate conversion
if (typeof(CanGetDelegate).IsAssignableFrom(fromType))
{
info.IsExplicit = true;
info.Converter = delegate(object dg)
{
Delegate result = ((CanGetDelegate) dg).GetDelegate(toType);
if (result == null) throw new ArgumentException(String.Format("The {0} has no match for the signature of {1}", fromType.ToString(), toType.ToString()));
return result;
};
return true;
}
}
// Matlab array up-conversion
if (toType.IsArray)
{
int toRank = toType.GetArrayRank();
Type toElementType = toType.GetElementType();
if (fromType.IsArray)
{
int fromRank = fromType.GetArrayRank();
Type fromElementType = fromType.GetElementType();
Conversion elementConversion;
if (!TryGetConversion(fromElementType, toElementType, out elementConversion))
return false;
return TryGetArrayConversion(fromRank, toRank, toElementType, elementConversion, out info);
}
else if (fromType.Equals(typeof (System.Reflection.Missing)))
{
// convert to zero-length array
int[] lengths = new int[toRank];
for (int i = 0; i < toRank; i++) lengths[i] = 0;
info.SubclassCount = 1;
info.Converter = delegate(object missing) { return Array.CreateInstance(toElementType, lengths); };
return true;
}
else
{
// convert a scalar to an array of given rank
Conversion elementConversion;
if (!TryGetConversion(fromType, toElementType, out elementConversion))
return false;
return TryGetArrayConversion(0, toRank, toElementType, elementConversion, out info);
}
}
// check for custom conversions
MemberInfo[] implicits = fromType.FindMembers(MemberTypes.Method, BindingFlags.Public | BindingFlags.Static | BindingFlags.InvokeMethod, Type.FilterName,
"op_Implicit");
foreach (MemberInfo member in implicits)
{
MethodInfo method = (MethodInfo) member;
if (method.ReturnType == toType)
{
info.SubclassCount = 1000;
info.Converter = delegate (object value)
{
return Util.Invoke(method, null, value);
};
return true;
}
}
MemberInfo[] explicits = fromType.FindMembers(MemberTypes.Method, BindingFlags.Public | BindingFlags.Static | BindingFlags.InvokeMethod, Type.FilterName,
"op_Explicit");
foreach (MemberInfo member in explicits)
{
MethodInfo method = (MethodInfo) member;
if (method.ReturnType == toType)
{
info.IsExplicit = true;
info.Converter = delegate (object value)
{
return Util.Invoke(method, null, value);
};
return true;
}
}
// lastly try the IConvertible interface
if (toType.IsPrimitive)
{
return TryGetPrimitiveConversion(fromType, toType, out info);
}
return false;
}
public static object ChangeType(object value, Type toType)
{
Conversion info;
Type type = value.GetType();
// prevent an infinite loop
if (type.Equals(typeof (object)))
{
throw new ArgumentException("Cannot convert from " + type.Name + " to " + toType.Name);
}
if (!TryGetConversion(type, toType, out info))
{
throw new ArgumentException("Cannot convert from " + type.Name + " to " + toType.Name);
}
Converter c = info.Converter;
if (c != null) value = c(value);
return value;
}
public static bool TryGetArrayConversion(int fromRank, int toRank, Type toElementType, Conversion elementConversion, out Conversion info)
{
Converter conv = elementConversion.Converter;
info = elementConversion; // assumes info is a struct
if (fromRank == 0)
{
// convert a scalar to an array of given rank
int[] lengths = new int[toRank];
for (int i = 0; i < toRank; i++) lengths[i] = 1;
int[] index = new int[toRank];
info.SubclassCount = 1000;
info.Converter = delegate(object item)
{
object value = item;
if (conv != null) value = conv(item);
Array a = Array.CreateInstance(toElementType, lengths);
a.SetValue(value, index);
return a;
};
return true;
}
else if (toRank == fromRank)
{
if (conv == null) return true;
info.Converter = delegate(object fromArray) { return ChangeRank((Array) fromArray, toRank, toElementType, conv); };
return true;
}
else if (toRank == 1 || fromRank == 1)
{
info.SubclassCount = 1000;
info.Converter = delegate(object fromArray) { return ChangeRank((Array) fromArray, toRank, toElementType, conv); };
return true;
}
return false;
}
private class Pair
{
public object first, second;
public Pair(object first, object second)
{
this.first = first;
this.second = second;
}
}
/// <summary>
/// Convert a weakly-typed delegate into a strongly-typed delegate.
/// </summary>
/// <param name="delegateType">The desired delegate type.</param>
/// <param name="inner">A delegate with parameters (object[] args).
/// The
/// return type can be any type convertible to the return type of delegateType, or void if
/// the delegateType is void.</param>
/// <returns>A delegate of type delegateType. The arguments of this delegate will be
/// passed as (object[]) args to the innerMethod.</returns>
public static Delegate ConvertDelegate(Type delegateType, Delegate inner)
{
// This code is based on:
// http://blogs.msdn.com/joelpob/archive/2005/07/01/434728.aspx
object target = inner;
string methodName = "DynamicMethod"; //delegateType.ToString();
MethodInfo signature = delegateType.GetMethod("Invoke");
Type returnType = signature.ReturnType;
Type innerReturnType = inner.Method.ReturnType;
Conversion conv;
if (!Conversion.TryGetConversion(innerReturnType, returnType, out conv))
{
throw new ArgumentException("Return type of the innerMethod (" + innerReturnType.Name + ") cannot be converted to the delegate return type (" + returnType.Name +
")");
}
Converter c = conv.Converter;
if (c != null)
{
target = new Pair(inner, c);
}
Type[] formals = Invoker.GetParameterTypes(signature);
Type[] formalsWithTarget = formals;
if (target != null)
{
formalsWithTarget = new Type[1 + formals.Length];
formalsWithTarget[0] = target.GetType();
formals.CopyTo(formalsWithTarget, 1);
}
DynamicMethod method = new DynamicMethod(methodName, returnType, formalsWithTarget, typeof (Conversion));
ILGenerator il = method.GetILGenerator();
// put the delegate parameters into an object[]
LocalBuilder args = il.DeclareLocal(typeof (object[]));
il.Emit(OpCodes.Ldc_I4, formals.Length);
il.Emit(OpCodes.Newarr, typeof (object));
il.Emit(OpCodes.Stloc, args);
int offset = (target == null) ? 0 : 1;
for (int i = 0; i < formals.Length; i++)
{
// args[i] = (arg i+1)
il.Emit(OpCodes.Ldloc, args);
il.Emit(OpCodes.Ldc_I4, i);
il.Emit(OpCodes.Ldarg, i + offset);
// box if necessary
if (formals[i].IsValueType)
{
il.Emit(OpCodes.Box, formals[i]);
}
il.Emit(OpCodes.Stelem_Ref);
}
// push the result converter on the stack
if (c != null)
{
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldfld, typeof (Pair).GetField("second"));
}
// call the inner delegate
il.Emit(OpCodes.Ldarg_0);
if (c != null)
{
il.Emit(OpCodes.Ldfld, typeof (Pair).GetField("first"));
}
il.Emit(OpCodes.Ldloc, args);
il.Emit(OpCodes.Call, inner.GetType().GetMethod("Invoke"));
// handle the return value
if (innerReturnType != typeof (void))
{
if (returnType == typeof (void))
{
il.Emit(OpCodes.Pop);
}
else
{
// converter object is already on the stack
// calling c.Method directly does not work (access exception)
//il.Emit(OpCodes.Call, c.Method);
il.Emit(OpCodes.Call, c.GetType().GetMethod("Invoke"));
// Converter always returns object, so unbox if necessary
if (returnType.IsValueType)
{
il.Emit(OpCodes.Unbox_Any, returnType);
}
}
}
il.Emit(OpCodes.Ret);
return method.CreateDelegate(delegateType, target);
}
public static void EmitTryInvoke(ILGenerator il, Delegate d)
{
//il.BeginExceptionBlock();
//il.BeginCatchBlock(typeof(TargetInvocationException));
//il.Emit(OpCodes.Call, typeof(TargetInvocationException).GetMethod("get_InnerException"));
//il.Emit(OpCodes.Throw);
//il.EndExceptionBlock();
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,96 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Reflection;
namespace Microsoft.ML.Probabilistic.Compiler.Reflection
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
public interface CanGetDelegate
{
Delegate GetDelegate(Type type);
}
public class DelegateGroup : CanGetDelegate
{
public MethodBase[] methods;
public object target;
public DelegateGroup()
{
}
public DelegateGroup(Type type, string methodName, BindingFlags flags, object target)
{
//methods = type.GetMethods(flags);
//methods = Array.FindAll<MethodInfo>(methods, delegate(MethodInfo method) { return method.Name == methodName; });
MemberInfo[] members = type.FindMembers(MemberTypes.Method, flags, Type.FilterName, methodName);
methods = Array.ConvertAll<MemberInfo, MethodBase>(members, delegate(MemberInfo info) { return (MethodBase) info; });
this.target = target;
}
public object DynamicInvoke(params object[] args)
{
return Invoker.Invoke(methods, target, args);
}
public Delegate GetDelegate(Type type)
{
// find a method which is compatible with the delegate type
// compatibility is defined at:
// ms-help://MS.VSCC.v80/MS.MSDN.v80/MS.NETDEVFX.v20.en/cpref2/html/M_System_Delegate_CreateDelegate_2_1ee8f399.htm
foreach (MethodBase method in methods)
{
Delegate result = Delegate.CreateDelegate(type, target, (MethodInfo) method, false);
if (result != null) return result;
}
return null;
}
public DelegateGroup MakeGenericMethod(params Type[] types)
{
List<MethodBase> newmethods = new List<MethodBase>();
foreach (MethodBase method in methods)
{
MethodInfo info = method as MethodInfo;
if (info != null && info.IsGenericMethodDefinition)
{
try
{
MethodInfo rmethod = info.MakeGenericMethod(types);
newmethods.Add(rmethod);
}
catch (ArgumentException)
{
}
}
}
DelegateGroup result = new DelegateGroup();
result.methods = newmethods.ToArray();
result.target = target;
return result;
}
public override string ToString()
{
StringBuilder s = new StringBuilder();
for (int i = 0; i < methods.Length; i++)
{
if (i > 0) s.AppendLine();
s.Append(methods[i].ToString());
}
return s.ToString();
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,689 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Text;
using System.Reflection;
using Microsoft.ML.Probabilistic.Utilities;
namespace Microsoft.ML.Probabilistic.Compiler.Reflection
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
// another name might be GenericActivator
/// <summary>
/// Static methods to dynamically invoke methods and access fields of an object.
/// </summary>
public static class Invoker
{
/// <summary>
/// Get a type by name.
/// </summary>
/// <param name="typeName">The name of a type in the System library or any loaded assembly.</param>
/// <returns>The Type object.</returns>
public static Type GetLoadedType(string typeName)
{
Type type = Type.GetType(typeName);
if (type == null)
{
// search through loaded assemblies
AppDomain app = AppDomain.CurrentDomain;
Assembly[] assemblies = app.GetAssemblies();
foreach (Assembly assembly in assemblies)
{
type = assembly.GetType(typeName);
if (type != null) break;
}
if (type == null) throw new TypeLoadException("The type '" + typeName + "' does not exist (perhaps you need a qualifier?)");
}
return type;
}
/// <summary>
/// Invoke the static member which best matches the argument types.
/// </summary>
/// <param name="type"></param>
/// <param name="methodName"></param>
/// <param name="args"></param>
/// <returns></returns>
public static object InvokeStatic(Type type, string methodName, params object[] args)
{
if (methodName == "new")
{
if (args.Length == 0) return Activator.CreateInstance(type);
else
{
// when there are arguments, we may need to perform conversions.
ConstructorInfo[] ctors = type.GetConstructors();
if (ctors.Length == 0) throw new MissingMethodException(type + " has no constructors");
return Invoke(ctors, null, args);
}
}
else
{
return InvokeMember(type, methodName,
BindingFlags.Public | BindingFlags.Static | BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.InvokeMethod |
BindingFlags.FlattenHierarchy, null, args);
}
}
/// <summary>
/// Get an element from an array or collection, or invoke a delegate.
/// </summary>
/// <param name="target">An array, collection, delegate, or method group.</param>
/// <param name="args">Indices for the collection or arguments for the delegate. May be null.</param>
/// <returns>The collection element or return value of the delegate. If args == null, the target itself.</returns>
public static object GetValue(object target, params object[] args)
{
if (args == null) return target;
Delegate d = target as Delegate;
if (d != null)
{
return InvokeMember(d.GetType(), "Invoke", BindingFlags.Public | BindingFlags.Instance | BindingFlags.InvokeMethod, d, args);
}
DelegateGroup dg = target as DelegateGroup;
if (dg != null)
{
return dg.DynamicInvoke(args);
}
if (args.Length == 0) return target;
Array array = target as Array;
if (array != null)
{
int[] index = Array.ConvertAll<object, int>(args, delegate(object o) { return Convert.ToInt32(o); });
return array.GetValue(index);
}
if (target == null) return target;
//throw new ArgumentNullException("The field/property value is null");
// collection
return InvokeMember(target.GetType(), "Item", BindingFlags.Public | BindingFlags.Instance | BindingFlags.GetProperty, target, args);
}
/// <summary>
/// Set an element in an array or collection.
/// </summary>
/// <param name="target">An array or collection.</param>
/// <param name="args">Indices followed by the value to set. Length > 0.</param>
public static void SetValue(object target, params object[] args)
{
Array array = target as Array;
if (array != null)
{
object value = args[args.Length - 1];
Conversion conv;
Type elementType = array.GetType().GetElementType();
if (!Conversion.TryGetConversion(value.GetType(), elementType, out conv))
throw new ArgumentException("The value (" + args[0] + ") does not match the element type (" + elementType.Name + ")");
if (conv.Converter != null) value = conv.Converter(value);
int[] index = new int[args.Length - 1];
for (int i = 0; i < index.Length; i++) index[i] = Convert.ToInt32(args[i]);
array.SetValue(value, index);
return;
}
// collection
InvokeMember(target.GetType(), "Item", BindingFlags.Public | BindingFlags.Instance | BindingFlags.SetProperty, target, args);
}
/// <summary>
/// Invoke the member which best matches the argument types.
/// </summary>
/// <param name="type">The type defining the member.</param>
/// <param name="memberName">The name of a field, property, event, instance method, or static method.</param>
/// <param name="flags"></param>
/// <param name="target">The object whose member to invoke. Ignored for a static field or property.
/// If target is non-null, then it is provided as the first argument to a static method.</param>
/// <param name="args">Can be empty or null. Empty means a function call with no arguments.
/// null means get the member itself.</param>
/// <returns>The result of the invocation. For SetField/SetProperty the result is null.</returns>
/// <exception cref="MissingMemberException"></exception>
/// <exception cref="ArgumentException"></exception>
/// <remarks><para>
/// This routine is patterned after Type.InvokeMember.
/// flags must specify Instance, Static, or both.
/// </para><para>
/// If flags contains CreateInstance, then name is ignored and a constructor is invoked.
/// </para><para>
/// If memberName names a field/property and flags contains SetField/SetProperty,
/// then the field/property's value is changed to args[args.Length-1].
/// If args.Length > 1, the field/property is indexed by args[0:(args.Length-2)].
/// </para><para>
/// If memberName names a field/property and flags contains GetField/GetProperty,
/// then the field/property's value is returned.
/// If args != null and the field/property is a delegate, then it is invoked with args.
/// Otherwise if args != null, the field/property is indexed by args.
/// </para><para>
/// If memberName names an event and flags contains GetField,
/// then the event's EventInfo is returned.
/// If args != null, then the event is raised with args.
/// </para><para>
/// If memberName names a method and flags contains InvokeMethod,
/// then it is invoked with args. A static method is invoked with target and args.
/// If args == null, then the result is a DelegateGroup containing all overloads of the method.
/// </para><para>
/// Other flag values are implemented as in Type.InvokeMember.
/// In each case, overloading is resolved by matching the argument types, possibly with conversions.
/// </para><para>
/// If a matching member is not found, the interfaces of the type are also searched.
/// As a last resort, if the memberName is op_Equality or op_Inequality, then a default implementation
/// is provided (as in C#).
/// </para></remarks>
public static object InvokeMember(Type type, string memberName, BindingFlags flags, object target, params object[] args)
{
if ((flags & BindingFlags.GetField) == BindingFlags.GetField)
{
FieldInfo field = type.GetField(memberName, flags);
if (field != null)
{
object value = field.GetValue(target);
value = GetValue(value, args);
return value;
}
EventInfo evnt = type.GetEvent(memberName, flags);
if (evnt != null)
{
if (args == null) return evnt;
//MethodInfo raiser = evnt.GetRaiseMethod();
//return raiser.Invoke(target, args);
//return type.InvokeMember(memberName, BindingFlags.Public | BindingFlags.Instance | BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.InvokeMethod, null, target, args);
// http://forums.microsoft.com/MSDN/ShowPost.aspx?PostID=130529&SiteID=1
return new NotSupportedException("Raising events is not supported by Reflection");
}
}
if ((flags & BindingFlags.SetField) == BindingFlags.SetField)
{
FieldInfo field = type.GetField(memberName, flags);
if (field != null)
{
if (args == null || args.Length == 0) throw new ArgumentException("No value was provided to set");
if (args.Length == 1)
{
object value = args[0];
Conversion conv;
if (!Conversion.TryGetConversion(value == null ? null : value.GetType(), field.FieldType, out conv))
throw new ArgumentException("The value (" + args[0] + ") does not match the field type (" + field.FieldType.Name + ")");
if (conv.Converter != null) value = conv.Converter(value);
field.SetValue(target, value);
}
else
{
SetValue(field.GetValue(target), args);
}
return null;
}
}
if ((flags & BindingFlags.GetProperty) == BindingFlags.GetProperty)
{
PropertyInfo[] props = type.GetProperties(flags);
props = Array.FindAll<PropertyInfo>(props, delegate(PropertyInfo p) { return p.Name == memberName; });
if (props.Length > 0)
{
PropertyInfo prop = props[0];
int rank = prop.GetIndexParameters().Length;
if (rank > 0)
{
//if(rank != args.Length) throw new ArgumentException("Not enough arguments for indexed property");
// indexed property
memberName = "get_" + memberName;
//MethodInfo method = type.GetMethod(memberName);
return InvokeMember(type, memberName, flags | BindingFlags.InvokeMethod, target, args);
}
// not indexed property
object value = prop.GetValue(target, null);
value = GetValue(value, args);
return value;
}
}
if ((flags & BindingFlags.SetProperty) == BindingFlags.SetProperty)
{
PropertyInfo[] props = type.GetProperties(flags);
props = Array.FindAll<PropertyInfo>(props, delegate(PropertyInfo p) { return p.Name == memberName; });
if (props.Length > 0)
{
if (args == null || args.Length == 0) throw new ArgumentException("No value was provided to set");
//PropertyInfo prop = type.GetProperty(memberName, flags);
PropertyInfo prop = props[0];
int rank = prop.GetIndexParameters().Length;
if (rank > 0 || args.Length == 1)
{
// indexed property
memberName = "set_" + memberName;
//MethodInfo method = type.GetMethod(memberName);
return InvokeMember(type, memberName, flags | BindingFlags.InvokeMethod, target, args);
}
// not indexed property
SetValue(prop.GetValue(target, null), args);
return null;
}
}
if ((flags & BindingFlags.InvokeMethod) == BindingFlags.InvokeMethod)
{
// must explicitly search through BaseTypes because static methods are not inherited
Type baseType = type;
while (baseType != null)
{
MethodInfo[] methods = baseType.GetMethods(flags);
methods = Array.FindAll(methods, delegate(MethodInfo method) { return method.Name == memberName; });
if (methods.Length > 0)
{
if (args == null)
{
// even if there is only one method, we can't create a delegate because we don't know the
// desired delegate type.
DelegateGroup dg = new DelegateGroup();
dg.target = target;
dg.methods = methods;
return dg;
}
return Invoke(methods, target, args);
}
baseType = baseType.BaseType;
}
}
// search through interfaces
Type[] faces = type.GetInterfaces();
foreach (Type face in faces)
{
try
{
return InvokeMember(face, memberName, flags, target, args);
}
catch (MissingMemberException)
{
}
}
// default operator implementations
if ((flags & BindingFlags.InvokeMethod) == BindingFlags.InvokeMethod &&
(flags & BindingFlags.Static) == BindingFlags.Static)
{
if (memberName == "op_Equality" || memberName == "op_Inequality" ||
memberName == "op_GreaterThan" || memberName == "op_LessThan" ||
memberName == "op_GreaterThanOrEqual" || memberName == "op_LessThanOrEqual" ||
memberName == "op_Subtraction" || memberName == "op_Addition" ||
memberName == "op_BooleanOr" || memberName == "op_BooleanAnd" || memberName == "op_BooleanNot" ||
memberName == "op_UnaryNegation")
{
return InvokeMember(typeof (Invoker), memberName, BindingFlags.Public | BindingFlags.Static | BindingFlags.InvokeMethod, target, args);
}
}
throw new MissingMemberException(type.ToString() + " has no member named " + memberName + " under the binding flags " + flags);
}
// used by InvokeMember
public static bool op_Equality(object a, object b)
{
if (a == null)
return (a == b);
else
return a.Equals(b);
}
public static bool op_Inequality(object a, object b)
{
return !op_Equality(a, b);
}
public static int op_UnaryNegation(int a)
{
return -a;
}
public static int op_Addition(int a, int b)
{
return (a + b);
}
public static int op_Subtraction(int a, int b)
{
return (a - b);
}
public static bool op_GreaterThan(int a, int b)
{
return (a > b);
}
public static bool op_GreaterThanOrEqual(int a, int b)
{
return (a >= b);
}
public static bool op_LessThan(int a, int b)
{
return (a < b);
}
public static bool op_LessThanOrEqual(int a, int b)
{
return (a <= b);
}
public static bool op_BooleanOr(bool a, bool b)
{
return (a || b);
}
public static bool op_BooleanAnd(bool a, bool b)
{
return (a && b);
}
public static bool op_BooleanNot(bool b)
{
return !b;
}
/// <summary>
/// Gets the types of the objects in the specified array.
/// </summary>
/// <param name="args">An array of objects whose types to determine. args[i] can be null, whose type is assumed to be typeof(Nullable).</param>
/// <returns>An array of Type objects representing the types of the corresponding elements in args. </returns>
/// <remarks>This method is the same as Type.GetTypeArray except it allows null values.</remarks>
public static Type[] GetTypeArray(object[] args)
{
Type[] actuals = new Type[args.Length];
for (int i = 0; i < args.Length; i++)
{
if (args[i] == null) actuals[i] = typeof (Nullable);
else actuals[i] = args[i].GetType();
}
return actuals;
}
public static string PrintTypes(Type[] types)
{
StringBuilder s = new StringBuilder();
for (int i = 0; i < types.Length; i++)
{
if (i > 0) s.Append(",");
if (types[i] != null) s.Append(types[i].Name);
}
return s.ToString();
}
public static Type[] GetParameterTypes(MethodBase method)
{
ParameterInfo[] parameters = method.GetParameters();
Type[] formals = new Type[parameters.Length];
int i = 0;
foreach (ParameterInfo param in parameters) formals[i++] = param.ParameterType;
return formals;
}
public static int GenericParameterCount(MethodBase method)
{
MethodInfo info = method as MethodInfo;
if (info == null) return 0;
Type[] args = info.GetGenericArguments();
int count = 0;
foreach (Type arg in args)
{
if (arg.IsGenericParameter) count++;
}
return count;
}
public static int GenericParameterCount(Type type)
{
Type[] args = type.GetGenericArguments();
int count = 0;
foreach (Type arg in args)
{
if (arg.IsGenericParameter) count++;
}
return count;
}
private static T[] AddFirst<T>(T[] array, T item)
{
T[] result = new T[array.Length + 1];
result[0] = item;
array.CopyTo(result, 1);
return result;
}
private static T[] ButFirst<T>(T[] array)
{
T[] result = new T[array.Length - 1];
Array.Copy(array, 1, result, 0, result.Length);
return result;
}
/// <summary>
/// Invoke the method which best matches the argument types.
/// </summary>
/// <param name="methods">A non-empty list of methods, exactly one of which will be invoked. Can include both static and instance methods.</param>
/// <param name="target">The instance for an instance method, or if non-null, the first argument of a static method.</param>
/// <param name="args">The remaining arguments of the method.</param>
/// <returns>The return value of the method.</returns>
public static object Invoke(MethodBase[] methods, object target, params object[] args)
{
if (methods.Length == 0) throw new ArgumentException("The method list is empty");
Binding binding;
Exception exception;
Type[] actuals = GetTypeArray(args);
MethodBase method = GetBestMethod(methods, (target == null) ? null : target.GetType(), actuals, ConversionOptions.AllConversions, out binding, out exception);
if (method == null) throw exception;
if (method.IsStatic && target != null)
{
args = AddFirst(args, target);
target = null;
}
else if (!method.IsStatic && !method.IsConstructor && target == null)
{
if (args.Length == 0) throw new ArgumentException("The target is null");
target = args[0];
args = ButFirst(args);
}
// apply argument conversions
binding.ConvertAll(args);
//for (int i = 0; i < args.Length; i++) {
// Converter conv = binding.Conversions[i].Converter;
// if (conv != null) args[i] = conv(args[i]);
//}
object result = Util.Invoke(method, target, args);
if (!method.IsConstructor && ((MethodInfo)method).ReturnType == typeof(void))
return Missing.Value;
else
return result;
}
/// <summary>
/// Invoke a generic method by inferring type parameters from the method arguments.
/// </summary>
/// <param name="method"></param>
/// <param name="target"></param>
/// <param name="args"></param>
/// <returns></returns>
public static object Invoke(MethodBase method, object target, params object[] args)
{
MethodBase[] methods = new MethodBase[] {method};
return Invoke(methods, target, args);
}
/// <summary>
/// Find the method which best matches the given arguments.
/// </summary>
/// <param name="type"></param>
/// <param name="memberName"></param>
/// <param name="flags"></param>
/// <param name="targetType">The type of <c>this</c>, for instance methods. If looking for a static method, use null.</param>
/// <param name="argTypes">Types. argTypes.Length == number of method parameters. argTypes[i] may be null to allow any type, or typeof(Nullable) to mean "any nullable type".</param>
/// <param name="exception">Exception created on failure</param>
/// <returns>null on failure.</returns>
/// <exception cref="ArgumentException">The best matching type parameters did not satisfy the constraints of the generic method.</exception>
/// <exception cref="MissingMethodException">No match was found.</exception>
public static MethodBase GetBestMethod(Type type, string memberName, BindingFlags flags, Type targetType, Type[] argTypes, out Exception exception)
{
MethodInfo[] methods = type.GetMethods(flags);
methods = Array.FindAll<MethodInfo>(methods, delegate(MethodInfo method) { return method.Name == memberName; });
if (methods.Length == 0)
{
exception = new MissingMethodException(type.Name + " does not have any methods named " + memberName + " under the binding flags " + flags);
return null;
}
else return GetBestMethod(methods, targetType, argTypes, out exception);
}
/// <summary>
/// Find the method which best matches the given arguments.
/// </summary>
/// <param name="methods"></param>
/// <param name="targetType">The type of <c>this</c>, for instance methods. If looking for a static method, use null.</param>
/// <param name="argTypes">Types. argTypes.Length == number of method parameters. argTypes[i] may be null to allow any type, or typeof(Nullable) to mean "any nullable type".</param>
/// <param name="exception">Exception created on failure</param>
/// <returns>A non-null MethodBase.</returns>
/// <exception cref="ArgumentException">The best matching type parameters did not satisfy the constraints of the generic method.</exception>
/// <exception cref="MissingMethodException">No match was found.</exception>
public static MethodBase GetBestMethod(MethodBase[] methods, Type targetType, Type[] argTypes, out Exception exception)
{
Binding binding;
return GetBestMethod(methods, targetType, argTypes, ConversionOptions.NoConversions, out binding, out exception);
}
/// <summary>
/// Find the method which best matches the given arguments.
/// </summary>
/// <param name="methods">Methods to search through</param>
/// <param name="targetType">The type of <c>this</c>, for instance methods. If looking for a static method, use null.</param>
/// <param name="argTypes">Types. argTypes.Length == number of method parameters. argTypes[i] may be null to allow any type, or typeof(Nullable) to mean "any nullable type".</param>
/// <param name="conversionOptions">Specifies which conversions are allowed</param>
/// <param name="binding">Modified to contain the generic type arguments and argument conversions needed for calling the method</param>
/// <param name="exception">Exception created on failure</param>
/// <returns>A non-null MethodBase.</returns>
/// <exception cref="ArgumentException">The best matching type parameters did not satisfy the constraints of the generic method.</exception>
/// <exception cref="MissingMethodException">No match was found.</exception>
public static MethodBase GetBestMethod(MethodBase[] methods, Type targetType, Type[] argTypes, ConversionOptions conversionOptions, out Binding binding,
out Exception exception)
{
Type[] instance_actuals = argTypes;
Type[] static_actuals = argTypes;
if (targetType != null)
{
static_actuals = AddFirst(argTypes, targetType);
}
else if (argTypes.Length > 0)
{
// targetType == null
instance_actuals = ButFirst(argTypes);
}
exception = new MissingMethodException("The arguments (" + PrintTypes(argTypes) + ") do not match any overload of " + methods[0].Name);
binding = null;
MethodBase bestMethod = null;
foreach (MethodBase method in methods)
{
Binding b = Binding.GetBestBinding(method, (method.IsStatic || method.IsConstructor) ? static_actuals : instance_actuals, conversionOptions, out exception);
if (b != null && b < binding)
{
binding = b;
bestMethod = method;
}
}
if (bestMethod == null) return null;
// If the method is generic, specialize on the inferred type parameters.
return binding.Bind(bestMethod);
}
#region Cloning
public class DoNotCloneAttribute : Attribute
{
}
public class DoNotCloneItemsAttribute : Attribute
{
}
public static bool HasAttribute(MemberInfo member, Type attributeType)
{
object[] attrs = member.GetCustomAttributes(true);
foreach (object attr in attrs)
{
if (attr.GetType().Equals(attributeType)) return true;
}
return false;
}
/// <summary>
/// Clone an object by reflection on its fields.
/// </summary>
/// <param name="o"></param>
/// <returns></returns>
public static object Clone(object o)
{
return Clone(o, true);
}
public static object Clone(object o, bool cloneFields)
{
bool debug = true;
Type type = o.GetType();
if (type.IsPrimitive)
{
return o; // no need to clone
}
else if (type.IsArray && !type.GetElementType().IsPrimitive)
{
if (type.GetArrayRank() == 1)
{
Array array = (Array) o;
int length = array.Length;
Array result = (Array) Activator.CreateInstance(type, length); // or array.Clone();
for (int i = 0; i < length; i++)
{
object value = array.GetValue(i);
if (cloneFields) value = Clone(value);
result.SetValue(value, i);
}
return result;
}
else throw new NotImplementedException("Cannot clone an array of rank > 1");
}
else if (o is ICloneable)
{
// if it has a Clone() method, use it
return ((ICloneable) o).Clone();
}
else
{
// if it has a copy constructor, use it
try
{
return Activator.CreateInstance(type, o);
}
catch (MissingMethodException)
{
}
if (debug) Console.WriteLine("(Cloning a " + type);
// make an empty instance and clone each field.
object clone = Activator.CreateInstance(type);
FieldInfo[] fields = type.GetFields(BindingFlags.Public | BindingFlags.Instance);
foreach (FieldInfo field in fields)
{
object value = field.GetValue(o);
// must be operator != here, not the Equals method
if (field.GetValue(clone) != value)
{
if (debug) Console.WriteLine(field.Name);
if (cloneFields && !HasAttribute(field, typeof (DoNotCloneAttribute)))
{
value = Clone(value, !HasAttribute(field, typeof (DoNotCloneItemsAttribute)));
}
field.SetValue(clone, value);
}
}
// what about events?
if (debug) Console.WriteLine(")");
return clone;
}
}
#endregion
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,123 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
internal class BreadthFirstSearch<NodeType> : GraphSearcher<NodeType>
{
public BreadthFirstSearch(Converter<NodeType, IEnumerable<NodeType>> successors,
CanCreateNodeData<NodeType> data)
: base(successors, data)
{
Initialize();
}
public BreadthFirstSearch(IGraph<NodeType> graph)
: base(graph)
{
Initialize();
}
public BreadthFirstSearch(IDirectedGraph<NodeType> graph)
: base(graph)
{
Initialize();
}
/// <summary>
/// A queue of nodes for breadth-first search.
/// </summary>
protected IList<NodeType> SearchQueue;
protected void Initialize()
{
SearchQueue = new QueueAsList<NodeType>();
}
public override void Clear()
{
base.Clear();
SearchQueue.Clear();
}
public override void SearchFrom(NodeType start)
{
SearchQueue.Add(start);
DoSearch();
}
public override void SearchFrom(IEnumerable<NodeType> startNodes)
{
foreach (NodeType node in startNodes)
{
SearchQueue.Add(node);
}
DoSearch();
}
protected void DoSearch()
{
while (SearchQueue.Count > 0)
{
NodeType node = SearchQueue[0];
SearchQueue.RemoveAt(0);
switch (IsVisited[node])
{
case VisitState.Unvisited:
OnDiscoverNode(node);
break;
case VisitState.Discovered:
break;
case VisitState.Visiting:
throw new Exception("BUG: start is Visiting and on the SearchQueue.");
case VisitState.Finished:
// this happens if we SearchFrom a Finished node.
continue;
}
// node was previously Unvisited or Discovered
IsVisited[node] = VisitState.Visiting;
foreach (NodeType target in Successors(node))
{
Edge<NodeType> edge = new Edge<NodeType>(node, target);
OnDiscoverEdge(edge);
VisitState targetIsVisited = IsVisited[target];
switch (targetIsVisited)
{
case VisitState.Unvisited:
IsVisited[target] = VisitState.Discovered;
SearchQueue.Add(target);
OnDiscoverNode(target);
// tree edge
OnTreeEdge(edge);
break;
case VisitState.Visiting:
// back edge
OnBackEdge(edge);
break;
case VisitState.Discovered:
// cross edge
OnCrossEdge(edge);
break;
case VisitState.Finished:
// cross edge
OnCrossEdge(edge);
break;
}
}
IsVisited[node] = VisitState.Finished;
OnFinishNode(node);
if (stopped)
{
SearchQueue.Clear();
stopped = false;
break;
}
}
}
}
}

Просмотреть файл

@ -0,0 +1,108 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Find all maximal cliques of an undirected graph
/// </summary>
/// <typeparam name="NodeType">The node type</typeparam>
/// <remarks>
/// Cliques are found using the Bron-Kerbosch algorithm.
/// </remarks>
internal class CliqueFinder<NodeType>
{
Stack<NodeType> clique;
HashSet<NodeType> exclude;
Func<NodeType, IEnumerable<NodeType>> neighbors;
Action<Stack<NodeType>> action;
/// <summary>
/// Create a new CliqueFinder
/// </summary>
/// <param name="neighbors">A dictionary giving the set of neighbors of any node.</param>
public CliqueFinder(Func<NodeType, IEnumerable<NodeType>> neighbors)
{
this.clique = new Stack<NodeType>();
this.exclude = new HashSet<NodeType>();
this.neighbors = neighbors;
}
/// <summary>
/// Find all maximal cliques in an undirected graph using the Bron-Kerbosch algorithm
/// </summary>
/// <param name="candidates">The set of nodes.</param>
/// <param name="action">Called with each clique as it is found.</param>
/// <remarks>
/// The graph must not have self-loops.
/// </remarks>
public void ForEachClique(ICollection<NodeType> candidates, Action<Stack<NodeType>> action)
{
this.action = action;
ForEachClique(candidates, exclude);
}
/// <summary>
/// Find all maximal cliques in an undirected graph using the Bron-Kerbosch algorithm
/// </summary>
/// <param name="candidates">The set of nodes.</param>
/// <param name="exclude">An empty workspace used by recursive calls.</param>
/// <remarks>
/// The graph must not have self-loops.
/// </remarks>
private void ForEachClique(ICollection<NodeType> candidates, HashSet<NodeType> exclude)
{
HashSet<NodeType> visitedCandidates = new HashSet<NodeType>();
NodeType pivot = default(NodeType);
pivot = candidates.FirstOrDefault();
//int maxNeighbors = 0;
//foreach (NodeType i in candidates)
//{
// int numNeighbors = neighbors(i).Count;
// if (numNeighbors > maxNeighbors)
// {
// maxNeighbors = numNeighbors;
// pivot = i;
// }
//}
//foreach (NodeType i in exclude)
//{
// int numNeighbors = neighbors(i).Count;
// if (numNeighbors > maxNeighbors)
// {
// maxNeighbors = numNeighbors;
// pivot = i;
// }
//}
foreach (NodeType i in candidates)
{
IEnumerable<NodeType> nbrs = neighbors(i);
// skip neighbors of pivot
if (nbrs.Contains(pivot))
continue;
HashSet<NodeType> matches = new HashSet<NodeType>();
HashSet<NodeType> excludeMatches = new HashSet<NodeType>();
foreach (NodeType neighbor in nbrs)
{
bool isVisited = visitedCandidates.Contains(neighbor);
if (candidates.Contains(neighbor) && !isVisited)
matches.Add(neighbor);
if (exclude.Contains(neighbor) || isVisited)
excludeMatches.Add(neighbor);
}
clique.Push(i);
if (matches.Count > 0)
ForEachClique(matches, excludeMatches);
else if (excludeMatches.Count == 0)
action(clique);
clique.Pop();
visitedCandidates.Add(i);
}
}
}
}

Просмотреть файл

@ -0,0 +1,79 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
// This class makes it easy to create custom List wrappers, by overriding
// the appropriate methods.
internal class CollectionWrapper<T, ListType> : ICollection<T>
where ListType : ICollection<T>
{
protected ListType list;
protected CollectionWrapper()
{
}
public CollectionWrapper(ListType list)
{
this.list = list;
}
#region IEnumerable methods
public virtual IEnumerator<T> GetEnumerator()
{
return list.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
#endregion
#region ICollection methods
public virtual int Count
{
get { return list.Count; }
}
public virtual bool IsReadOnly
{
get { return list.IsReadOnly; }
}
public virtual void CopyTo(T[] array, int index)
{
list.CopyTo(array, index);
}
public virtual void Add(T item)
{
list.Add(item);
}
public virtual void Clear()
{
list.Clear();
}
public virtual bool Contains(T item)
{
return list.Contains(item);
}
public virtual bool Remove(T item)
{
return list.Remove(item);
}
#endregion
}
}

Просмотреть файл

@ -0,0 +1,341 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Find all elementary cycles of a directed graph
/// </summary>
/// <typeparam name="NodeType">The node type</typeparam>
/// <remarks><para>
/// The cycles are described by firing actions according to the pattern:
/// BeginCycle, AddNode, AddNode, ..., AddNode, EndCycle, BeginCycle, ..., EndCycle.
/// The nodes in a cycle will appear in order of the directed edges between them.
/// </para><para>
/// The algorithm comes from:
/// "Finding all the Elementary Circuits of a Directed Graph"
/// Donald B. Johnson
/// SIAM Journal on Computing (1975)
/// http://dutta.csc.ncsu.edu/csc791_spring07/wrap/circuits_johnson.pdf
/// The runtime is O((n+e)(c+1)) where n is the number of nodes, e is the number of edges, and c
/// is the number of cycles in the graph.
/// </para></remarks>
internal class CycleFinder<NodeType>
{
private IDirectedGraph<NodeType> graph;
private IndexedProperty<NodeType, bool> isBlocked;
private IndexedProperty<NodeType, Set<NodeType>> blockedSources;
private Stack<NodeType> stack = new Stack<NodeType>();
private Set<NodeType> excluded = new Set<NodeType>();
public event Action<NodeType> AddNode;
public event Action BeginCycle , EndCycle;
public CycleFinder(IDirectedGraph<NodeType> graph)
{
this.graph = graph;
CanCreateNodeData<NodeType> data = (CanCreateNodeData<NodeType>) graph;
isBlocked = data.CreateNodeData<bool>(false);
blockedSources = data.CreateNodeData<Set<NodeType>>(null);
}
public void Search()
{
foreach (NodeType node in graph.Nodes)
{
SearchFrom(node, node);
// we have already found all cycles containing this node, so we exclude it from future searches
excluded.Add(node);
foreach (NodeType node2 in graph.Nodes)
{
isBlocked[node2] = false;
blockedSources[node2] = null;
}
}
}
/// <summary>
/// Find cycles containing root
/// </summary>
/// <param name="node"></param>
/// <param name="root"></param>
/// <returns></returns>
private bool SearchFrom(NodeType node, NodeType root)
{
bool foundCycle = false;
stack.Push(node);
isBlocked[node] = true;
foreach (NodeType target in graph.TargetsOf(node))
{
if (excluded.Contains(target)) continue;
if (target.Equals(root))
{
foundCycle = true;
OnBeginCycle();
Stack<NodeType> temp = new Stack<NodeType>();
foreach (NodeType nodeOnStack in stack)
{
temp.Push(nodeOnStack);
}
foreach (NodeType nodeOnStack in temp)
{
OnAddNode(nodeOnStack);
}
OnEndCycle();
}
else if (!isBlocked[target])
{
// recursive call
if (SearchFrom(target, root)) foundCycle = true;
}
}
// at this point, we could always set isBlocked[node]=false,
// but as an optimization we leave it set if no cycle was discovered,
// to prevent repeated searching of the same paths.
if (foundCycle) Unblock(node);
else
{
// at this point, all targets are blocked
foreach (NodeType target in graph.TargetsOf(node))
{
if (excluded.Contains(target)) continue;
Set<NodeType> blockedSourcesOfTarget = blockedSources[target];
if (blockedSourcesOfTarget == null)
{
blockedSourcesOfTarget = new Set<NodeType>();
blockedSources[target] = blockedSourcesOfTarget;
}
blockedSourcesOfTarget.Add(node);
}
}
stack.Pop();
return foundCycle;
}
private void Unblock(NodeType node)
{
isBlocked[node] = false;
Set<NodeType> blockedSourcesOfNode = blockedSources[node];
if (blockedSourcesOfNode != null)
{
blockedSources[node] = null;
foreach (NodeType source in blockedSourcesOfNode)
{
if (isBlocked[source]) Unblock(source);
}
}
}
public void OnAddNode(NodeType node)
{
if (AddNode != null) AddNode(node);
}
public void OnBeginCycle()
{
if (BeginCycle != null) BeginCycle();
}
public void OnEndCycle()
{
if (EndCycle != null) EndCycle();
}
}
/// <summary>
/// Find all elementary cycles of a directed graph
/// </summary>
/// <typeparam name="NodeType">The node type</typeparam>
/// <typeparam name="EdgeType">The edge type</typeparam>
/// <remarks><para>
/// The cycles are described by firing actions according to the pattern:
/// BeginCycle, AddEdge, AddEdge, ..., AddEdge, EndCycle, BeginCycle, ..., EndCycle.
/// The edges in a cycle will appear in order of their directions.
/// </para><para>
/// The algorithm comes from:
/// "Finding all the Elementary Circuits of a Directed Graph"
/// Donald B. Johnson
/// SIAM Journal on Computing (1975)
/// http://dutta.csc.ncsu.edu/csc791_spring07/wrap/circuits_johnson.pdf
/// The runtime is O((n+e)(c+1)) where n is the number of nodes, e is the number of edges, and c
/// is the number of cycles in the graph.
/// </para></remarks>
internal class CycleFinder<NodeType, EdgeType>
{
public event Action<EdgeType> AddEdge;
public event Action BeginCycle , EndCycle;
private IDirectedGraph<NodeType, EdgeType> graph;
private IndexedProperty<NodeType, bool> isBlocked;
private IndexedProperty<NodeType, Set<NodeType>> blockedSources;
private Set<NodeType> excluded = new Set<NodeType>();
protected Stack<StackFrame> SearchStack = new Stack<StackFrame>();
private NodeType root;
protected class StackFrame
{
public NodeType Node;
public IEnumerator<EdgeType> EdgesOut;
public EdgeType TreeEdge;
public bool foundCycle;
public StackFrame(NodeType node, IEnumerator<EdgeType> edgesOut, EdgeType treeEdge)
{
this.Node = node;
this.EdgesOut = edgesOut;
this.TreeEdge = treeEdge;
}
public override string ToString()
{
return Node.ToString();
}
}
public CycleFinder(IDirectedGraph<NodeType, EdgeType> graph)
{
this.graph = graph;
CanCreateNodeData<NodeType> data = (CanCreateNodeData<NodeType>) graph;
isBlocked = data.CreateNodeData<bool>(false);
blockedSources = data.CreateNodeData<Set<NodeType>>(null);
}
public void Search()
{
foreach (NodeType node in graph.Nodes)
{
SearchFrom(node);
// we have already found all cycles containing this node, so we exclude it from future searches
excluded.Add(node);
foreach (NodeType node2 in graph.Nodes)
{
isBlocked[node2] = false;
blockedSources[node2] = null;
}
}
}
/// <summary>
/// Find cycles containing root
/// </summary>
/// <param name="node"></param>
/// <returns></returns>
private void SearchFrom(NodeType node)
{
root = node;
Push(node, default(EdgeType));
DoSearch();
}
protected void DoSearch()
{
while (SearchStack.Count > 0)
{
StackFrame frame = SearchStack.Peek();
if (!PushNextChild(frame))
{
// all children have been visited, so we can remove ourselves from the stack.
NodeType node = frame.Node;
// at this point, we could always set isBlocked[node]=false,
// but as an optimization we leave it set if no cycle was discovered,
// to prevent repeated searching of the same paths.
if (frame.foundCycle) Unblock(node);
else
{
// at this point, all targets are blocked
foreach (NodeType target in graph.TargetsOf(node))
{
if (excluded.Contains(target)) continue;
Set<NodeType> blockedSourcesOfTarget = blockedSources[target];
if (blockedSourcesOfTarget == null)
{
blockedSourcesOfTarget = new Set<NodeType>();
blockedSources[target] = blockedSourcesOfTarget;
}
blockedSourcesOfTarget.Add(node);
}
}
SearchStack.Pop();
if (frame.foundCycle && SearchStack.Count > 0)
{
SearchStack.Peek().foundCycle = true;
}
}
}
}
protected void Push(NodeType node, EdgeType treeEdge)
{
SearchStack.Push(new StackFrame(node, graph.EdgesOutOf(node).GetEnumerator(), treeEdge));
}
protected bool PushNextChild(StackFrame frame)
{
NodeType node = frame.Node;
isBlocked[node] = true;
while (frame.EdgesOut.MoveNext())
{
EdgeType edge = frame.EdgesOut.Current;
NodeType target = graph.TargetOf(edge);
if (excluded.Contains(target)) continue;
if (target.Equals(root))
{
frame.foundCycle = true;
OnBeginCycle();
Stack<EdgeType> temp = new Stack<EdgeType>();
foreach (StackFrame frame2 in SearchStack)
{
temp.Push(frame2.TreeEdge);
}
// the last TreeEdge is a dummy
temp.Pop();
foreach (EdgeType edgeOnStack in temp)
{
OnAddEdge(edgeOnStack);
}
OnAddEdge(edge);
OnEndCycle();
}
else if (!isBlocked[target])
{
// recursive call
Push(target, edge);
return true;
}
}
return false;
}
private void Unblock(NodeType node)
{
isBlocked[node] = false;
Set<NodeType> blockedSourcesOfNode = blockedSources[node];
if (blockedSourcesOfNode != null)
{
blockedSources[node] = null;
foreach (NodeType source in blockedSourcesOfNode)
{
if (isBlocked[source]) Unblock(source);
}
}
}
public void OnAddEdge(EdgeType edge)
{
if (AddEdge != null) AddEdge(edge);
}
public void OnBeginCycle()
{
if (BeginCycle != null) BeginCycle();
}
public void OnEndCycle()
{
if (EndCycle != null) EndCycle();
}
}
}

Просмотреть файл

@ -0,0 +1,330 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Order the nodes to best satisfy cyclic dependencies.
/// </summary>
/// <remarks><p>
/// The algorithm is essentially a topological sort, modified to deal with cycles.
/// In case of a directed cycle, we search for a node which can execute before some of its parents.
/// This judgement is made by the canExecute predicate.
/// </p><p>
/// Algorithm:
/// For each target node, we run bfs backward to collect a list of ancestors and their finishing times.
/// These ancestor nodes are placed in a priority queue according to their finishing time.
/// We then examine each node on the queue to determine if its input requirements are satisfied.
/// If so, the node is scheduled. If not, we put it aside to wait until one of its unscheduled parents is scheduled.
/// </p><p>
/// SourcesOf and CreateNodeData are the only graph methods used. The Nodes property is not used.
/// </p></remarks>
internal class CyclicDependencySort<Node, Cost>
where Cost : IComparable<Cost>
{
private Converter<Node, IEnumerable<Node>> successors;
private BreadthFirstSearch<Node> bfs;
/// <summary>
/// Indicates if the node has been placed on the schedule.
/// </summary>
/// <remarks>
/// Unlike WasScheduledLastIteration, this information changes throughout the scheduling process.
/// </remarks>
public IndexedProperty<Node, bool> IsScheduled;
#if false
/// <summary>
/// Indicates if the node was scheduled on the previous iteration, i.e. its results are available for the current iteration.
/// </summary>
/// <remarks>
/// This information is fixed on entry to the sorter and is not modified.
/// </remarks>
public IndexedProperty<Node, bool> WasScheduledLastIteration;
#endif
private int visitCount;
/// <summary>
/// Indicates if scheduling should stop. Takes the latest node to be scheduled.
/// </summary>
public Converter<Node, bool> StopScheduling;
private bool done;
private Cost threshold;
/// <summary>
/// Cost at which nodes will not be scheduled.
/// </summary>
public Cost Threshold
{
get { return threshold; }
set
{
threshold = value;
ApplyThreshold = true;
}
}
/// <summary>
/// The maximum cost of a scheduled node.
/// </summary>
public Cost MaxScheduledCost;
/// <summary>
/// Indicates that only nodes whose cost is less than Threshold will be scheduled.
/// </summary>
public bool ApplyThreshold;
public CostUpdater updateCost;
/// <summary>
/// Called to update the cost of scheduling a node.
/// </summary>
/// <param name="node"></param>
/// <param name="isScheduled"></param>
/// <param name="cost">The previous cost, which may be modified in place for efficiency. May be null.</param>
/// <returns></returns>
public delegate Cost CostUpdater(Node node, IndexedProperty<Node, bool> isScheduled, Cost cost);
/// <summary>
/// Called just before IsScheduled[node] is set to true.
/// </summary>
public Func<Node, bool> addToSchedule;
/// <summary>
/// Queue of nodes waiting to be scheduled.
/// </summary>
private PriorityQueue<QueueEntry> queue;
public IndexedProperty<Node, QueueEntry> EntryOfNode;
#if false
/// <summary>
/// Represents the undesirability of scheduling a node.
/// </summary>
public struct Badness : IComparable<Badness>
{
public int ClosenessToTarget;
public bool OutOfOrderTrigger;
public bool OutOfOrderTriggee;
public bool ChildOfStaleParent;
public int CompareTo(Badness that)
{
if(this < that) return -1;
else if(that < this) return 1;
else return 0;
}
public static bool operator<(Badness a, Badness b)
{
return (a.Badness < b.Badness) || (a.ClosenessToTarget < b.ClosenessToTarget);
}
}
#endif
public class QueueEntry : IComparable<QueueEntry>
{
public Node Node;
public Cost Cost;
/// <summary>
/// Used to break ties between nodes of the same scheduling cost.
/// Tries to schedule close nodes last.
/// </summary>
public int ClosenessToTarget;
public int QueuePosition;
public int CompareTo(QueueEntry that)
{
int costCompare = this.Cost.CompareTo(that.Cost);
if (costCompare != 0) return costCompare;
else return Comparer<int>.Default.Compare(this.ClosenessToTarget, that.ClosenessToTarget);
}
public static readonly EntryComparer Comparer = new EntryComparer();
public class EntryComparer : IComparer<QueueEntry>
{
public int Compare(QueueEntry x, QueueEntry y)
{
return x.CompareTo(y);
}
}
public override string ToString()
{
return $"Pos={QueuePosition},Cost={Cost},Node={Node}";
}
}
#if false
public class KeyValueComparer<KeyType, ValueType> : IComparer<KeyValuePair<KeyType, ValueType>>
{
public IComparer<KeyType> KeyComparer;
public int Compare(KeyValuePair<KeyType, ValueType> x, KeyValuePair<KeyType, ValueType> y)
{
return KeyComparer.Compare(x.Key, y.Key);
}
public KeyValueComparer(IComparer<KeyType> keyComparer)
{
this.KeyComparer = keyComparer;
}
}
#endif
public CyclicDependencySort(IDirectedGraph<Node> dependencyGraph, CostUpdater updateCost)
: this(dependencyGraph.SourcesOf, dependencyGraph.TargetsOf,
(CanCreateNodeData<Node>) dependencyGraph, updateCost)
{
}
public CyclicDependencySort(
Converter<Node, IEnumerable<Node>> predecessors,
Converter<Node, IEnumerable<Node>> successors,
CanCreateNodeData<Node> data,
CostUpdater updateCost)
{
this.successors = successors;
bfs = new BreadthFirstSearch<Node>(predecessors, data);
this.updateCost = updateCost;
queue = new PriorityQueue<QueueEntry>(QueueEntry.Comparer);
queue.Moved += delegate(QueueEntry entry, int pos) { entry.QueuePosition = pos; };
EntryOfNode = data.CreateNodeData<QueueEntry>(null);
visitCount = 0;
bfs.FinishNode += delegate(Node node)
{
QueueEntry entry = new QueueEntry();
entry.ClosenessToTarget = ++visitCount;
entry.Node = node;
entry.Cost = Threshold;
EntryOfNode[node] = entry;
queue.Add(entry);
UpdateEntry(entry);
};
IsScheduled = data.CreateNodeData<bool>(false);
}
public void DrainQueue(Action<QueueEntry> action)
{
while (queue.Count > 0)
{
QueueEntry entry = queue.ExtractMinimum();
action(entry);
}
}
public void Reschedule(Node node)
{
//Console.WriteLine("rescheduling "+node);
IsScheduled[node] = false;
QueueEntry entry = EntryOfNode[node];
if (entry == null)
{
bfs.IsVisited[node] = VisitState.Unvisited;
bfs.SearchFrom(node);
}
else if (entry.QueuePosition < 0)
{
queue.Add(entry);
UpdateEntry(entry);
}
}
public void Clear()
{
queue.Clear();
bfs.Clear();
visitCount = 0;
IsScheduled.Clear();
MaxScheduledCost = default(Cost);
}
public void MarkScheduled(IEnumerable<Node> targets)
{
foreach (Node node in targets)
{
bfs.IsVisited[node] = VisitState.Finished;
IsScheduled[node] = true;
}
}
public void AddRange(IEnumerable<Node> targets)
{
// bfs will add all ancestors to the queue.
bfs.SearchFrom(targets);
done = false;
while (queue.Count > 0)
{
QueueEntry entry = queue[0];
// make sure the cost of this entry is up-to-date (inefficient, but safe)
UpdateEntry(entry);
if (entry.QueuePosition == 0)
{
if (ApplyThreshold && (entry.Cost.CompareTo(Threshold) >= 0))
{
return;
}
queue.ExtractMinimum();
Node node = entry.Node;
//Console.WriteLine("scheduling " + node);
if (addToSchedule != null)
{
if (!addToSchedule(node))
{
// put back on the queue
queue.Add(entry);
continue;
}
}
if (entry.Cost.CompareTo(MaxScheduledCost) > 0) MaxScheduledCost = entry.Cost;
IsScheduled[node] = true;
if (StopScheduling != null && StopScheduling(node))
{
done = true;
return;
}
//Console.WriteLine("updating targets:");
foreach (Node target in successors(node))
{
if (!IsScheduled[target])
{
UpdateCost(target);
}
}
//Console.WriteLine("done with targets");
}
}
done = true;
}
public void UpdateCost(Node node)
{
// the node may not have an entry if it was never visited. In that case, ignore it.
QueueEntry entry = EntryOfNode[node];
if (entry != null && entry.QueuePosition >= 0)
{
UpdateEntry(entry);
}
}
public void UpdateEntry(QueueEntry entry)
{
Node node = entry.Node;
if (IsScheduled[node]) throw new Exception("node " + node + " was already scheduled");
entry.Cost = updateCost(node, IsScheduled, entry.Cost);
queue.Changed(entry.QueuePosition);
}
public bool IncompleteSchedule
{
get { return !done; }
}
}
}

Просмотреть файл

@ -0,0 +1,300 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
internal class DepthFirstSearch<NodeType> : GraphSearcher<NodeType>
{
public DepthFirstSearch(Converter<NodeType, IEnumerable<NodeType>> successors,
CanCreateNodeData<NodeType> data)
: base(successors, data)
{
Initialize();
}
public DepthFirstSearch(IGraph<NodeType> graph)
: base(graph)
{
Initialize();
}
public DepthFirstSearch(IDirectedGraph<NodeType> graph)
: base(graph)
{
Initialize();
}
protected struct StackFrame
{
public NodeType Node;
public IEnumerator<NodeType> Successors;
public StackFrame(NodeType node, IEnumerator<NodeType> successors)
{
this.Node = node;
this.Successors = successors;
}
}
/// <summary>
/// A stack of nodes for depth-first search.
/// </summary>
protected Stack<StackFrame> SearchStack;
protected void Initialize()
{
SearchStack = new Stack<StackFrame>();
}
public override void Clear()
{
base.Clear();
SearchStack.Clear();
}
public override void SearchFrom(NodeType start)
{
if (IsVisited[start] == VisitState.Unvisited)
{
Push(start);
DoSearch();
}
}
public override void SearchFrom(IEnumerable<NodeType> startNodes)
{
foreach (NodeType node in startNodes)
{
SearchFrom(node);
}
}
public void ForEachStackNode(Action<NodeType> action)
{
foreach (StackFrame frame in SearchStack)
{
action(frame.Node);
}
}
protected void DoSearch()
{
while (SearchStack.Count > 0)
{
StackFrame frame = SearchStack.Peek();
if (!PushNextChild(frame))
{
// all children have been visited, so we can remove ourselves from the stack.
NodeType node = frame.Node;
IsVisited[node] = VisitState.Finished;
SearchStack.Pop();
OnFinishNode(node);
if (SearchStack.Count > 0)
OnFinishTreeEdge(new Edge<NodeType>(SearchStack.Peek().Node, node));
if (stopped)
{
SearchStack.Clear();
stopped = false;
}
}
}
}
protected void Push(NodeType node)
{
SearchStack.Push(new StackFrame(node, Successors(node).GetEnumerator()));
}
protected bool PushNextChild(StackFrame frame)
{
NodeType node = frame.Node;
switch (IsVisited[node])
{
case VisitState.Unvisited:
OnDiscoverNode(node);
break;
case VisitState.Discovered:
break;
case VisitState.Visiting:
break;
case VisitState.Finished:
// this happens if we SearchFrom a Finished node.
return false;
}
// node was previously Unvisited or Discovered
IsVisited[node] = VisitState.Visiting;
while (frame.Successors.MoveNext())
{
NodeType target = frame.Successors.Current;
Edge<NodeType> edge = new Edge<NodeType>(node, target);
OnDiscoverEdge(edge);
VisitState targetIsVisited = IsVisited[target];
switch (targetIsVisited)
{
case VisitState.Unvisited:
IsVisited[target] = VisitState.Discovered;
Push(target);
OnDiscoverNode(target);
// tree edge
OnTreeEdge(edge);
return true;
case VisitState.Visiting:
// back edge
OnBackEdge(edge);
break;
case VisitState.Discovered:
// cross edge
OnCrossEdge(edge);
break;
case VisitState.Finished:
// cross edge
OnCrossEdge(edge);
break;
}
if (stopped) return false;
}
return false;
}
}
internal class DepthFirstSearch<NodeType, EdgeType> : GraphSearcher<NodeType, EdgeType>
{
protected IDirectedGraph<NodeType, EdgeType> graph;
public DepthFirstSearch(IDirectedGraph<NodeType, EdgeType> graph)
{
this.graph = graph;
CreateNodeData(graph);
Initialize();
}
protected struct StackFrame
{
public NodeType Node;
public IEnumerator<EdgeType> EdgesOut;
public EdgeType TreeEdge;
public StackFrame(NodeType node, IEnumerator<EdgeType> edgesOut, EdgeType treeEdge)
{
this.Node = node;
this.EdgesOut = edgesOut;
this.TreeEdge = treeEdge;
}
public override string ToString()
{
return Node.ToString();
}
}
/// <summary>
/// A stack of nodes for depth-first search.
/// </summary>
protected Stack<StackFrame> SearchStack;
protected void Initialize()
{
SearchStack = new Stack<StackFrame>();
}
public override void Clear()
{
base.Clear();
SearchStack.Clear();
}
public override void SearchFrom(NodeType start)
{
if (IsVisited[start] == VisitState.Unvisited)
{
Push(start, default(EdgeType));
DoSearch();
}
}
public override void SearchFrom(IEnumerable<NodeType> startNodes)
{
foreach (NodeType node in startNodes)
{
SearchFrom(node);
}
}
protected void DoSearch()
{
while (SearchStack.Count > 0)
{
StackFrame frame = SearchStack.Peek();
if (!PushNextChild(frame))
{
// all children have been visited, so we can remove ourselves from the stack.
NodeType node = frame.Node;
IsVisited[node] = VisitState.Finished;
SearchStack.Pop();
OnFinishNode(node);
if (SearchStack.Count > 0)
OnFinishTreeEdge(frame.TreeEdge);
}
}
}
protected void Push(NodeType node, EdgeType treeEdge)
{
SearchStack.Push(new StackFrame(node, graph.EdgesOutOf(node).GetEnumerator(), treeEdge));
}
protected bool PushNextChild(StackFrame frame)
{
NodeType node = frame.Node;
switch (IsVisited[node])
{
case VisitState.Unvisited:
OnDiscoverNode(node);
break;
case VisitState.Discovered:
break;
case VisitState.Visiting:
break;
case VisitState.Finished:
// this happens if we SearchFrom a Finished node.
return false;
}
// node was previously Unvisited or Discovered
IsVisited[node] = VisitState.Visiting;
while (frame.EdgesOut.MoveNext())
{
EdgeType edge = frame.EdgesOut.Current;
NodeType target = graph.TargetOf(edge);
OnDiscoverEdge(edge);
VisitState targetIsVisited = IsVisited[target];
switch (targetIsVisited)
{
case VisitState.Unvisited:
IsVisited[target] = VisitState.Discovered;
Push(target, edge);
OnDiscoverNode(target);
// tree edge
OnTreeEdge(edge);
return true;
case VisitState.Visiting:
// back edge
OnBackEdge(edge);
break;
case VisitState.Discovered:
// cross edge
OnCrossEdge(edge);
break;
case VisitState.Finished:
// cross edge
OnCrossEdge(edge);
break;
}
}
return false;
}
}
}

Просмотреть файл

@ -0,0 +1,77 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Computes the distance to all nodes reachable from a starting node.
/// </summary>
/// <typeparam name="NodeType">The node type.</typeparam>
/// <remarks>
/// The distances are returned via the SetDistance action. Nodes which are unreachable from the
/// starting node will have no distance set.
/// </remarks>
internal class DistanceSearch<NodeType>
{
protected Converter<NodeType, IEnumerable<NodeType>> Successors;
protected CanCreateNodeData<NodeType> Data;
public event Action<NodeType, int> SetDistance;
private readonly BreadthFirstSearch<NodeType> bfs;
private int parentCount, childCount, distance;
public DistanceSearch(IGraph<NodeType> graph)
: this(graph.NeighborsOf, (CanCreateNodeData<NodeType>) graph)
{
}
public DistanceSearch(IDirectedGraph<NodeType> graph)
: this(graph.TargetsOf, (CanCreateNodeData<NodeType>) graph)
{
}
public DistanceSearch(Converter<NodeType, IEnumerable<NodeType>> successors,
CanCreateNodeData<NodeType> data)
{
this.Successors = successors;
this.Data = data;
bfs = new BreadthFirstSearch<NodeType>(Successors, Data);
bfs.DiscoverNode += delegate(NodeType node)
{
OnSetDistance(node, distance);
if (distance == 0) distance++;
else childCount++;
};
bfs.FinishNode += delegate(NodeType node)
{
if (--parentCount == 0)
{
parentCount = childCount;
childCount = 0;
distance++;
}
};
}
public void SearchFrom(NodeType start)
{
// we can compute the distance from the starting node to any node using constant additional storage.
// at any given time, the bfs queue contains a set of parent nodes whose distance is <c>distance</c>
// and a set of child nodes whose distance is <c>distance+1</c>. When the parent nodes are
// exhausted, the children become parents and we reset childCount to 0, incrementing distance.
parentCount = 1;
childCount = 0;
distance = 0;
bfs.SearchFrom(start);
bfs.Clear();
}
public void OnSetDistance(NodeType node, int distance)
{
if (SetDistance != null) SetDistance(node, distance);
}
}
}

Просмотреть файл

@ -0,0 +1,80 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Probabilistic.Utilities;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// An edge with stored endpoints.
/// </summary>
/// <typeparam name="NodeType">The type of a node handle.</typeparam>
/// <remarks>This is a commonly-used interface for an edge object which stores its endpoints.
/// Edge handles are not required to implement it.
/// </remarks>
internal interface IEdge<NodeType>
{
NodeType Source { get; }
NodeType Target { get; }
}
internal interface IMutableEdge<NodeType> : IEdge<NodeType>
{
new NodeType Source { get; set; }
new NodeType Target { get; set; }
}
/// <summary>
/// A basic edge object.
/// </summary>
/// <typeparam name="NodeType">The type of a node handle.</typeparam>
internal struct Edge<NodeType> : IEdge<NodeType>
{
public NodeType Source, Target;
public Edge(NodeType source, NodeType target)
{
this.Source = source;
this.Target = target;
}
public static Edge<NodeType> New(NodeType source, NodeType target)
{
return new Edge<NodeType>(source, target);
}
public override string ToString()
{
return String.Format("({0},{1})", Source, Target);
}
public override bool Equals(object obj)
{
if (!(obj is Edge<NodeType>))
return false;
Edge<NodeType> that = (Edge<NodeType>)obj;
return Source.Equals(that.Source) && Target.Equals(that.Target);
}
public override int GetHashCode()
{
return Hash.Combine(Source.GetHashCode(), Target.GetHashCode());
}
#region IEdge<NodeType> Members
NodeType IEdge<NodeType>.Source
{
get { return Source; }
}
NodeType IEdge<NodeType>.Target
{
get { return Target; }
}
#endregion
}
}

Просмотреть файл

@ -0,0 +1,932 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Probabilistic.Compiler.Reflection;
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
internal class NodeDataDictionary<NodeType> : CanCreateNodeData<NodeType>
{
public IndexedProperty<NodeType, T> CreateNodeData<T>()
{
return new IndexedProperty<NodeType, T>(new Dictionary<NodeType, T>());
}
public IndexedProperty<NodeType, T> CreateNodeData<T>(T defaultValue)
{
return new IndexedProperty<NodeType, T>(new Dictionary<NodeType, T>(), defaultValue);
}
}
/// <summary>
/// A directed graph of HasTargets objects.
/// </summary>
/// <remarks><p>
/// Abstractly, a directed graph is a collection of node pairs (node1,node2).
/// Each pair is called an edge. An edge from a node to itself (a
/// self-loop) is allowed. Duplicate edges are allowed. Edges to nodes
/// outside of the graph are allowed.
/// </p><p>
/// A node is added via <c>g.Nodes.Add(node)</c> and an edge is added via
/// <c>g.AddEdge(node1,node2)</c>.
/// </p><p>
/// This implementation supports node labels, which can be any object.
/// A labeled node is added via <c>g.Nodes.WithLabel(label).Add(node)</c>.
/// </p><p>
/// The graph is implemented by an adjacency list which can be singly or
/// doubly-linked.
/// The list is distributed among the nodes of the graph, which hold their
/// child nodes and possibly also their parent nodes (in the
/// doubly-linked case).
/// Nodes which implement the HasSources interface will be doubly-linked,
/// and other nodes will be singly-linked.
/// Thus the graph can be part doubly-linked and part singly-linked.
/// Doubly-linked nodes are more efficient to remove from the graph.
/// </p></remarks>
internal class Graph<NodeType> : NodeDataDictionary<NodeType>,
IMutableDirectedGraph<NodeType>, ILabeledGraph<NodeType, object>,
CanCreateNodeData<NodeType>
where NodeType : HasTargets<NodeType>
{
// nodes must be unique across labels
protected LabeledSet<NodeType, object> nodes;
public Func<NodeType> NodeFactory;
public Graph()
{
nodes = new LabeledSet<NodeType, object>("");
}
public Graph(Func<NodeType> nodeFactory) : this()
{
this.NodeFactory = nodeFactory;
}
#region ILabeledGraph methods
public ILabeledCollection<NodeType, object> Nodes
{
get { return this.nodes; }
}
#endregion
#region IGraph methods
ICollection<NodeType> IGraph<NodeType>.Nodes
{
get { return this.nodes; }
}
public NodeType AddNode()
{
NodeType node = NodeFactory();
nodes.Add(node);
return node;
}
/// <summary>
/// Add a directed edge from node to target.
/// </summary>
/// <param name="source">The source node.</param>
/// <param name="target">The target node.</param>
/// <remarks>The two nodes need not be in the graph, and will not be added to the graph.</remarks>
public void AddEdge(NodeType source, NodeType target)
{
source.Targets.Add(target);
HasSources<NodeType> intoNode = target as HasSources<NodeType>;
if (intoNode != null) intoNode.Sources.Add(source);
}
public bool ContainsEdge(NodeType source, NodeType target)
{
return source.Targets.Contains(target);
}
/// <summary>
/// Remove a directed edge from source to target.
/// </summary>
/// <param name="source">The source node.</param>
/// <param name="target">The target node.</param>
/// <remarks>If there are multiple edges from source to target, only one is removed.</remarks>
public virtual bool RemoveEdge(NodeType source, NodeType target)
{
HasSources<NodeType> intoNode = target as HasSources<NodeType>;
if (intoNode != null) intoNode.Sources.Remove(source);
return source.Targets.Remove(target);
}
#if zero
// Remove all edges with a given label connected to a node, but do not
// remove the node. This includes all edges which refer to the node
// from another node.
// If node.InwardLabels is null, then inward references will not be
// be cleared.
public Graph ClearNodeEdges(HasChildNodesLabeled node, object label)
{
foreach(Node toNode in node.GetOutwardEdges(label)) {
HasParentNodesLabeled intoNode = toNode as HasParentNodesLabeled;
if(intoNode != null)
intoNode.RemoveInwardEdge(node,label);
}
node.ClearOutwardEdges(label);
HasParentNodesLabeled inNode = node as HasParentNodesLabeled;
if(inNode != null) {
// doubly-linked case
foreach(Node fromNode in inNode.GetInwardEdges(label)) {
fromNode.RemoveOutwardEdge(node,label);
}
inNode.ClearInwardEdges(label);
} else {
// singly-linked case
// must search all nodes in the graph
foreach(Node fromNode in nodes) {
fromNode.RemoveOutwardEdge(node,label);
}
}
return this;
}
#endif
public virtual void ClearEdgesOutOf(NodeType source)
{
foreach (NodeType target in source.Targets)
{
HasSources<NodeType> intoNode = target as HasSources<NodeType>;
if (intoNode != null) intoNode.Sources.Remove(source);
}
source.Targets.Clear();
}
public virtual void ClearEdgesInto(NodeType target)
{
HasSources<NodeType> inNode = target as HasSources<NodeType>;
if (inNode != null)
{
// doubly-linked case
foreach (NodeType source in inNode.Sources)
{
source.Targets.Remove(target);
}
inNode.Sources.Clear();
}
else
{
// singly-linked case
// must search all nodes in the graph
foreach (NodeType source in nodes)
{
while (source.Targets.Contains(target))
{
source.Targets.Remove(target);
}
}
}
}
/// <summary>
/// Remove all edges connected to a node.
/// </summary>
/// <param name="node"></param>
/// <remarks>
/// The node itself is not removed. In the singly-linked case, the graph is scanned for
/// all nodes which link to <paramref name="node"/> and these links are cut.
/// In this case, there may still be links from outside the graph.
/// </remarks>
public virtual void ClearEdgesOf(NodeType node)
{
ClearEdgesOutOf(node);
ClearEdgesInto(node);
}
/// <summary>
/// Remove all edges in the graph.
/// </summary>
/// <remarks>In the singly-linked case, there may still be links from outside the graph.
/// </remarks>
public virtual void ClearEdges()
{
foreach (NodeType node in nodes)
{
node.Targets.Clear();
HasSources<NodeType> inNode = node as HasSources<NodeType>;
if (inNode != null) inNode.Sources.Clear();
}
}
#endregion
public virtual int EdgeCount()
{
int count = 0;
foreach (NodeType node in nodes)
{
count += node.Targets.Count;
}
return count;
}
public virtual int NeighborCount(NodeType node)
{
return TargetCount(node) + SourceCount(node);
}
public virtual int TargetCount(NodeType source)
{
return source.Targets.Count;
}
public virtual int SourceCount(NodeType target)
{
HasSources<NodeType> inNode = target as HasSources<NodeType>;
if (inNode != null)
{
return inNode.Sources.Count;
}
else
{
// singly-linked case
// must search all nodes in the graph
int count = 0;
foreach (NodeType source in nodes)
{
if (source.Targets.Contains(target))
{
count++;
}
}
return count;
}
}
public IEnumerable<NodeType> NeighborsOf(NodeType node)
{
foreach (NodeType target in TargetsOf(node))
{
yield return target;
}
foreach (NodeType source in SourcesOf(node))
{
yield return source;
}
}
public IEnumerable<NodeType> TargetsOf(NodeType source)
{
return source.Targets;
}
public IEnumerable<NodeType> SourcesOf(NodeType target)
{
HasSources<NodeType> inNode = target as HasSources<NodeType>;
if (inNode != null)
{
return inNode.Sources;
}
else
{
// singly-linked case
// must search all nodes in the graph
return new SourceEnumerator(this, target);
}
}
internal class SourceEnumerator : IEnumerable<NodeType>
{
private Graph<NodeType> graph;
private NodeType target;
public IEnumerator<NodeType> GetEnumerator()
{
foreach (NodeType source in graph.nodes)
{
if (source.Targets.Contains(target))
{
yield return source;
}
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public SourceEnumerator(Graph<NodeType> graph, NodeType target)
{
this.graph = graph;
this.target = target;
}
}
// This works for graphs because nodes are unique.
public virtual object LabelOf(NodeType node)
{
foreach (object label in nodes.Labels)
{
if (Nodes.WithLabel(label).Contains(node)) return label;
}
return null;
}
public override string ToString()
{
StringBuilder s = new StringBuilder();
foreach (object label in nodes.Labels)
{
s.Append(label).AppendLine(":");
foreach (NodeType node in nodes.WithLabel(label))
{
foreach (NodeType target in node.Targets)
{
s.AppendLine(String.Format(" {0} -> {1}", node, target));
}
HasSources<NodeType> intoNode = node as HasSources<NodeType>;
if (intoNode != null)
{
foreach (NodeType source in intoNode.Sources)
{
s.AppendLine(String.Format(" {0} <- {1}", node, source));
}
}
}
}
return s.ToString();
}
/// <summary>
/// Add all nodes from another graph.
/// </summary>
/// <param name="that"></param>
/// <remarks>Nodes are added as references, i.e. they are not cloned. No new edges are created.</remarks>
public virtual void Add(Graph<NodeType> that)
{
foreach (object label in that.Nodes.Labels)
{
foreach (NodeType node in that.Nodes.WithLabel(label))
{
Nodes.WithLabel(label).Add(node);
}
}
}
public virtual void Clear()
{
ClearEdges();
Nodes.Clear();
}
public virtual bool RemoveNodeAndEdges(NodeType node)
{
ClearEdgesOf(node);
return Nodes.Remove(node);
}
/// <summary>
/// Provides node and edge data for an existing graph.
/// </summary>
/// <typeparam name="NodeDataType"></typeparam>
/// <typeparam name="EdgeDataType"></typeparam>
internal class Data<NodeDataType, EdgeDataType>
{
public IDictionary<NodeType, NodeDataType> Nodes;
public IDictionary<Edge<NodeType>, EdgeDataType> Edges;
public NodeDataType this[NodeType node]
{
get { return Nodes[node]; }
set { Nodes[node] = value; }
}
public EdgeDataType this[NodeType source, NodeType target]
{
get { return Edges[new Edge<NodeType>(source, target)]; }
set { Edges[new Edge<NodeType>(source, target)] = value; }
}
public Data()
{
Nodes = new Dictionary<NodeType, NodeDataType>();
Edges = new Dictionary<Edge<NodeType>, EdgeDataType>();
}
}
}
//---------------------------------------------------------------------------------------------------
//---------------------------------------------------------------------------------------------------
/// <summary>
///
/// </summary>
/// <typeparam name="NodeType"></typeparam>
/// <typeparam name="EdgeType"></typeparam>
/// <param name="source"></param>
/// <param name="target"></param>
/// <returns></returns>
internal delegate EdgeType EdgeFactory<NodeType, EdgeType>(NodeType source, NodeType target);
/// <summary>
/// A directed graph with explicit node and edge objects.
/// </summary>
/// <typeparam name="NodeType"></typeparam>
/// <typeparam name="EdgeType"></typeparam>
internal class Graph<NodeType, EdgeType> : Graph<NodeType>,
IMutableDirectedGraph<NodeType, EdgeType>, IMultigraph<NodeType, EdgeType>,
CanCreateEdgeData<EdgeType>
where NodeType : HasTargets<NodeType>, HasOutEdges<EdgeType>
where EdgeType : IEdge<NodeType>
{
public Graph()
: base()
{
}
public Graph(EdgeFactory<NodeType, EdgeType> edgeFactory)
: base()
{
this.EdgeFactory = edgeFactory;
}
/// <summary>
/// A delegate to create edge objects.
/// </summary>
/// <remarks>
/// The delegate only creates the edge object; it does not register the edge with the endpoints.
/// </remarks>
public EdgeFactory<NodeType, EdgeType> EdgeFactory;
public new EdgeType AddEdge(NodeType source, NodeType target)
{
EdgeType edge = EdgeFactory(source, target);
AddEdge(edge);
return edge;
}
public void AddEdge(EdgeType edge)
{
edge.Source.OutEdges.Add(edge);
HasInEdges<EdgeType> intoTarget = edge.Target as HasInEdges<EdgeType>;
if (intoTarget != null) intoTarget.InEdges.Add(edge);
}
public NodeType SourceOf(EdgeType edge)
{
return edge.Source;
}
public NodeType TargetOf(EdgeType edge)
{
return edge.Target;
}
public IEnumerable<EdgeType> EdgesOutOf(NodeType source)
{
return source.OutEdges;
}
public IEnumerable<EdgeType> EdgesInto(NodeType target)
{
return ((HasInEdges<EdgeType>) target).InEdges;
}
public IEnumerable<EdgeType> Edges
{
get
{
List<EdgeType> list = new List<EdgeType>();
foreach (NodeType node in Nodes)
{
foreach (EdgeType edge in node.OutEdges)
{
list.Add(edge);
}
}
return list;
}
}
public EdgeType GetEdge(NodeType source, NodeType target)
{
EdgeType result;
if (TryGetEdge(source, target, out result))
{
return result;
}
else
{
throw new EdgeNotFoundException(source, target);
}
}
public virtual bool TryGetEdge(NodeType source, NodeType target, out EdgeType edge)
{
bool found = false;
edge = default(EdgeType);
foreach (EdgeType anEdge in source.OutEdges)
{
if (anEdge.Target.Equals(target))
{
if (found) throw new AmbiguousEdgeException(source, target);
found = true;
edge = anEdge;
}
}
return found;
}
/// <summary>
/// Get an edge handle.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <returns>An edge handle if an edge exists.</returns>
/// <exception cref="EdgeNotFoundException">If there is no edge from source to target.</exception>
public EdgeType GetAnyEdge(NodeType source, NodeType target)
{
EdgeType result;
if (AnyEdge(source, target, out result))
{
return result;
}
else
{
throw new EdgeNotFoundException(source, target);
}
}
public virtual bool AnyEdge(NodeType source, NodeType target, out EdgeType edge)
{
foreach (EdgeType anEdge in source.OutEdges)
{
if (anEdge.Target.Equals(target))
{
edge = anEdge;
return true;
}
}
edge = default(EdgeType);
return false;
}
public IEnumerable<EdgeType> EdgesOf(NodeType node)
{
List<EdgeType> edges = new List<EdgeType>();
edges.AddRange(node.OutEdges);
edges.AddRange(((HasInEdges<EdgeType>) node).InEdges);
return edges;
}
public virtual int EdgeCount(NodeType source, NodeType target)
{
int count = 0;
foreach (EdgeType edge in source.OutEdges)
{
if (edge.Target.Equals(target))
{
count++;
}
}
return count;
}
public IEnumerable<EdgeType> EdgesLinking(NodeType source, NodeType target)
{
List<EdgeType> list = new List<EdgeType>();
foreach (EdgeType edge in source.OutEdges)
{
if (edge.Target.Equals(target))
{
list.Add(edge);
}
}
return list;
}
public virtual bool RemoveEdge(EdgeType edge)
{
bool removed = edge.Source.OutEdges.Remove(edge);
HasInEdges<EdgeType> target = edge.Target as HasInEdges<EdgeType>;
if (target != null) removed = removed && target.InEdges.Remove(edge);
return removed;
}
public override bool RemoveEdge(NodeType source, NodeType target)
{
EdgeType edge;
if (AnyEdge(source, target, out edge))
{
return RemoveEdge(edge);
}
else
{
return false;
}
}
public override void ClearEdgesOutOf(NodeType source)
{
foreach (EdgeType edge in source.OutEdges)
{
HasInEdges<EdgeType> intoNode = edge.Target as HasInEdges<EdgeType>;
if (intoNode != null) intoNode.InEdges.Remove(edge);
}
source.OutEdges.Clear();
}
public override void ClearEdgesInto(NodeType target)
{
HasInEdges<EdgeType> inNode = target as HasInEdges<EdgeType>;
if (inNode != null)
{
// doubly-linked case
foreach (EdgeType edge in inNode.InEdges)
{
edge.Source.OutEdges.Remove(edge);
}
inNode.InEdges.Clear();
}
else
{
// singly-linked case
// must search all nodes in the graph
foreach (NodeType source in nodes)
{
EdgeType edge;
while (AnyEdge(source, target, out edge))
{
source.OutEdges.Remove(edge);
}
}
}
}
public override void ClearEdges()
{
foreach (NodeType node in nodes)
{
node.OutEdges.Clear();
HasInEdges<EdgeType> inNode = node as HasInEdges<EdgeType>;
if (inNode != null) inNode.InEdges.Clear();
}
}
public override int EdgeCount()
{
int count = 0;
foreach (NodeType node in nodes)
{
count += node.OutEdges.Count;
}
return count;
}
public override int TargetCount(NodeType source)
{
return source.OutEdges.Count;
}
public override int SourceCount(NodeType target)
{
HasInEdges<EdgeType> inNode = target as HasInEdges<EdgeType>;
if (inNode != null)
{
return inNode.InEdges.Count;
}
else
{
// singly-linked case
// must search all nodes in the graph
int count = 0;
foreach (NodeType source in nodes)
{
if (source.Targets.Contains(target))
{
count++;
}
}
return count;
}
}
/// <summary>
/// Copy edge data to another source and target.
/// </summary>
/// <param name="edge"></param>
/// <param name="source"></param>
/// <param name="target"></param>
/// <returns>A new edge with the same data as edge but between source and target.</returns>
public EdgeType CopyEdge(EdgeType edge, NodeType source, NodeType target)
{
if (edge is ICloneable && edge is IMutableEdge<NodeType>)
{
EdgeType newEdge = (EdgeType) ((ICloneable) edge).Clone();
IMutableEdge<NodeType> mutEdge = (IMutableEdge<NodeType>) newEdge;
mutEdge.Source = source;
mutEdge.Target = target;
AddEdge(newEdge);
return newEdge;
}
else
{
return AddEdge(source, target);
}
}
/// <summary>
/// Copy edges from one node to another.
/// </summary>
/// <param name="node"></param>
/// <param name="node2"></param>
/// <remarks>
/// For every edge (node,x) or (x,node), an edge (node2,x) or (x,node2) is created, with the
/// same label.
/// The existing edges of <paramref name="node2"/> are left unchanged.
/// </remarks>
public void CopyEdges(NodeType node, NodeType node2)
{
foreach (EdgeType edge in node.OutEdges)
{
CopyEdge(edge, node2, edge.Target);
}
HasInEdges<EdgeType> inNode = node as HasInEdges<EdgeType>;
if (inNode != null)
{
foreach (EdgeType edge in inNode.InEdges)
{
CopyEdge(edge, edge.Source, node2);
}
}
}
/// <summary>
/// Copy constructor.
/// </summary>
/// <param name="g"></param>
/// <remarks>Clones all nodes in <paramref name="g"/>, preserving edges between them
/// and to nodes outside the graph.</remarks>
public Graph(Graph<NodeType, EdgeType> g)
{
EdgeFactory = g.EdgeFactory;
// first clone the nodes
Dictionary<NodeType, NodeType> newNodes = new Dictionary<NodeType, NodeType>();
foreach (object label in g.Nodes.Labels)
{
foreach (NodeType node in g.Nodes.WithLabel(label))
{
// clone the node contents, but not its neighbors
NodeType newNode = (NodeType) Invoker.Clone(node);
newNode.OutEdges.Clear();
if (newNode is HasInEdges<EdgeType>)
((HasInEdges<EdgeType>) newNode).InEdges.Clear();
Nodes.WithLabel(label).Add(newNode);
newNodes[node] = newNode;
}
}
// now clone the edges
foreach (NodeType node in g.Nodes)
{
NodeType newNode = newNodes[node];
foreach (EdgeType edge in node.OutEdges)
{
NodeType target;
if (!newNodes.TryGetValue(edge.Target, out target))
{
// edge to a node outside the graph
target = edge.Target;
}
CopyEdge(edge, newNode, target);
}
HasInEdges<EdgeType> inNode = node as HasInEdges<EdgeType>;
if (inNode != null)
{
foreach (EdgeType edge in inNode.InEdges)
{
if (newNodes.ContainsKey(edge.Source)) continue;
// edge from a node outside the graph
CopyEdge(edge, edge.Source, newNode);
}
}
}
}
public virtual object Clone()
{
return new Graph<NodeType, EdgeType>(this);
}
/// <summary>
/// Check that parent and child edges match.
/// </summary>
[System.Diagnostics.ConditionalAttribute("DEBUG")]
public void CheckValid()
{
foreach (NodeType node in Nodes)
{
// check out edges of node
foreach (EdgeType edge in node.OutEdges)
{
HasInEdges<EdgeType> inTarget = edge.Target as HasInEdges<EdgeType>;
if (inTarget == null) continue;
if (!inTarget.InEdges.Contains(edge))
{
throw new InferCompilerException(node + " -> " + edge.Target + " has no backward pointer");
}
}
// check in edges of node
HasInEdges<EdgeType> inNode = node as HasInEdges<EdgeType>;
if (inNode == null) continue;
foreach (EdgeType edge in inNode.InEdges)
{
if (!edge.Source.OutEdges.Contains(edge))
{
throw new InferCompilerException(node + " <- " + edge.Source + " has no backward pointer");
}
}
}
}
public override string ToString()
{
StringBuilder s = new StringBuilder();
foreach (object label in nodes.Labels)
{
s.Append(label).AppendLine(":");
foreach (NodeType source in nodes.WithLabel(label))
{
foreach (EdgeType edge in source.OutEdges)
{
s.Append(" ").AppendLine(edge.ToString());
}
#if false
HasInEdges<EdgeType> target = source as HasInEdges<EdgeType>;
if (target != null) {
foreach (EdgeType edge in target.InEdges) {
s.Append(" ").AppendLine(edge.ToString());
}
}
#endif
}
}
return s.ToString();
}
public IndexedProperty<EdgeType, T> CreateEdgeData<T>()
{
return new IndexedProperty<EdgeType, T>(new Dictionary<EdgeType, T>(), default(T));
}
public IndexedProperty<EdgeType, T> CreateEdgeData<T>(T defaultValue)
{
return new IndexedProperty<EdgeType, T>(new Dictionary<EdgeType, T>(), defaultValue);
}
/// <summary>
/// Provides node and edge data for an existing graph.
/// </summary>
/// <typeparam name="NodeDataType"></typeparam>
/// <typeparam name="EdgeDataType"></typeparam>
public new class Data<NodeDataType, EdgeDataType>
{
public IDictionary<NodeType, NodeDataType> Nodes;
public IDictionary<EdgeType, EdgeDataType> Edges;
public NodeDataType this[NodeType node]
{
get { return Nodes[node]; }
set { Nodes[node] = value; }
}
public EdgeDataType this[EdgeType edge]
{
get { return Edges[edge]; }
set { Edges[edge] = value; }
}
#if false
public EdgeDataType this[NodeType source, NodeType target, EdgeType label]
{
get
{
return Edges[new EdgeIndexType(source, target, label)];
}
set
{
Edges[new EdgeIndexType(source, target, label)] = value;
}
}
#endif
public Data()
{
Nodes = new Dictionary<NodeType, NodeDataType>();
Edges = new Dictionary<EdgeType, EdgeDataType>();
}
}
}
}

Просмотреть файл

@ -0,0 +1,135 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Labels for depth-first search.
/// </summary>
internal enum VisitState
{
Unvisited,
Discovered,
Visiting,
Finished
};
/// <summary>
/// Performs depth-first search or breadth-first search on a graph.
/// </summary>
/// <typeparam name="NodeType"></typeparam>
/// <typeparam name="EdgeType"></typeparam>
internal abstract class GraphSearcher<NodeType, EdgeType>
{
public event Action<NodeType> DiscoverNode , FinishNode;
public event Action<EdgeType> DiscoverEdge , TreeEdge , BackEdge , CrossEdge , FinishTreeEdge;
public IndexedProperty<NodeType, VisitState> IsVisited;
protected bool stopped;
public void Stop()
{
stopped = true;
}
protected void CreateNodeData(IGraph<NodeType> graph)
{
if (graph is CanCreateNodeData<NodeType>)
{
IsVisited = ((CanCreateNodeData<NodeType>) graph).CreateNodeData<VisitState>(VisitState.Unvisited);
}
else
{
IsVisited = new IndexedProperty<NodeType, VisitState>(new Dictionary<NodeType, VisitState>(), VisitState.Unvisited);
}
}
public virtual void Clear()
{
IsVisited.Clear();
}
public void ClearActions()
{
DiscoverNode = null;
FinishNode = null;
DiscoverEdge = null;
TreeEdge = null;
BackEdge = null;
CrossEdge = null;
FinishTreeEdge = null;
}
public abstract void SearchFrom(NodeType start);
public abstract void SearchFrom(IEnumerable<NodeType> startNodes);
protected void OnDiscoverNode(NodeType node)
{
if (DiscoverNode != null) DiscoverNode(node);
}
protected void OnFinishNode(NodeType node)
{
if (FinishNode != null) FinishNode(node);
}
protected void OnTreeEdge(EdgeType edge)
{
if (TreeEdge != null) TreeEdge(edge);
}
protected void OnBackEdge(EdgeType edge)
{
if (BackEdge != null) BackEdge(edge);
}
protected void OnCrossEdge(EdgeType edge)
{
if (CrossEdge != null) CrossEdge(edge);
}
protected void OnDiscoverEdge(EdgeType edge)
{
if (DiscoverEdge != null) DiscoverEdge(edge);
}
protected void OnFinishTreeEdge(EdgeType edge)
{
if (FinishTreeEdge != null) FinishTreeEdge(edge);
}
}
internal abstract class GraphSearcher<NodeType> : GraphSearcher<NodeType, Edge<NodeType>>
{
protected Converter<NodeType, IEnumerable<NodeType>> Successors;
protected GraphSearcher(Converter<NodeType, IEnumerable<NodeType>> successors,
IndexedProperty<NodeType, VisitState> isVisited)
{
this.Successors = successors;
this.IsVisited = isVisited;
}
protected GraphSearcher(Converter<NodeType, IEnumerable<NodeType>> successors,
CanCreateNodeData<NodeType> data)
{
this.Successors = successors;
this.IsVisited = data.CreateNodeData<VisitState>(VisitState.Unvisited);
}
protected GraphSearcher(IGraph<NodeType> graph)
{
Successors = graph.NeighborsOf;
CreateNodeData(graph);
}
protected GraphSearcher(IDirectedGraph<NodeType> graph)
{
Successors = graph.TargetsOf;
CreateNodeData(graph);
}
}
}

Просмотреть файл

@ -0,0 +1,479 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Diagnostics;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Utilities;
using NodeIndex = System.Int32;
using EdgeIndex = System.Int32;
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Represents a graph in which each node may belong to a group, and each group may belong to another group.
/// Group membership must be acyclic.
/// Nodes and groups are identified by integers.
/// An integer n identifies a group if two conditions hold:
/// 1. n is at least firstGroup and less than lastGroup.
/// 2. n is at least g.Nodes.Count, or g.NeighborCount(n) == 0.
/// </summary>
internal class GroupGraph : //IMutableDirectedGraph<int, int>, IMultigraph<int, int>,
CanCreateNodeData<int>, CanCreateEdgeData<int>
{
internal readonly IndexedGraph g;
// maps a node or group index into a group index (or -1 if no group)
internal readonly IList<NodeIndex> groupOf;
/// <summary>
/// firstGroup is at most g.Nodes.Count.
/// </summary>
internal readonly int firstGroup, lastGroup;
/// <summary>
/// The set of edges whose target is in the group and source is not. Indexed by (group - firstGroup). A null indicates a non-group.
/// </summary>
internal readonly IList<ICollection<EdgeIndex>> edgesIntoGroup, edgesOutOfGroup;
private DepthFirstSearch<NodeIndex> dfsScheduleWithGroups;
private List<NodeIndex> groupSchedule;
/// <summary>
/// Caller must subsequently call BuildGroupEdges
/// </summary>
/// <param name="g"></param>
/// <param name="groupOf"></param>
/// <param name="firstGroup"></param>
internal GroupGraph(IndexedGraph g, IList<NodeIndex> groupOf, int firstGroup)
{
this.g = g;
this.groupOf = groupOf;
if (firstGroup > g.Nodes.Count) throw new ArgumentException("firstGroup > g.Nodes.Count");
this.firstGroup = firstGroup;
this.lastGroup = groupOf.Count;
this.edgesIntoGroup = new List<ICollection<EdgeIndex>>();
this.edgesOutOfGroup = new List<ICollection<EdgeIndex>>();
}
public IndexedProperty<int, T> CreateEdgeData<T>(T defaultValue)
{
throw new NotImplementedException();
}
public IndexedProperty<EdgeIndex, T> CreateNodeData<T>(T defaultValue)
{
return MakeIndexedProperty.FromArray(new T[lastGroup], defaultValue);
}
public bool IsGroup(NodeIndex node)
{
return (node >= firstGroup) && (node >= g.Nodes.Count || g.NeighborCount(node) == 0);
}
public IEnumerable<NodeIndex> EdgesInto(NodeIndex node)
{
if (IsGroup(node))
{
// node is actually a group
int groupIndex = node - firstGroup;
var edges = edgesIntoGroup[groupIndex];
if (edges != null)
return edges;
else
return new NodeIndex[0];
}
else
{
return g.EdgesInto(node);
}
}
public IEnumerable<NodeIndex> EdgesOutOf(NodeIndex node)
{
if (IsGroup(node))
{
// node is actually a group
int groupIndex = node - firstGroup;
var edges = edgesOutOfGroup[groupIndex];
if (edges != null)
return edges;
else
return new NodeIndex[0];
}
else
{
return g.EdgesOutOf(node);
}
}
public NodeIndex TargetOf(EdgeIndex edge)
{
Set<NodeIndex> groups = GetGroupSet(g.SourceOf(edge));
return GetLargestGroupExcluding(g.TargetOf(edge), groups);
}
// as we traverse the graph, we will always stay at the highest level of abstraction (the largest group) that we can.
public IEnumerable<NodeIndex> SourcesOf(NodeIndex node)
{
Set<NodeIndex> groups = GetGroupSet(node);
foreach (EdgeIndex edge in EdgesInto(node))
{
NodeIndex source = g.SourceOf(edge);
if (source != node)
yield return GetLargestGroupExcluding(source, groups);
}
}
// as we traverse the graph, we will always stay at the highest level of abstraction (the largest group) that we can.
public IEnumerable<NodeIndex> TargetsOf(NodeIndex node)
{
Set<NodeIndex> groups = GetGroupSet(node);
foreach (EdgeIndex edge in EdgesOutOf(node))
{
NodeIndex target = g.TargetOf(edge);
if (target != node)
yield return GetLargestGroupExcluding(target, groups);
}
}
public void BuildGroupEdges()
{
int numGroups = lastGroup - firstGroup;
while (edgesIntoGroup.Count < numGroups)
{
edgesIntoGroup.Add(new Set<EdgeIndex>());
edgesOutOfGroup.Add(new Set<EdgeIndex>());
}
foreach (NodeIndex target in g.Nodes)
{
Set<NodeIndex> targetGroups = GetGroupSet(target);
foreach (EdgeIndex edge in g.EdgesInto(target))
{
NodeIndex source = g.SourceOf(edge);
Set<NodeIndex> sourceGroups = GetGroupSet(source);
foreach (NodeIndex sourceGroup in sourceGroups)
{
if (!targetGroups.Contains(sourceGroup))
edgesOutOfGroup[sourceGroup - firstGroup].Add(edge);
}
foreach (NodeIndex targetGroup in targetGroups)
{
if (!sourceGroups.Contains(targetGroup))
edgesIntoGroup[targetGroup - firstGroup].Add(edge);
}
}
}
}
private IEnumerable<EdgeIndex> GetAllEdges(NodeIndex source, NodeIndex target)
{
if (IsGroup(source))
{
foreach (var edge in edgesOutOfGroup[source - firstGroup])
{
NodeIndex target2 = g.TargetOf(edge);
if (target2 == target ||
GetGroups(target2).Any(group => group == target))
yield return edge;
}
}
else if (IsGroup(target))
{
foreach (var edge in edgesIntoGroup[target - firstGroup])
{
NodeIndex source2 = g.SourceOf(edge);
if (source2 == source ||
GetGroups(source2).Any(group => group == source))
yield return edge;
}
}
else yield return g.GetAnyEdge(source, target);
}
private EdgeIndex GetAnyEdge(NodeIndex source, NodeIndex target)
{
foreach(var edge in GetAllEdges(source, target))
{
return edge;
}
throw new EdgeNotFoundException(source, target);
}
private void CheckGroupEdges()
{
for (int node = 0; node < lastGroup; node++)
{
if (groupOf[node] != -1 && !IsGroup(groupOf[node])) throw new Exception("!IsGroup(groupOf[node])");
}
int numGroups = lastGroup - firstGroup;
for (int groupIndex = 0; groupIndex < numGroups; groupIndex++)
{
NodeIndex group = firstGroup + groupIndex;
if (edgesIntoGroup[groupIndex] == null) continue;
foreach(EdgeIndex edge in edgesIntoGroup[groupIndex])
{
NodeIndex source = g.SourceOf(edge);
Set<NodeIndex> sourceGroups = GetGroupSet(source);
NodeIndex target = g.TargetOf(edge);
Set<NodeIndex> targetGroups = GetGroupSet(target);
if (!targetGroups.Contains(group)) throw new Exception("!targetGroups.Contains(group)");
if (sourceGroups.Contains(group)) throw new Exception("sourceGroups.Contains(group)");
}
}
}
/// <summary>
/// Merge two groups that have the same parent group.
/// </summary>
/// <param name="group">Group that will receive all nodes in group2.</param>
/// <param name="group2">Group that will be empty on exit.</param>
public void MergeGroups(NodeIndex group, NodeIndex group2)
{
if (!IsGroup(group)) throw new ArgumentException($"!IsGroup(group)");
if (!IsGroup(group2)) throw new ArgumentException($"!IsGroup(group2)");
if (groupOf[group] != groupOf[group2])
throw new ArgumentException("groups do not have the same parent group");
if (group == group2)
return;
for (int node = 0; node < groupOf.Count; node++)
{
if (groupOf[node] == group2)
{
groupOf[node] = group;
}
}
NodeIndex groupIndex = group - firstGroup;
NodeIndex group2Index = group2 - firstGroup;
edgesIntoGroup[groupIndex].AddRange(edgesIntoGroup[group2Index]);
edgesIntoGroup[group2Index] = null;
edgesOutOfGroup[groupIndex].AddRange(edgesOutOfGroup[group2Index]);
edgesOutOfGroup[group2Index] = null;
// remove edges between the two groups
var edgesToRemove = edgesIntoGroup[groupIndex].Where(edgesOutOfGroup[groupIndex].Contains).ToList();
foreach (var edge in edgesToRemove)
{
edgesIntoGroup[groupIndex].Remove(edge);
edgesOutOfGroup[groupIndex].Remove(edge);
}
//CheckGroupEdges();
}
/// <summary>
/// result does not include node
/// </summary>
/// <param name="node">Node in dg</param>
/// <returns></returns>
public Set<NodeIndex> GetGroupSet(NodeIndex node)
{
return Set<NodeIndex>.FromEnumerable(GetGroups(node));
}
/// <summary>
/// Returns true if node is in group.
/// </summary>
/// <param name="node">Node in g</param>
/// <param name="group">Group in g</param>
/// <returns></returns>
public bool InGroup(NodeIndex node, NodeIndex group)
{
return GetGroups(node).Any(g => g == group);
}
public IEnumerable<NodeIndex> GetGroups(NodeIndex node)
{
NodeIndex group = node;
while (true)
{
group = groupOf[group];
if (group == -1)
break;
yield return group;
}
}
/// <summary>
/// Get the largest group of node (including node itself) that is not in the set.
/// </summary>
/// <param name="node">Node in g</param>
/// <param name="set">Groups in g</param>
/// <param name="mustBeInGroup">If true and set is not empty, result must be contained in some group in the set</param>
/// <returns></returns>
public NodeIndex GetLargestGroupExcluding(NodeIndex node, Set<NodeIndex> set, bool mustBeInGroup = false)
{
if (mustBeInGroup && set.Count == 0)
return node;
else
return GetLargestGroupExcluding(node, set.Contains, mustBeInGroup);
}
/// <summary>
/// Get the largest group of node (including node itself) that belongs to group.
/// </summary>
/// <param name="node">Node in g</param>
/// <param name="group">Group in g, or -1 (to get the largest group of node)</param>
/// <returns></returns>
public NodeIndex GetLargestGroupInsideGroup(NodeIndex node, NodeIndex group)
{
return GetLargestGroupExcluding(node, group.Equals, mustBeInGroup: (group != -1));
}
/// <summary>
/// Get the largest group of node (including node itself) that does not satisfy a predicate.
/// </summary>
/// <param name="node">Node in g</param>
/// <param name="predicate">Accepts groups in g</param>
/// <param name="mustBeInGroup">If true, result be in a group (i.e. groupOf[result] != -1)</param>
/// <returns></returns>
public NodeIndex GetLargestGroupExcluding(NodeIndex node, Predicate<NodeIndex> predicate, bool mustBeInGroup = false)
{
if (groupOf == null)
return node;
// group is in g
NodeIndex group = node;
while (true)
{
NodeIndex nextGroup = groupOf[group];
if (nextGroup == -1)
return mustBeInGroup ? -1 : group;
if (predicate(nextGroup))
return group;
group = nextGroup;
}
}
/// <summary>
/// Get the smallest group of node (including node itself) that satisfies the predicate, or -1 if none
/// </summary>
/// <param name="node">Node in g</param>
/// <param name="predicate">Accepts group in g</param>
/// <returns></returns>
public NodeIndex GetSmallestGroup(NodeIndex node, Predicate<NodeIndex> predicate)
{
NodeIndex group = node;
while (true)
{
if (predicate(group))
return group;
if (groupOf == null)
return -1;
NodeIndex nextGroup = groupOf[group];
if (nextGroup == -1)
return nextGroup;
group = nextGroup;
}
}
public List<NodeIndex> GetScheduleWithGroups(Converter<NodeIndex, IEnumerable<NodeIndex>> predecessors)
{
// toposort the forward edges to get a schedule
// algorithm: create a schedule for each group and then stitch them together
List<NodeIndex> schedule = new List<NodeIndex>();
if (dfsScheduleWithGroups == null)
{
// used by SearchFrom
dfsScheduleWithGroups = new DepthFirstSearch<NodeIndex>(predecessors, this);
dfsScheduleWithGroups.BackEdge += delegate (Edge<NodeIndex> edge)
{
List<NodeIndex> cycle = new List<NodeIndex>();
cycle.Add(edge.Target);
bool found = false;
dfsScheduleWithGroups.ForEachStackNode(delegate (NodeIndex node)
{
if (node == edge.Target)
found = true;
if (!found)
cycle.Add(node);
});
//cycle.Reverse();
NodeIndex source = cycle[cycle.Count - 1];
Debug.Write("cycle: ");
Debug.Write(IsGroup(source) ? $"[{source}] " : $"{source}");
foreach (var target in cycle)
{
foreach (EdgeIndex edge2 in GetAllEdges(source, target))
{
if (IsGroup(source))
{
Debug.Write($"{g.SourceOf(edge2)}");
}
Debug.Write($"->{g.TargetOf(edge2)} ");
}
if(IsGroup(target))
{
Debug.Write($"[{target}] ");
}
source = target;
}
Debug.WriteLine("");
throw new InferCompilerException("Cycle of forward edges");
};
dfsScheduleWithGroups.FinishNode += node => groupSchedule.Add(node);
}
else
{
dfsScheduleWithGroups.Clear();
}
// the top-level schedule. will only contain nodes/groups that are not in groups.
List<NodeIndex> topSchedule = new List<NodeIndex>();
Dictionary<NodeIndex, List<NodeIndex>> scheduleOfGroup = new Dictionary<EdgeIndex, List<NodeIndex>>();
scheduleOfGroup[-1] = topSchedule;
schedule.Clear();
// build a schedule by visiting each node and placing all predecessors on the schedule.
// predecessors are added by DFS FinishNode.
foreach (NodeIndex node in g.Nodes)
{
if(!IsGroup(node))
SearchFrom(node, scheduleOfGroup);
}
// The top-level schedule may contain references to groups, whose schedule is contained in scheduleOfGroup.
// Insert the group schedules into the top-level schedule to get one combined schedule.
ForEachLeafNode(topSchedule, scheduleOfGroup, schedule.Add);
return schedule;
}
/// <summary>
/// Invoke DFS on a graph node or group.
/// If the node/group is in a group, then search from that group first (to ensure all predecessors of the group are scheduled),
/// then search from the node, placing the node's results (which must all be within the group) on that group's schedule.
/// </summary>
/// <param name="node">Node in g</param>
/// <param name="scheduleOfGroup">Holds schedule of each group (modified on exit)</param>
private void SearchFrom(NodeIndex node, Dictionary<NodeIndex, List<NodeIndex>> scheduleOfGroup)
{
// first search from all groups of this node
NodeIndex group = groupOf[node];
if (group != -1)
SearchFrom(group, scheduleOfGroup);
if (!scheduleOfGroup.TryGetValue(group, out groupSchedule))
{
groupSchedule = new List<NodeIndex>();
scheduleOfGroup[group] = groupSchedule;
}
dfsScheduleWithGroups.SearchFrom(node);
}
/// <summary>
/// Invoke action on all nodes in order, except for groups which are expanded recursively using scheduleOfGroup
/// </summary>
/// <param name="nodes">Nodes in dg</param>
/// <param name="scheduleOfGroup">Maps groups in dg to nodes in dg</param>
/// <param name="action">Accepts a node in dg</param>
private void ForEachLeafNode(List<NodeIndex> nodes, Dictionary<NodeIndex, List<NodeIndex>> scheduleOfGroup, Action<NodeIndex> action)
{
foreach (NodeIndex node in nodes)
{
if(IsGroup(node))
{
List<NodeIndex> groupSchedule = scheduleOfGroup[node];
ForEachLeafNode(groupSchedule, scheduleOfGroup, action);
}
else
action(node);
}
}
}
}

Просмотреть файл

@ -0,0 +1,270 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// A graph of nodes and edges.
/// </summary>
/// <typeparam name="NodeType">The type of a node handle.</typeparam>
/// <remarks><p>
/// This interface is intended for use by generic algorithms which can operate on any graph type.
/// Nodes are treated as opaque handles, generated by the graph, which only the graph knows how to resolve.
/// Some graph types might represent nodes as objects with edge data; other graphs might use integer indices or even strings.
/// </p><p>
/// This interface can be used for both directed and undirected graphs, though the methods only provide
/// undirected information (i.e. all neighbors rather than sources vs. targets).
/// See <see cref="IMutableGraph&lt;NodeType&gt;"/> for methods to modify a graph, and <see cref="IDirectedGraph&lt;NodeType&gt;"/>
/// to get directed edge information.
/// In this interface, edges are represented implicitly as (source, target) pairs.
/// See <see cref="IGraph&lt;NodeType,EdgeType&gt;"/> for methods using explicit edge handles.
/// </p></remarks>
internal interface IGraph<NodeType>
{
/// <summary>
/// Collection of node handles.
/// </summary>
/// <remarks>
/// The methods Add(NodeType) and Contains(NodeType) which are defined by <see cref="ICollection&lt;NodeType&gt;"/>
/// are not necessarily supported (in order to use them, you would have to have a valid node handle, which
/// implies the node is already in the graph). The methods Remove(NodeType) and Clear() are only supported
/// by mutable graphs, and they are only guaranteed to remove node handles from the Nodes collection;
/// the edges connected to those nodes may still exist in the graph.
/// </remarks>
ICollection<NodeType> Nodes { get; }
int EdgeCount();
/// <summary>
/// Number of adjacent nodes.
/// </summary>
/// <param name="node">A node handle.</param>
/// <returns>The number of nodes adjacent to <paramref name="node"/>.</returns>
/// <remarks>For a directed graph, this is the number of sources (parents) plus the number of targets (children) of <paramref name="node"/></remarks>
int NeighborCount(NodeType node);
// if NeighborsOf were an ICollection, it would subsume NeighborCount and ClearEdgesOf.
// however, these would usually be slower than direct method calls on the graph.
/// <summary>
/// Adjacent nodes.
/// </summary>
/// <param name="node">A node handle.</param>
/// <returns>The nodes adjacent to <paramref name="node"/></returns>
/// <remarks>For a directed graph, this is the sources (parents) and targets (children) of <paramref name="node"/>, in any order.</remarks>
IEnumerable<NodeType> NeighborsOf(NodeType node);
/// <summary>
/// Test for an edge.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <returns>True if an edge exists from <paramref name="source"/> to <paramref name="target"/></returns>
/// <remarks>
/// For some graph types, source and target are not required to be nodes in the graph.
/// That is, ContainsEdge(source,target) can be true even if Nodes.Contains(source) is false.
/// </remarks>
bool ContainsEdge(NodeType source, NodeType target);
}
internal interface IMutableGraph<NodeType> : IGraph<NodeType>
{
NodeType AddNode();
bool RemoveNodeAndEdges(NodeType node);
void Clear();
void AddEdge(NodeType source, NodeType target);
bool RemoveEdge(NodeType source, NodeType target);
void ClearEdges();
void ClearEdgesOf(NodeType node);
}
internal interface IDirectedGraph<NodeType> : IGraph<NodeType>
{
int TargetCount(NodeType source);
int SourceCount(NodeType target);
IEnumerable<NodeType> TargetsOf(NodeType source);
IEnumerable<NodeType> SourcesOf(NodeType target);
}
internal interface IMutableDirectedGraph<NodeType> : IDirectedGraph<NodeType>, IMutableGraph<NodeType>
{
void ClearEdgesOutOf(NodeType source);
void ClearEdgesInto(NodeType target);
}
/// <summary>
/// A graph with explicit edge handles.
/// </summary>
/// <typeparam name="NodeType">The type of a node handle.</typeparam>
/// <typeparam name="EdgeType">The type of an edge handle.</typeparam>
/// <remarks><p>
/// This interface is intended for use by generic algorithms which can operate on any graph type.
/// Nodes and edges are treated as opaque handles, generated by the graph, which only the graph knows how to resolve.
/// As a consequence, the methods AddEdge(EdgeType) and ContainsEdge(EdgeType) are not necessarily supported, because
/// in order to use them, you would have to have a valid edge handle, which implies the edge already exists in the graph.
/// </p><p>
/// This interface can be used for both directed and undirected graphs, though the methods only provide
/// undirected information (i.e. all edges rather than in-edges vs. out-edges).
/// See <see cref="IMutableGraph&lt;NodeType,EdgeType&gt;"/> for methods to modify a graph, and <see cref="IDirectedGraph&lt;NodeType,EdgeType&gt;"/>
/// to get directed edge information.
/// </p></remarks>
internal interface IGraph<NodeType, EdgeType> : IGraph<NodeType>
{
IEnumerable<EdgeType> Edges { get; }
/// <summary>
/// Get an edge handle.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <returns>An edge handle if the edge exists and is unique.</returns>
/// <exception cref="EdgeNotFoundException">If there is no edge from source to target.</exception>
/// <exception cref="AmbiguousEdgeException">If there is more than one edge from source to target.</exception>
EdgeType GetEdge(NodeType source, NodeType target);
/// <summary>
/// Get an edge handle.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <param name="edge">An edge handle if the edge exists and is unique, otherwise <c>default(EdgeType)</c>.</param>
/// <returns>True if there is an edge from source to target.</returns>
/// <exception cref="AmbiguousEdgeException">If there is more than one edge from source to target.</exception>
/// <remarks>This method combines the functionality of ContainsEdge(source,target) and GetEdge(source,target).</remarks>
bool TryGetEdge(NodeType source, NodeType target, out EdgeType edge);
/// <summary>
/// All edges connected to a node.
/// </summary>
/// <param name="node">A node handle.</param>
/// <returns>All edges connected to <paramref name="node"/>.</returns>
IEnumerable<EdgeType> EdgesOf(NodeType node);
}
/// <summary>
/// A graph which may have parallel edges.
/// </summary>
/// <typeparam name="NodeType">The type of a node handle.</typeparam>
/// <typeparam name="EdgeType">The type of an edge handle.</typeparam>
internal interface IMultigraph<NodeType, EdgeType> : IGraph<NodeType, EdgeType>
{
/// <summary>
/// Count the edges between nodes.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <returns>The number of edges from <paramref name="source"/> to <paramref name="target"/>.</returns>
int EdgeCount(NodeType source, NodeType target);
/// <summary>
/// Get edge handles.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <returns>All edges from source to target.</returns>
IEnumerable<EdgeType> EdgesLinking(NodeType source, NodeType target);
/// <summary>
/// Get an edge handle.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <param name="edge">An edge handle if an edge exists, otherwise <c>default(EdgeType)</c>.</param>
/// <returns>True if there is an edge from source to target.</returns>
/// <remarks>This method combines the functionality of ContainsEdge(source,target) and GetAnyEdge(source,target).</remarks>
bool AnyEdge(NodeType source, NodeType target, out EdgeType edge);
}
internal interface IMutableGraph<NodeType, EdgeType> : IGraph<NodeType, EdgeType>, IMutableGraph<NodeType>
{
new EdgeType AddEdge(NodeType source, NodeType target);
bool RemoveEdge(EdgeType edge);
}
internal interface IDirectedGraph<NodeType, EdgeType> : IGraph<NodeType, EdgeType>, IDirectedGraph<NodeType>
{
NodeType SourceOf(EdgeType edge);
NodeType TargetOf(EdgeType edge);
IEnumerable<EdgeType> EdgesOutOf(NodeType source);
IEnumerable<EdgeType> EdgesInto(NodeType target);
}
internal interface IMutableDirectedGraph<NodeType, EdgeType> : IDirectedGraph<NodeType, EdgeType>, IMutableGraph<NodeType, EdgeType>, IMutableDirectedGraph<NodeType>
{
}
internal class EdgeNotFoundException : Exception
{
public EdgeNotFoundException()
{
}
public EdgeNotFoundException(object source, object target)
: base("no edge from " + source + " to " + target)
{
}
}
internal class AmbiguousEdgeException : Exception
{
public AmbiguousEdgeException()
{
}
public AmbiguousEdgeException(object source, object target)
: base("ambiguous edge from " + source + " to " + target)
{
}
}
// note that "labeled graph" means something different in graph theory:
// http://mathworld.wolfram.com/LabeledGraph.html
internal interface ILabeledGraph<NodeType, LabelType> : IGraph<NodeType>
{
new ILabeledCollection<NodeType, LabelType> Nodes { get; }
}
internal interface ILabeledEdgeGraph<NodeType, EdgeType> : IGraph<NodeType>
{
void AddEdge(NodeType fromNode, NodeType toNode, EdgeType label);
void RemoveEdge(NodeType fromNode, NodeType toNode, EdgeType label);
void ClearEdges(EdgeType label);
}
/// <summary>
/// An interface for attaching data to node handles.
/// </summary>
/// <typeparam name="NodeType">The type of a node handle.</typeparam>
/// <remarks>
/// This interface allows general graph algorithms to attach data to graph nodes.
/// </remarks>
internal interface CanCreateNodeData<NodeType>
{
/// <summary>
/// Create a mapping from node handles to data.
/// </summary>
/// <typeparam name="T">The type of data to store.</typeparam>
/// <returns>A mapping initialized to defaultValue.</returns>
IndexedProperty<NodeType, T> CreateNodeData<T>(T defaultValue);
}
/// <summary>
/// An interface for attaching data to edge handles.
/// </summary>
/// <typeparam name="EdgeType">The type of an edge handle.</typeparam>
/// <remarks>
/// This interface allows general graph algorithms to attach data to graph edges.
/// </remarks>
internal interface CanCreateEdgeData<EdgeType>
{
/// <summary>
/// Create a mapping from edge handles to data.
/// </summary>
/// <typeparam name="T">The type of data to store.</typeparam>
/// <returns>A mapping initialized to defaultValue.</returns>
IndexedProperty<EdgeType, T> CreateEdgeData<T>(T defaultValue);
}
}

Просмотреть файл

@ -0,0 +1,346 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Text;
using System.Linq;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
internal class IndexedGraph : IMutableDirectedGraph<int, int>, IMultigraph<int, int>,
CanCreateNodeData<int>, CanCreateEdgeData<int>
{
protected readonly List<List<int>> inEdges;
protected readonly List<List<int>> outEdges;
protected readonly List<Edge<int>> edges;
public bool IsReadOnly;
public bool NodeCountIsConstant;
public IndexedGraph()
{
inEdges = new List<List<int>>();
outEdges = new List<List<int>>();
edges = new List<Edge<int>>();
}
public IndexedGraph(int nodeCount)
{
inEdges = new List<List<int>>(nodeCount);
outEdges = new List<List<int>>(nodeCount);
for (int i = 0; i < nodeCount; i++)
{
AddNode();
}
edges = new List<Edge<int>>();
NodeCountIsConstant = true;
}
public int SourceOf(int edge)
{
return edges[edge].Source;
}
public int TargetOf(int edge)
{
return edges[edge].Target;
}
public IEnumerable<int> EdgesOutOf(int source)
{
return outEdges[source];
}
public IEnumerable<int> EdgesInto(int target)
{
return inEdges[target];
}
public IEnumerable<int> Edges
{
get { return new Range(0, edges.Count); }
}
public int GetEdge(int source, int target)
{
int edge;
if (TryGetEdge(source, target, out edge)) return edge;
else throw new EdgeNotFoundException(source, target);
}
public bool TryGetEdge(int source, int target, out int edge)
{
edge = 0;
bool found = false;
foreach (int outEdge in EdgesOutOf(source))
{
if (edges[outEdge].Target == target)
{
if (found) throw new AmbiguousEdgeException(source, target);
edge = outEdge;
found = true;
}
}
return found;
}
public int EdgeCount(int source, int target)
{
return EdgesOutOf(source).Count(HasTarget(target));
//return Enumerable.Count(EdgesOutOf(source), HasTarget(target));
}
public IEnumerable<int> EdgesLinking(int source, int target)
{
foreach (int edge in EdgesOutOf(source))
{
if (edges[edge].Target == target) yield return edge;
}
}
/// <summary>
/// Get an edge handle.
/// </summary>
/// <param name="source">A node handle.</param>
/// <param name="target">A node handle.</param>
/// <returns>An edge handle if an edge exists.</returns>
/// <exception cref="EdgeNotFoundException">If there is no edge from source to target.</exception>
public int GetAnyEdge(int source, int target)
{
int result;
if (AnyEdge(source, target, out result))
{
return result;
}
else
{
throw new EdgeNotFoundException(source, target);
}
}
public bool AnyEdge(int source, int target, out int edge)
{
edge = 0;
foreach (int tryEdge in EdgesOutOf(source))
{
if (edges[tryEdge].Target == target)
{
edge = tryEdge;
return true;
}
}
return false;
}
public IEnumerable<int> EdgesOf(int node)
{
List<int> result = new List<int>(EdgesInto(node));
result.AddRange(EdgesOutOf(node));
return result;
// This version does not deal properly with self-loops (returns same edge twice)
//return Enumerable.Join(EdgesInto(node), EdgesOutOf(node));
}
public ICollection<int> Nodes
{
get { return new Range(0, inEdges.Count); }
}
public int EdgeCount()
{
return edges.Count;
}
public int NeighborCount(int node)
{
return NeighborsOf(node).Count();
}
public IEnumerable<int> NeighborsOf(int node)
{
return SourcesOf(node).Concat(TargetsOf(node));
}
public bool ContainsEdge(int source, int target)
{
//return Enumerable.Exists(EdgesOutOf(source), HasTarget(target));
// same as above, but inlined
foreach (int edge in outEdges[source])
{
if (edges[edge].Target == target) return true;
}
return false;
}
public Func<int, bool> HasTarget(int target)
{
return delegate(int edge) { return (edges[edge].Target == target); };
}
public int TargetCount(int source)
{
return TargetsOf(source).Count();
}
public int SourceCount(int target)
{
return SourcesOf(target).Count();
}
public IEnumerable<int> TargetsOf(int source)
{
foreach (int edge in EdgesOutOf(source))
{
yield return edges[edge].Target;
}
}
public IEnumerable<int> SourcesOf(int target)
{
foreach (int edge in EdgesInto(target))
{
yield return edges[edge].Source;
}
}
public int AddEdge(int source, int target)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
int edge = edges.Count;
edges.Add(new Edge<int>(source, target));
if (outEdges[source] == null) outEdges[source] = new List<int>();
outEdges[source].Add(edge);
if (inEdges[target] == null) inEdges[target] = new List<int>();
inEdges[target].Add(edge);
return edge;
}
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning disable 162
#endif
public bool RemoveEdge(int edge)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
throw new Exception("The method or operation is not implemented.");
Edge<int> edgeStruct = edges[edge];
outEdges[edgeStruct.Source].Remove(edge);
inEdges[edgeStruct.Target].Remove(edge);
// FIXME edge remains in edges array.
}
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning restore 162
#endif
public int AddNode()
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
if (NodeCountIsConstant) throw new NotSupportedException("The graph size cannot be changed.");
int node = inEdges.Count;
inEdges.Add(new List<int>());
outEdges.Add(new List<int>());
return node;
}
public bool RemoveNodeAndEdges(int node)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
if (NodeCountIsConstant) throw new NotSupportedException("The graph size cannot be changed.");
throw new Exception("The method or operation is not implemented.");
}
public void Clear()
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
if (NodeCountIsConstant) throw new NotSupportedException("The graph size cannot be changed.");
throw new Exception("The method or operation is not implemented.");
}
void IMutableGraph<int>.AddEdge(int source, int target)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
AddEdge(source, target);
}
public bool RemoveEdge(int source, int target)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
throw new Exception("The method or operation is not implemented.");
}
public void ClearEdges()
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
throw new Exception("The method or operation is not implemented.");
}
public void ClearEdgesOf(int node)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
throw new Exception("The method or operation is not implemented.");
}
public void ClearEdgesOutOf(int source)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
throw new Exception("The method or operation is not implemented.");
}
public void ClearEdgesInto(int target)
{
if (IsReadOnly) throw new NotSupportedException("Graph is read only");
throw new Exception("The method or operation is not implemented.");
}
public IndexedProperty<int, T> CreateNodeData<T>(T defaultValue = default(T))
{
if (NodeCountIsConstant)
{
T[] data = new T[inEdges.Count];
IndexedProperty<int, T> prop = MakeIndexedProperty.FromArray<T>(data, defaultValue);
prop.Clear();
return prop;
}
else
{
return new IndexedProperty<int, T>(new Dictionary<int, T>(), defaultValue);
}
}
public IndexedProperty<int, T> CreateEdgeData<T>(T defaultValue = default(T))
{
if (IsReadOnly)
{
T[] data = new T[edges.Count];
IndexedProperty<int, T> prop = MakeIndexedProperty.FromArray<T>(data, defaultValue);
prop.Clear();
return prop;
}
else
{
return new IndexedProperty<int, T>(new Dictionary<int, T>(), defaultValue);
}
}
public override string ToString()
{
StringBuilder s = new StringBuilder();
foreach (int node in Nodes)
{
s.AppendFormat("{0} -> ", node);
bool first = true;
foreach (int target in TargetsOf(node))
{
if (!first) s.Append(" ");
else first = false;
s.Append(target.ToString(CultureInfo.InvariantCulture));
}
s.AppendLine();
}
return s.ToString();
}
}
}

Просмотреть файл

@ -0,0 +1,138 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Create a graph from a collection of nodes and corresponding adjacency data.
/// </summary>
/// <typeparam name="Node">The type of a node handle.</typeparam>
/// <typeparam name="NodeInfo">The type of a node data structure.</typeparam>
/// <remarks><para>
/// This class creates a directed graph object from a collection of node handles and data objects.
/// The data object is assumed to hold the adjacency information for the handles, in the form
/// of a delegate sourcesOfNode(data) which returns a collection of node handles.
/// Each node is given an integer index.
/// Using these integers you can attach additional data to the graph via CreateNodeData which returns an array.
/// </para><para>
/// In the simplest case, NodeInfo can be the same as Node in which case this class just stores
/// a mapping from nodes to integers and vice versa.
/// (The delegate infoOfNode would simply return its argument.)
/// More generally, NodeInfo can store cached information about the node.
/// </para></remarks>
internal class IndexedGraphWrapper<Node, NodeInfo> : IDirectedGraph<int>, CanCreateNodeData<int>
{
public NodeInfo[] info;
/// <summary>Provides an index (into info[]) for each node.</summary>
public IndexedProperty<Node, int> indexOfNode;
protected Converter<NodeInfo, ICollection<Node>> sourcesOfNode;
public ICollection<int>[] targetsOfNode;
public IndexedGraphWrapper(ICollection<Node> nodes, Converter<Node, NodeInfo> infoOfNode,
Converter<NodeInfo, ICollection<Node>> sourcesOfNode)
: this(nodes, infoOfNode, sourcesOfNode,
new IndexedProperty<Node, int>(new Dictionary<Node, int>()))
{
}
public IndexedGraphWrapper(ICollection<Node> nodes,
Converter<Node, NodeInfo> infoOfNode, Converter<NodeInfo, ICollection<Node>> sourcesOfNode,
IndexedProperty<Node, int> indexOfNode)
{
info = new NodeInfo[nodes.Count];
this.indexOfNode = indexOfNode;
this.sourcesOfNode = sourcesOfNode;
//new IndexedProperty<Node,int>(new Dictionary<Node,int>());
int i = 0;
foreach (Node node in nodes)
{
info[i] = infoOfNode(node);
indexOfNode[node] = i;
i++;
}
targetsOfNode = new ICollection<int>[info.Length];
}
public int TargetCount(int source)
{
ICollection<int> targets = targetsOfNode[source];
return (targets != null) ? targets.Count : 0;
}
public int SourceCount(int target)
{
return sourcesOfNode(info[target]).Count;
}
public IEnumerable<int> TargetsOf(int source)
{
ICollection<int> targets = targetsOfNode[source];
if (targets != null)
{
foreach (int target in targets)
{
yield return target;
}
}
}
public IEnumerable<int> SourcesOf(int target)
{
foreach (Node sourceNode in sourcesOfNode(info[target]))
{
int source = indexOfNode[sourceNode];
ICollection<int> targets = targetsOfNode[source];
if (targets == null) targets = new Set<int>();
targets.Add(target);
targetsOfNode[source] = targets;
yield return source;
}
}
public ICollection<int> Nodes
{
get { return new Range(0, info.Length); }
}
public int EdgeCount()
{
throw new Exception("The method or operation is not implemented.");
}
public int NeighborCount(int node)
{
return SourceCount(node) + TargetCount(node);
}
public IEnumerable<int> NeighborsOf(int node)
{
return SourcesOf(node).Concat(TargetsOf(node));
}
public bool ContainsEdge(int source, int target)
{
foreach (Node targetNode in sourcesOfNode(info[target]))
{
if (indexOfNode[targetNode] == target) return true;
}
return false;
//return info[target].Sources.Contains(nodeOf(info[source]));
}
public IndexedProperty<int, T> CreateNodeData<T>(T defaultValue)
{
T[] data = new T[info.Length];
IndexedProperty<int, T> prop = MakeIndexedProperty.FromArray<T>(data, defaultValue);
prop.Clear();
return prop;
}
}
}

Просмотреть файл

@ -0,0 +1,102 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
internal struct IndexedProperty<KeyType, ValueType>
{
/// <summary>
/// Delegate for retrieving data at an index.
/// </summary>
public Converter<KeyType, ValueType> Get;
/// <summary>
/// Delegate for setting data at an index.
/// </summary>
public Action<KeyType, ValueType> Set;
/// <summary>
/// Delegate for clearing the mapping.
/// </summary>
/// <remarks>
/// If the mapping has a default value, this sets all data to that value.
/// Otherwise the mapping is undefined at all values.
/// </remarks>
public Action Clear;
/// <summary>
/// Get or set data at an index.
/// </summary>
/// <param name="key"></param>
/// <returns></returns>
public ValueType this[KeyType key]
{
get { return Get(key); }
set { Set(key, value); }
}
public IndexedProperty(Converter<KeyType, ValueType> getter, Action<KeyType, ValueType> setter, Action clearer)
{
Get = getter;
Set = setter;
Clear = clearer;
}
public IndexedProperty(IDictionary<KeyType, ValueType> dictionary, ValueType defaultValue = default(ValueType))
{
Get = delegate(KeyType key)
{
ValueType value;
bool containsKey = dictionary.TryGetValue(key, out value);
if (!containsKey) return defaultValue;
else return value;
};
Set = delegate(KeyType key, ValueType value) { dictionary[key] = value; };
Clear = dictionary.Clear;
}
}
internal static class MakeIndexedProperty
{
public static IndexedProperty<int, T> FromArray<T>(T[] array, T defaultValue = default(T))
{
return new IndexedProperty<int, T>(
delegate(int key) { return array[key]; },
delegate(int key, T value) { array[key] = value; },
delegate()
{
if (ReferenceEquals(defaultValue, null) || defaultValue.Equals(default(T)))
Array.Clear(array, 0, array.Length);
else
{
for (int i = 0; i < array.Length; i++)
{
array[i] = defaultValue;
}
}
});
}
public static IndexedProperty<T, bool> FromSet<T>(ICollection<T> set)
{
return new IndexedProperty<T, bool>(
delegate (T key)
{ return set.Contains(key); },
delegate (T key, bool value)
{
if (value)
set.Add(key);
else
set.Remove(key);
},
delegate ()
{
set.Clear();
});
}
}
}

Просмотреть файл

@ -0,0 +1,101 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
// The resulting ICollection is read-only.
internal class JoinCollections<T> : ICollection<T>
{
private ICollection<ICollection<T>> collections;
public JoinCollections(ICollection<ICollection<T>> collections)
{
this.collections = collections;
}
public JoinCollections(params ICollection<T>[] collections)
{
this.collections = collections;
}
#region IEnumerable<T> methods
public IEnumerator<T> GetEnumerator()
{
foreach (ICollection<T> collection in collections)
{
foreach (T value in collection)
{
yield return value;
}
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
#endregion
#region ICollection<T> methods
public int Count
{
get
{
int count = 0;
foreach (ICollection<T> collection in collections)
{
count += collection.Count;
}
return count;
}
}
public bool IsReadOnly
{
get { return true; }
}
public void Add(T item)
{
throw new NotSupportedException();
}
public void Clear()
{
throw new NotSupportedException();
}
public bool Contains(T item)
{
foreach (ICollection<T> collection in collections)
{
if (collection.Contains(item)) return true;
}
return false;
}
public void CopyTo(T[] array, int index)
{
foreach (ICollection<T> collection in collections)
{
collection.CopyTo(array, index);
index += collection.Count;
}
}
public bool Remove(T item)
{
throw new NotSupportedException();
}
#endregion
}
}

Просмотреть файл

@ -0,0 +1,233 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Runtime.Serialization;
using System.Text;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// A collection which is the union of labeled subcollections.
/// </summary>
/// <typeparam name="ItemType"></typeparam>
/// <typeparam name="LabelType"></typeparam>
/// <remarks><p>
/// This interface corresponds to an inverted index, where each label maps to a subset of items.
/// This is unlike an IDictionary, where each key maps to a single item.
/// When the labels are sparse, an inverted index is more efficient than attaching a label to each item and searching
/// through the items.
/// The subcollections may overlap.
/// </p><p>
/// Add(ItemType) is equivalent to using the default label, which may be default(LabelType) or some
/// other value such as an empty string ("").
/// The other ICollection methods, such as Count and Clear, apply to all items regardless of label.
/// </p></remarks>
internal interface ILabeledCollection<ItemType, LabelType> : ICollection<ItemType>
{
/// <summary>
/// The labels of the subcollections.
/// </summary>
/// <remarks>
/// Labels must be unique. The returned collection must not be modified.
/// </remarks>
ICollection<LabelType> Labels { get; }
/// <summary>
/// Get a subcollection.
/// </summary>
/// <param name="label">The label of an existing subcollection or a subcollection to be created.</param>
/// <returns>A subcollection of items.</returns>
/// <remarks>If the subcollection already exists, it is returned. Otherwise, a new subcollection is created and returned.
/// The result is mutable. Some LabeledCollection classes may not allow certain labels.
/// </remarks>
/// <exception cref="InvalidLabelException">If the label is not allowed by the collection.</exception>
ICollection<ItemType> WithLabel(LabelType label);
}
internal class InvalidLabelException : Exception
{
public object Label;
public InvalidLabelException(object label)
{
Label = label;
}
public InvalidLabelException()
{
}
// This constructor is needed for serialization.
protected InvalidLabelException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
// This is thrown if you try to Add to a list which is full.
// It is more informative than NotSupportedException.
internal class ListOverflowException : Exception
{
public ListOverflowException()
{
}
// This constructor is needed for serialization.
protected ListOverflowException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
/// <summary>
/// A default implementation of ILabeledCollection.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <typeparam name="LabelType"></typeparam>
/// <remarks>
/// This base class implements all of the ICollection methods in terms of the
/// two ILabeledCollection methods.
/// This makes it easy to create new ILabeledCollection classes, just by
/// implementing the two ILabeledCollection methods.
/// It assumes that the subcollections do not overlap.
/// </remarks>
internal abstract class LabeledCollection<T, LabelType> : ILabeledCollection<T, LabelType>
{
public abstract ICollection<LabelType> Labels { get; }
public abstract ICollection<T> WithLabel(LabelType label);
public virtual LabelType DefaultLabel
{
get { return default(LabelType); }
}
#if zero
//ICollection ILabeledCollection<T>.Labels {
public virtual ICollection Labels {
get { throw new NotSupportedException(); }
}
//ICollection<T> ILabeledCollection<T>.WithLabel(object label)
public virtual ICollection<T> WithLabel(object label)
{
throw new NotSupportedException();
}
#endif
#region IEnumerable methods
public virtual IEnumerator<T> GetEnumerator()
{
foreach (LabelType label in Labels)
{
foreach (T value in WithLabel(label))
{
yield return value;
}
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
#endregion
#region ICollection methods
public virtual int Count
{
get
{
int count = 0;
foreach (LabelType label in Labels)
{
count += WithLabel(label).Count;
}
return count;
}
}
public virtual bool IsReadOnly
{
get { return false; }
}
public virtual void Add(T item)
{
WithLabel(DefaultLabel).Add(item);
}
public virtual void Clear()
{
foreach (LabelType label in Labels)
{
WithLabel(label).Clear();
}
}
public virtual bool Contains(T item)
{
foreach (LabelType label in Labels)
{
if (WithLabel(label).Contains(item))
{
return true;
}
}
return false;
}
public virtual void CopyTo(T[] array, int index)
{
foreach (LabelType label in Labels)
{
ICollection<T> list = WithLabel(label);
list.CopyTo(array, index);
index += list.Count;
}
}
public virtual bool Remove(T item)
{
foreach (LabelType label in Labels)
{
ICollection<T> list = WithLabel(label);
if (list.Contains(item))
{
// remove only one instance
return list.Remove(item);
}
}
return false;
}
#endregion
public override string ToString()
{
StringBuilder s = new StringBuilder();
foreach (LabelType label in Labels)
{
if (!label.Equals(DefaultLabel))
{
s.Append(String.Format("<{0}>", label));
}
int count = 0;
foreach (T value in WithLabel(label))
{
if (count > 0) s.Append(" ");
count++;
s.Append(value);
}
if (!label.Equals(DefaultLabel))
{
s.Append(String.Format("</{0}>", label));
}
}
return s.ToString();
}
}
}

Просмотреть файл

@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Compiler.Graphs;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
internal delegate ICollection<T> CollectionWrapperFactory<T>(ICollection<T> list);
/// <summary>
/// A base class for LabeledCollection wrapper classes.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <typeparam name="LabelType"></typeparam>
/// <typeparam name="ListType"></typeparam>
/// <remarks>
/// This class makes it easy to write decorators for LabeledCollections.
/// </remarks>
internal class LabeledCollectionWrapper<T, LabelType, ListType> : CollectionWrapper<T, ListType>, ILabeledCollection<T, LabelType>
where ListType : ILabeledCollection<T, LabelType>
{
protected CollectionWrapperFactory<T> factory;
protected LabeledCollectionWrapper()
{
}
public LabeledCollectionWrapper(ListType list, CollectionWrapperFactory<T> factory)
: base(list)
{
this.factory = factory;
}
#region ILabeledCollection methods
public ICollection<LabelType> Labels
{
get { return list.Labels; }
}
public ICollection<T> WithLabel(LabelType label)
{
return factory(list.WithLabel(label));
}
#endregion
}
}

Просмотреть файл

@ -0,0 +1,81 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// A list which is a union of labeled sublists.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <typeparam name="LabelType"></typeparam>
/// <remarks>
/// This class can be used to represent an inverted index, where each label maps to a list of items.
/// It is implemented by a Dictionary of List objects, so the labels can have any type.
/// </remarks>
internal class LabeledList<T, LabelType> : LabeledCollection<T, LabelType>
{
private Dictionary<LabelType, ICollection<T>> dictionary;
public LabelType defaultLabel;
public override LabelType DefaultLabel
{
get { return defaultLabel; }
}
public LabeledList()
{
dictionary = new Dictionary<LabelType, ICollection<T>>();
}
public LabeledList(LabelType defaultLabel) : this()
{
this.defaultLabel = defaultLabel;
}
/// <summary>
/// Copy constructor.
/// </summary>
/// <param name="list"></param>
public LabeledList(LabeledList<T, LabelType> list) : this()
{
// copy the elements into this list
foreach (LabelType label in list.Labels)
{
foreach (T item in list.WithLabel(label))
{
WithLabel(label).Add(item);
}
}
}
#region LabeledCollection methods
public override ICollection<LabelType> Labels
{
get { return dictionary.Keys; }
}
public override ICollection<T> WithLabel(LabelType label)
{
ICollection<T> list;
if (! dictionary.TryGetValue(label, out list))
{
// first time using this label
list = new List<T>();
dictionary[label] = list;
}
return list;
}
#endregion
// this is more efficient than the default implementation
public override void Clear()
{
dictionary.Clear();
}
}
}

Просмотреть файл

@ -0,0 +1,75 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Compiler.Graphs;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// A set which is the union of labeled subsets.
/// </summary>
/// <typeparam name="ItemType"></typeparam>
/// <typeparam name="LabelType"></typeparam>
/// <remarks>
/// It is implemented as a LabeledList where Add is overridden to prevent adding duplicates.
/// </remarks>
internal class LabeledSet<ItemType, LabelType> : LabeledSetWrapper<ItemType, LabelType, LabeledList<ItemType, LabelType>>
{
public LabeledSet()
: base(new LabeledList<ItemType, LabelType>())
{
}
public LabeledSet(LabelType defaultLabel)
: base(new LabeledList<ItemType, LabelType>(defaultLabel))
{
}
}
internal class LabeledSetWrapper<NodeType, LabelType, ListType> : LabeledCollectionWrapper<NodeType, LabelType, ListType>
where ListType : ILabeledCollection<NodeType, LabelType>
{
public LabeledSetWrapper(ListType list)
: base(list, null)
{
factory = delegate(ICollection<NodeType> sublist) { return new NodeListWrapper(sublist, this); };
}
// These routines are adapted from NodeListWrapper
#region ICollection methods
// value can be a duplicate node, but it won't be added again.
public override void Add(NodeType node)
{
if (!list.Contains(node)) list.Add(node);
}
#endregion
// Wraps a node list to ensure that nodes are unique, and that
// graph edges are removed when a node is removed.
private class NodeListWrapper : CollectionWrapper<NodeType, ICollection<NodeType>>
{
private LabeledSetWrapper<NodeType, LabelType, ListType> set;
public NodeListWrapper(ICollection<NodeType> list, LabeledSetWrapper<NodeType, LabelType, ListType> set)
: base(list)
{
this.set = set;
}
#region ICollection methods
// value can be a duplicate node, but it won't be added again.
public override void Add(NodeType item)
{
if (!set.Contains(item)) list.Add(item);
}
#endregion
}
}
}

Просмотреть файл

@ -0,0 +1,295 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
// Reference: Jianxiu Hao and James B. Orlin,
// "A faster algorithm for finding the minimum cut in a directed graph",
// Journal of Algorithms 17: 424--446 (1994).
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Finds a minimum edge cut to separate a set of sources from a set of sinks
/// </summary>
/// <typeparam name="NodeType">Node type</typeparam>
/// <typeparam name="EdgeType">Edge type</typeparam>
/// <remarks><para>
/// A node can be both a source and a sink. In this case, edges into the node are considered toward the sink
/// and edges out of the node are considered from the source. Thus the algorithm will cut all paths from the
/// node back to itself.
/// </para><para>
/// By modifying IsSinkEdge, certain edges can be labelled as "sink edges". These edges are treated as if
/// their target was a sink node. Making a node both a source and sink is equivalent to making it a source and
/// labelling its inward edges as sink edges.
/// </para><para>
/// The implementation uses the preflow-push algorithm, modified to return the minimum cut,
/// as described by Jianxiu Hao and James B. Orlin,
/// "A faster algorithm for finding the minimum cut in a directed graph",
/// Journal of Algorithms 17: 424--446 (1994).
/// </para><para>
/// This algorithm can sometimes be very slow for certain choices of the capacities.
/// This seems to be caused by loss of precision in the float calculations, e.g.
/// when a edge with capacity 1e-8 pushes flow into an edge with capacity 1e+8.
/// </para></remarks>
internal class MinCut<NodeType, EdgeType>
{
protected IDirectedGraph<NodeType, EdgeType> graph;
protected Func<EdgeType, float> capacity;
/// <summary>
/// Capacity of an edge in the reverse direction (default 0)
/// </summary>
public Func<EdgeType, float> reverseCapacity;
/// <summary>
/// The set of source nodes
/// </summary>
public Set<NodeType> Sources = new Set<NodeType>();
/// <summary>
/// The set of sink nodes
/// </summary>
public Set<NodeType> Sinks = new Set<NodeType>();
/// <summary>
/// The set of sink edges
/// </summary>
public Func<EdgeType, bool> IsSinkEdge;
private IndexedProperty<NodeType, int> distanceToSink; // also called "height"
/// <summary>
/// A cache of the nodes at a given distanceToSink
/// </summary>
private Dictionary<int, Set<NodeType>> nodesAtDistance = new Dictionary<int, Set<NodeType>>();
/// <summary>
/// The flow in the direction of the edge (always between -reverseCapacity and the capacity)
/// </summary>
private IndexedProperty<EdgeType, float> flow;
/// <summary>
/// A cache of (inward flow - outward flow)
/// </summary>
private IndexedProperty<NodeType, float> excess;
/// <summary>
/// The nodes on the source side of the cut
/// </summary>
private Set<NodeType> sourceGroup = new Set<NodeType>();
/// <summary>
/// All nodes that are not in the sourceGroup, not a sink, and have inward flow > outward flow
/// </summary>
private Set<NodeType> activeNodes = new Set<NodeType>();
public MinCut(IDirectedGraph<NodeType, EdgeType> graph, Func<EdgeType, float> capacity)
{
this.graph = graph;
this.capacity = capacity;
reverseCapacity = e => 0f;
distanceToSink = ((CanCreateNodeData<NodeType>) graph).CreateNodeData<int>(1);
flow = ((CanCreateEdgeData<EdgeType>) graph).CreateEdgeData<float>(0f);
excess = ((CanCreateNodeData<NodeType>) graph).CreateNodeData<float>(0f);
IsSinkEdge = e => false;
}
private void Initialize()
{
distanceToSink.Clear();
flow.Clear();
excess.Clear();
activeNodes.Clear();
sourceGroup.Clear();
sourceGroup.AddRange(Sources);
nodesAtDistance[0] = Set<NodeType>.FromEnumerable(Sinks);
foreach (NodeType sink in Sinks) distanceToSink[sink] = 0;
Set<NodeType> nodesAtDistance1 = new Set<NodeType>();
foreach (NodeType node in graph.Nodes)
{
if (!Sources.Contains(node) && !Sinks.Contains(node)) nodesAtDistance1.Add(node);
}
nodesAtDistance[1] = nodesAtDistance1;
foreach (NodeType source in Sources)
{
if (!Sinks.Contains(source)) distanceToSink[source] = int.MaxValue;
foreach (EdgeType edge in graph.EdgesOutOf(source))
{
float f = capacity(edge);
flow[edge] = f;
if (!IsSinkEdge(edge))
{
NodeType target = graph.TargetOf(edge);
if (!Sources.Contains(target) && !Sinks.Contains(target))
{
excess[target] += f;
activeNodes.Add(target);
}
}
}
foreach (EdgeType edge in graph.EdgesInto(source))
{
if (!IsSinkEdge(edge))
{
float f = reverseCapacity(edge);
flow[edge] = -f;
NodeType target = graph.SourceOf(edge);
if (!Sources.Contains(target) && !Sinks.Contains(target))
{
excess[target] += f;
activeNodes.Add(target);
}
}
}
}
}
/// <summary>
/// Compute the min cut and return all nodes connected to any source
/// </summary>
/// <returns>The set of all nodes connected to any source after removing the min cut edges</returns>
public Set<NodeType> GetSourceGroup()
{
Initialize();
while (activeNodes.Count > 0)
{
//Console.WriteLine(StringUtil.CollectionToString(activeNodes," "));
// select an active node
NodeType node = activeNodes.First();
if (float.IsNaN(excess[node])) throw new Exception("encountered NaN");
if (excess[node] > 0f)
{
// discharge this node
// find an admissible edge
// an edge is admissible if neither endpoint is in the sourceGroup, distanceToSink(source)=distanceToSink(target)+1, and residual capacity>0
foreach (EdgeType edge in graph.EdgesOutOf(node))
{
NodeType target = graph.TargetOf(edge);
bool isSinkEdge = IsSinkEdge(edge);
int distTarget = isSinkEdge ? 0 : distanceToSink[target];
if (distanceToSink[node] != distTarget + 1) continue;
// target cannot be in sourceGroup since their distances are huge
float cap = capacity(edge);
float residual;
if (cap == flow[edge]) residual = 0f; // avoid infinity-infinity
else residual = cap - flow[edge];
if (residual <= 0f) continue;
// push flow along this edge
float push = System.Math.Min(excess[node], residual);
flow[edge] += push;
if (excess[node] == push) excess[node] = 0f; // avoid infinity-infinity
else excess[node] -= push;
if (!isSinkEdge && !Sinks.Contains(target))
{
excess[target] += push;
if (excess[target] > 0f) activeNodes.Add(target);
}
if (excess[node] <= 0f) break;
}
}
if (excess[node] > 0f)
{
// push flow back to inward edges
foreach (EdgeType edge in graph.EdgesInto(node))
{
NodeType target = graph.SourceOf(edge);
if (distanceToSink[node] != distanceToSink[target] + 1) continue;
// target cannot be in sourceGroup since their distances are huge,
// unless it is both a source and sink.
if (Sources.Contains(target)) continue;
float revCap = reverseCapacity(edge);
float residual;
if (-revCap == flow[edge]) residual = 0f;
else residual = revCap + flow[edge];
if (residual <= 0f) continue;
// push flow backward along this edge
float push = System.Math.Min(excess[node], residual);
if (flow[edge] == push) flow[edge] = 0f;
else flow[edge] -= push;
if (excess[node] == push) excess[node] = 0f;
else excess[node] -= push;
excess[target] += push;
if (!Sinks.Contains(target) && excess[target] > 0f) activeNodes.Add(target);
if (excess[node] <= 0f) break;
}
}
if (excess[node] > 0f)
{
// relabel
int dist = distanceToSink[node];
int count = nodesAtDistance[dist].Count;
if (count == 1)
{
// cut at this node
foreach (KeyValuePair<int, Set<NodeType>> entry in nodesAtDistance)
{
if (entry.Key >= dist)
{
foreach (NodeType node2 in entry.Value)
{
sourceGroup.Add(node2);
distanceToSink[node2] = int.MaxValue;
activeNodes.Remove(node2);
}
entry.Value.Clear();
}
}
}
else
{
// update the distance to sink
nodesAtDistance[dist].Remove(node);
dist = int.MaxValue;
foreach (EdgeType edge in graph.EdgesOutOf(node))
{
NodeType target = graph.TargetOf(edge);
if (sourceGroup.Contains(target)) continue;
float cap = capacity(edge);
float residual;
if (cap == flow[edge]) residual = 0f; // for infinity case
else residual = cap - flow[edge];
if (residual <= 0f) continue;
int distTarget = IsSinkEdge(edge) ? 0 : distanceToSink[target];
dist = System.Math.Min(dist, distTarget + 1);
}
foreach (EdgeType edge in graph.EdgesInto(node))
{
NodeType target = graph.SourceOf(edge);
if (sourceGroup.Contains(target)) continue;
float revCap = reverseCapacity(edge);
float residual;
if (-revCap == flow[edge]) residual = 0f;
else residual = revCap + flow[edge];
if (residual <= 0f) continue;
dist = System.Math.Min(dist, distanceToSink[target] + 1);
}
distanceToSink[node] = dist;
if (dist == int.MaxValue)
{
sourceGroup.Add(node);
// node is no longer active
}
else
{
Set<NodeType> nodes;
if (!nodesAtDistance.TryGetValue(dist, out nodes))
{
nodes = new Set<NodeType>();
nodesAtDistance[dist] = nodes;
}
nodes.Add(node);
continue; // node remains active
}
}
}
activeNodes.Remove(node);
}
return sourceGroup;
}
}
}

202
src/Compiler/Graphs/Node.cs Normal file
Просмотреть файл

@ -0,0 +1,202 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
internal interface HasTargets<T>
{
ICollection<T> Targets { get; }
}
internal interface HasSources<T>
{
ICollection<T> Sources { get; }
}
internal interface HasSourcesAndTargets<T> : HasSources<T>, HasTargets<T>
{
}
/// <summary>
/// Stores a list of source nodes and target nodes. (Does not store edges.)
/// </summary>
/// <typeparam name="T">Node type to link to (usually a DirectedNode itself).</typeparam>
internal class DirectedNode<T> : HasSourcesAndTargets<T>
{
protected List<T> targets, sources;
public ICollection<T> Targets
{
get { return targets; }
}
public ICollection<T> Sources
{
get { return sources; }
}
public ICollection<T> Neighbors
{
get { return new JoinCollections<T>(sources, targets); }
}
public DirectedNode()
{
targets = new List<T>();
sources = new List<T>();
}
public DirectedNode(HasSourcesAndTargets<T> node)
{
// clone the collections, but not the referenced nodes
targets = new List<T>(node.Targets);
sources = new List<T>(node.Sources);
}
}
internal interface HasOutEdges<E>
{
ICollection<E> OutEdges { get; }
}
internal interface HasInEdges<E>
{
ICollection<E> InEdges { get; }
}
internal interface HasInAndOutEdges<E> : HasInEdges<E>, HasOutEdges<E>
{
}
/// <summary>
/// Stores a list of OutEdges and InEdges.
/// </summary>
/// <typeparam name="T">Node type used by the edges.</typeparam>
/// <typeparam name="E">Edge type to store.</typeparam>
internal class DirectedNode<T, E> : HasSourcesAndTargets<T>, HasInAndOutEdges<E>
where E : IEdge<T>
{
protected List<E> outEdges, inEdges;
public ICollection<E> OutEdges
{
get { return outEdges; }
}
public ICollection<E> InEdges
{
get { return inEdges; }
}
public ICollection<T> Targets
{
get { return outEdges.ConvertAll<T>(delegate(E edge) { return edge.Target; }).AsReadOnly(); }
}
public ICollection<T> Sources
{
get { return inEdges.ConvertAll<T>(delegate(E edge) { return edge.Source; }).AsReadOnly(); }
}
public DirectedNode()
{
outEdges = new List<E>();
inEdges = new List<E>();
}
public DirectedNode(HasInAndOutEdges<E> node)
{
// clone the collections, but not the referenced edges
outEdges = new List<E>(node.OutEdges);
inEdges = new List<E>(node.InEdges);
}
}
/// <summary>
/// Directed graph node holding data of type T
/// </summary>
/// <typeparam name="T">The data type</typeparam>
internal class BasicNode<T> : DirectedNode<BasicNode<T>>
{
public T Data;
public BasicNode(T data)
: base()
{
Data = data;
}
/// <summary>
/// Copy constructor.
/// </summary>
/// <param name="node"></param>
public BasicNode(BasicNode<T> node)
: base(node)
{
Data = node.Data;
}
public override string ToString()
{
if (object.ReferenceEquals(Data, null)) return "(Node)";
return Data.ToString();
}
}
internal class BasicNode : DirectedNode<BasicNode>
{
public object Data;
public BasicNode(object data)
: base()
{
Data = data;
}
/// <summary>
/// Copy constructor.
/// </summary>
/// <param name="node"></param>
public BasicNode(BasicNode node)
: base(node)
{
Data = node.Data;
}
public override string ToString()
{
if (Data == null) return "(Node)";
return Data.ToString();
}
}
internal class BasicEdgeNode : DirectedNode<BasicEdgeNode, Edge<BasicEdgeNode>>
{
public object Data;
public BasicEdgeNode(object data)
: base()
{
Data = data;
}
/// <summary>
/// Copy constructor.
/// </summary>
/// <param name="node"></param>
public BasicEdgeNode(BasicEdgeNode node)
: base(node)
{
Data = node.Data;
}
public override string ToString()
{
if (Data == null) return "(Node)";
return Data.ToString();
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,346 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Find all elementary paths of a directed graph
/// </summary>
/// <typeparam name="NodeType">The node type</typeparam>
/// <remarks><para>
/// The paths are described by firing actions according to the pattern:
/// BeginPath, AddNode, AddNode, ..., AddNode, EndPath, BeginPath, ..., EndPath.
/// The node on a path will appear in order of the directed edges between them.
/// </para><para>
/// Only paths which cannot be made longer are returned, i.e. sub-paths of an elementary path are not returned.
/// </para></remarks>
internal class PathFinder<NodeType>
{
private IDirectedGraph<NodeType> graph;
private IndexedProperty<NodeType, bool> isBlocked;
private Stack<NodeType> stack = new Stack<NodeType>();
public event Action<NodeType> AddNode;
public event Action BeginPath , EndPath;
public PathFinder(IDirectedGraph<NodeType> graph)
{
this.graph = graph;
CanCreateNodeData<NodeType> data = (CanCreateNodeData<NodeType>) graph;
isBlocked = data.CreateNodeData<bool>(false);
}
/// <summary>
/// Find all paths starting with the given node
/// </summary>
/// <param name="node">Starting node</param>
public void SearchFrom(NodeType node)
{
stack.Push(node);
isBlocked[node] = true;
bool canExtend = false;
foreach (NodeType target in graph.TargetsOf(node))
{
if (isBlocked[target]) continue;
// recursive call
SearchFrom(target);
canExtend = true;
}
isBlocked[node] = false;
if (!canExtend && stack.Count > 1)
{
// path has multiple nodes and cannot be extended
OnBeginPath();
Stack<NodeType> temp = new Stack<NodeType>();
foreach (NodeType nodeOnStack in stack)
{
temp.Push(nodeOnStack);
}
foreach (NodeType nodeOnStack in temp)
{
OnAddNode(nodeOnStack);
}
OnEndPath();
}
stack.Pop();
}
public void OnAddNode(NodeType node)
{
if (AddNode != null) AddNode(node);
}
public void OnBeginPath()
{
if (BeginPath != null) BeginPath();
}
public void OnEndPath()
{
if (EndPath != null) EndPath();
}
}
/// <summary>
/// Find all elementary paths of a directed graph
/// </summary>
/// <typeparam name="NodeType">The node type</typeparam>
/// <typeparam name="EdgeType">The edge type</typeparam>
/// <remarks><para>
/// The paths are described by firing actions according to the pattern:
/// BeginPath, AddEdge, AddEdge, ..., AddEdge, EndPath, BeginPath, ..., EndPath.
/// The edges on a path will appear in order of their directions.
/// </para><para>
/// Only paths which cannot be made longer are returned, i.e. sub-paths of an elementary path are not returned.
/// </para></remarks>
internal class PathFinder<NodeType, EdgeType>
{
private Converter<NodeType, IEnumerable<EdgeType>> edgesOutOf;
private Converter<EdgeType, NodeType> targetOf;
private IndexedProperty<NodeType, bool> isBlocked;
private Stack<EdgeType> stack = new Stack<EdgeType>();
public event Action<EdgeType> AddEdge;
public event Action BeginPath , EndPath;
public PathFinder(Converter<NodeType, IEnumerable<EdgeType>> edgesOutOf, Converter<EdgeType, NodeType> targetOf, CanCreateNodeData<NodeType> data)
{
this.edgesOutOf = edgesOutOf;
this.targetOf = targetOf;
isBlocked = data.CreateNodeData<bool>(false);
}
public PathFinder(IDirectedGraph<NodeType, EdgeType> graph)
{
CanCreateNodeData<NodeType> data = (CanCreateNodeData<NodeType>) graph;
isBlocked = data.CreateNodeData<bool>(false);
targetOf = graph.TargetOf;
edgesOutOf = graph.EdgesOutOf;
}
/// <summary>
/// Find all paths starting with the given node
/// </summary>
/// <param name="node">Starting node</param>
public void SearchFrom(NodeType node)
{
isBlocked[node] = true;
bool foundPath = false;
foreach (EdgeType edge in edgesOutOf(node))
{
NodeType target = targetOf(edge);
if (isBlocked[target]) continue;
// recursive call
stack.Push(edge);
SearchFrom(target);
stack.Pop();
foundPath = true;
}
isBlocked[node] = false;
if (!foundPath && stack.Count > 0)
{
OnBeginPath();
Stack<EdgeType> temp = new Stack<EdgeType>();
foreach (EdgeType edgeOnStack in stack)
{
temp.Push(edgeOnStack);
}
foreach (EdgeType edgeOnStack in temp)
{
OnAddEdge(edgeOnStack);
}
OnEndPath();
}
}
public void OnAddEdge(EdgeType edge)
{
if (AddEdge != null) AddEdge(edge);
}
public void OnBeginPath()
{
if (BeginPath != null) BeginPath();
}
public void OnEndPath()
{
if (EndPath != null) EndPath();
}
}
/// <summary>
/// Find all nodes on any path from a source set to a sink set
/// </summary>
/// <typeparam name="NodeType"></typeparam>
internal class NodeOnPathFinder<NodeType>
{
private Converter<NodeType, IEnumerable<NodeType>> successors;
public IndexedProperty<NodeType, bool> isBlocked;
private IndexedProperty<NodeType, Set<NodeType>> blockedSources;
private IndexedProperty<NodeType, bool> onPath;
private Predicate<NodeType> isSink;
public NodeOnPathFinder(
Converter<NodeType, IEnumerable<NodeType>> successors,
CanCreateNodeData<NodeType> data,
IndexedProperty<NodeType, bool> onPath,
Predicate<NodeType> isSink)
{
this.successors = successors;
isBlocked = data.CreateNodeData<bool>(false);
blockedSources = data.CreateNodeData<Set<NodeType>>(null);
this.onPath = onPath;
this.isSink = isSink;
}
public void Clear()
{
isBlocked.Clear();
blockedSources.Clear();
onPath.Clear();
}
public void SearchFrom(NodeType node)
{
if (onPath[node]) return;
if (isSink(node))
{
onPath[node] = true;
}
isBlocked[node] = true;
foreach (NodeType target in successors(node))
{
if (isBlocked[target]) continue;
// recursive call
SearchFrom(target);
if (onPath[target])
{
onPath[node] = true;
}
}
if (onPath[node]) Unblock(node);
else
{
// at this point, all targets are blocked
foreach (NodeType target in successors(node))
{
Set<NodeType> blockedSourcesOfTarget = blockedSources[target];
if (blockedSourcesOfTarget == null)
{
blockedSourcesOfTarget = new Set<NodeType>();
blockedSources[target] = blockedSourcesOfTarget;
}
blockedSourcesOfTarget.Add(node);
}
}
}
private void Unblock(NodeType node)
{
isBlocked[node] = false;
Set<NodeType> blockedSourcesOfNode = blockedSources[node];
if (blockedSourcesOfNode != null)
{
blockedSources[node] = null;
foreach (NodeType source in blockedSourcesOfNode)
{
if (isBlocked[source]) Unblock(source);
}
}
}
}
/// <summary>
/// Find all edges on any path from a source set to a sink set
/// </summary>
/// <typeparam name="NodeType"></typeparam>
/// <typeparam name="EdgeType"></typeparam>
internal class EdgeOnPathFinder<NodeType, EdgeType>
{
private Converter<NodeType, IEnumerable<EdgeType>> edgesOutOf;
private Converter<EdgeType, NodeType> targetOf;
public IndexedProperty<NodeType, bool> isBlocked;
private IndexedProperty<NodeType, Set<NodeType>> blockedSources;
private IndexedProperty<EdgeType, bool> onPath;
private Predicate<NodeType> isSink;
public EdgeOnPathFinder(
Converter<NodeType, IEnumerable<EdgeType>> edgesOutOf,
Converter<EdgeType, NodeType> targetOf,
CanCreateNodeData<NodeType> data,
IndexedProperty<EdgeType, bool> onPath,
Predicate<NodeType> isSink)
{
this.edgesOutOf = edgesOutOf;
this.targetOf = targetOf;
isBlocked = data.CreateNodeData<bool>(false);
blockedSources = data.CreateNodeData<Set<NodeType>>(null);
this.onPath = onPath;
this.isSink = isSink;
}
public void Clear()
{
isBlocked.Clear();
blockedSources.Clear();
onPath.Clear();
}
public void SearchFrom(NodeType node)
{
if (isSink(node))
return;
isBlocked[node] = true;
bool foundPath = false;
foreach (var edge in this.edgesOutOf(node))
{
NodeType target = this.targetOf(edge);
if (isBlocked[target])
continue;
// recursive call
SearchFrom(target);
if (!isBlocked[target])
{
onPath[edge] = true;
foundPath = true;
}
}
if (foundPath)
Unblock(node);
else
{
// at this point, all targets are blocked
foreach (var edge in this.edgesOutOf(node))
{
NodeType target = this.targetOf(edge);
Set<NodeType> blockedSourcesOfTarget = blockedSources[target];
if (blockedSourcesOfTarget == null)
{
blockedSourcesOfTarget = new Set<NodeType>();
blockedSources[target] = blockedSourcesOfTarget;
}
blockedSourcesOfTarget.Add(node);
}
}
}
private void Unblock(NodeType node)
{
isBlocked[node] = false;
Set<NodeType> blockedSourcesOfNode = blockedSources[node];
if (blockedSourcesOfNode != null)
{
blockedSources[node] = null;
foreach (NodeType source in blockedSourcesOfNode)
{
if (isBlocked[source])
Unblock(source);
}
}
}
}
}

Просмотреть файл

@ -0,0 +1,100 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Searches for a pair of pseudo-peripheral nodes in a graph.
/// </summary>
/// <remarks>
/// The nodes <c>(start,end)</c> form a pseudo-peripheral pair
/// if <c>end</c> is the furthest node from <c>start</c> and <c>start</c> is the
/// furthest node from <c>end</c>.
/// If the distance is maximal over all such pairs, then
/// the nodes are peripheral.
/// This class does not guarantee that the pair is peripheral, only pseudo-peripheral.
/// In a directed graph, distance from <c>start</c> is measured forward and distance from
/// <c>end</c> is measured backward.
/// </remarks>
internal class PseudoPeripheralSearch<NodeType>
{
private Converter<NodeType, IEnumerable<NodeType>> Successors;
private Converter<NodeType, IEnumerable<NodeType>> Predecessors;
private CanCreateNodeData<NodeType> Data;
public PseudoPeripheralSearch(IGraph<NodeType> graph)
: this(graph.NeighborsOf, graph.NeighborsOf, (CanCreateNodeData<NodeType>) graph)
{
}
public PseudoPeripheralSearch(IDirectedGraph<NodeType> graph)
: this(graph.TargetsOf, graph.SourcesOf, (CanCreateNodeData<NodeType>) graph)
{
}
public PseudoPeripheralSearch(Converter<NodeType, IEnumerable<NodeType>> successors,
Converter<NodeType, IEnumerable<NodeType>> predecessors,
CanCreateNodeData<NodeType> data)
{
this.Successors = successors;
this.Predecessors = predecessors;
this.Data = data;
}
/// <summary>
/// Find a pseudo-peripheral pair.
/// </summary>
/// <param name="start">On entry, holds an initial seed for the search. On return, holds the start node of the pair.</param>
/// <param name="end">On return, holds the end node of the pair.</param>
/// <remarks>Regardless of the seed node provided, a pseudo-peripheral pair will always be found.
/// However the seed can affect the quality of the pair (i.e. how distant they are).</remarks>
public void SearchFrom(ref NodeType start, out NodeType end)
{
NodeType startNode = start, endNode = start;
int maxDistance = 0;
bool firstTime = true;
// this loop will always terminate because maxDistance only increases.
while (true)
{
DistanceSearch<NodeType> distanceForward = new DistanceSearch<NodeType>(Successors, Data);
bool endMoved = false;
distanceForward.SetDistance += delegate(NodeType node, int distance)
{
//Console.WriteLine("forward: "+node+distance);
if (distance > maxDistance)
{
maxDistance = distance;
endNode = node;
endMoved = true;
}
};
distanceForward.SearchFrom(startNode);
if (firstTime) firstTime = false;
else if (!endMoved) break;
distanceForward = null;
DistanceSearch<NodeType> distanceBackward = new DistanceSearch<NodeType>(Predecessors, Data);
bool startMoved = false;
distanceBackward.SetDistance += delegate(NodeType node, int distance)
{
//Console.WriteLine("backward: "+node+distance);
if (distance > maxDistance)
{
maxDistance = distance;
startNode = node;
startMoved = true;
}
};
distanceBackward.SearchFrom(endNode);
if (!startMoved) break;
distanceBackward = null;
}
start = startNode;
end = endNode;
}
}
}

Просмотреть файл

@ -0,0 +1,239 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Utilities;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Represents a collection of integers in sequence from LowerBound to UpperBound.
/// </summary>
internal class Range : IList<int>
{
protected int Start, Count;
public int LowerBound
{
get { return Start; }
set { Start = value; }
}
public int UpperBound
{
get { return Start + Count - 1; }
set { Count = value - Start + 1; }
}
public Range(int start, int count)
{
this.Start = start;
this.Count = count;
}
public void Add(int item)
{
throw new Exception("The method or operation is not implemented.");
}
public void Clear()
{
throw new Exception("The method or operation is not implemented.");
}
public bool Contains(int item)
{
return (item >= LowerBound) && (item <= UpperBound);
}
public void CopyTo(int[] array, int arrayIndex)
{
for (int i = 0; i < Count; i++)
{
array[arrayIndex + i] = Start + i;
}
}
int ICollection<int>.Count
{
get { return Count; }
}
public bool IsReadOnly
{
get { return true; }
}
public bool Remove(int item)
{
throw new Exception("The method or operation is not implemented.");
}
public IEnumerator<int> GetEnumerator()
{
for (int i = 0; i < Count; i++)
{
yield return Start + i;
}
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public int IndexOf(int item)
{
if (Contains(item)) return item - LowerBound;
else return -1;
}
public void Insert(int index, int item)
{
throw new Exception("The method or operation is not implemented.");
}
public void RemoveAt(int index)
{
throw new Exception("The method or operation is not implemented.");
}
public int this[int index]
{
get { return Start + index; }
set { throw new Exception("The method or operation is not implemented."); }
}
}
/// <summary>
/// Represents a multidimensional grid of integer points.
/// </summary>
internal class MultiRange : IList<int[]>
{
public int[] LowerBounds, Lengths, Strides;
public MultiRange(params int[] lengths)
: this(new int[lengths.Length], lengths)
{
}
public MultiRange(int[] lowerBounds, int[] lengths)
{
LowerBounds = lowerBounds;
Lengths = lengths;
Strides = StringUtil.ArrayStrides(lengths);
}
public static MultiRange ArrayIndices(Array array)
{
return new MultiRange(StringUtil.ArrayLowerBounds(array), StringUtil.ArrayDimensions(array));
}
public int Rank
{
get { return Lengths.Length; }
}
public int GetUpperBound(int dim)
{
return LowerBounds[dim] + Lengths[dim] - 1;
}
public int IndexOf(int[] item)
{
if (!Contains(item)) return -1;
int index = 0;
for (int i = 0; i < Rank; i++)
{
index += (item[i] - LowerBounds[i])*Strides[i];
}
return index;
}
public void Insert(int index, int[] item)
{
throw new Exception("The method or operation is not implemented.");
}
public void RemoveAt(int index)
{
throw new Exception("The method or operation is not implemented.");
}
public int[] this[int index]
{
get
{
int[] value = new int[Strides.Length];
StringUtil.LinearIndexToMultidimensionalIndex(index, Strides, value, LowerBounds);
return value;
}
set { throw new Exception("The method or operation is not implemented."); }
}
public void Add(int[] item)
{
throw new Exception("The method or operation is not implemented.");
}
public void Clear()
{
throw new Exception("The method or operation is not implemented.");
}
public bool Contains(int[] item)
{
for (int i = 0; i < Rank; i++)
{
if (!(item[i] >= LowerBounds[i] && item[i] < LowerBounds[i] + Lengths[i])) return false;
}
return true;
}
public void CopyTo(int[][] array, int arrayIndex)
{
throw new Exception("The method or operation is not implemented.");
}
public int Count
{
get
{
int count = 1;
for (int i = 0; i < Rank; i++)
{
count *= Lengths[i];
}
return count;
}
}
public bool IsReadOnly
{
get { return true; }
}
public bool Remove(int[] item)
{
throw new Exception("The method or operation is not implemented.");
}
public IEnumerator<int[]> GetEnumerator()
{
int count = Count;
int[] index = new int[Rank];
for (int i = 0; i < count; i++)
{
StringUtil.LinearIndexToMultidimensionalIndex(i, Strides, index, LowerBounds);
yield return index;
}
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

Просмотреть файл

@ -0,0 +1,238 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Utilities;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Wraps a stack to look like a list where you can only access index 0.
/// </summary>
/// <typeparam name="T"></typeparam>
internal class StackAsList<T> : IList<T>
{
public Stack<T> Stack;
public StackAsList(Stack<T> stack)
{
this.Stack = stack;
}
public StackAsList()
{
this.Stack = new Stack<T>();
}
public int IndexOf(T item)
{
throw new NotSupportedException();
}
public void Insert(int index, T item)
{
if (index == 0)
{
Add(item);
}
else
{
throw new NotSupportedException();
}
}
public void RemoveAt(int index)
{
if (index == 0)
{
Stack.Pop();
}
else
{
throw new NotSupportedException();
}
}
public T this[int index]
{
get
{
if (index == 0)
{
return Stack.Peek();
}
else
{
throw new NotSupportedException();
}
}
set { throw new NotSupportedException(); }
}
public void Add(T item)
{
Stack.Push(item);
}
public void Clear()
{
Stack.Clear();
}
public bool Contains(T item)
{
return Stack.Contains(item);
}
public void CopyTo(T[] array, int arrayIndex)
{
Stack.CopyTo(array, arrayIndex);
}
public int Count
{
get { return Stack.Count; }
}
public bool IsReadOnly
{
get { return false; }
}
public bool Remove(T item)
{
throw new NotSupportedException();
}
public IEnumerator<T> GetEnumerator()
{
return Stack.GetEnumerator();
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return ((System.Collections.IEnumerable) Stack).GetEnumerator();
}
public override string ToString()
{
return StringUtil.EnumerableToString(this, Environment.NewLine);
}
}
/// <summary>
/// Wraps a Queue to look like a list where you can only access index 0.
/// </summary>
/// <typeparam name="T"></typeparam>
internal class QueueAsList<T> : IList<T>
{
public Queue<T> Queue;
public QueueAsList(Queue<T> queue)
{
this.Queue = queue;
}
public QueueAsList()
{
this.Queue = new Queue<T>();
}
public int IndexOf(T item)
{
throw new NotSupportedException();
}
public void Insert(int index, T item)
{
if (index == 0)
{
Add(item);
}
else
{
throw new Exception("The method or operation is not implemented.");
}
}
public void RemoveAt(int index)
{
if (index == 0)
{
Queue.Dequeue();
}
else
{
throw new Exception("The method or operation is not implemented.");
}
}
public T this[int index]
{
get
{
if (index == 0)
{
return Queue.Peek();
}
else
{
throw new Exception("The method or operation is not implemented.");
}
}
set { throw new Exception("The method or operation is not implemented."); }
}
public void Add(T item)
{
Queue.Enqueue(item);
}
public void Clear()
{
Queue.Clear();
}
public bool Contains(T item)
{
return Queue.Contains(item);
}
public void CopyTo(T[] array, int arrayIndex)
{
Queue.CopyTo(array, arrayIndex);
}
public int Count
{
get { return Queue.Count; }
}
public bool IsReadOnly
{
get { return false; }
}
public bool Remove(T item)
{
throw new Exception("The method or operation is not implemented.");
}
public IEnumerator<T> GetEnumerator()
{
return Queue.GetEnumerator();
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return ((System.Collections.IEnumerable) Queue).GetEnumerator();
}
public override string ToString()
{
return StringUtil.EnumerableToString(this, Environment.NewLine);
}
}
}

Просмотреть файл

@ -0,0 +1,454 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
// Reference:
// [1] "Introduction to Algorithms" by Cormen, Leiserson, and Rivest (1994)
// [2] "An Improved Algorithm for Finding the Strongly Connected Components of a Directed Graph"
// David J. Pearce, Technical Report, 2005
// http://www.mcs.vuw.ac.nz/~djp/files/P05.ps
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Microsoft.ML.Probabilistic.Compiler.Graphs
{
/// <summary>
/// Find the strongly connected components of a directed graph.
/// </summary>
/// <typeparam name="NodeType">The node type.</typeparam>
/// <remarks><para>
/// A strongly connected component is a maximal set of nodes that can all reach each other by a directed path.
/// Every directed graph has a unique partition of nodes into strongly connected components.
/// The components form a DAG, since there cannot be any directed cycle among the components.
/// </para><para>
/// Given a graph and a set of start nodes, this class enumerates the strongly connected components
/// that are reachable from the start nodes.
/// The components are described by firing actions according to the pattern:
/// BeginComponent, AddNode, AddNode, ..., AddNode, EndComponent, BeginComponent, ..., EndComponent.
/// The nodes in a component will appear in an arbitrary order.
/// The components will appear in topological order, i.e. there are no edges from
/// a later component to an earlier component.
/// </para><para>
/// The implementation uses depth first search in each direction (Kosaraju's algorithm), as described by Cormen, Leiserson, and Rivest.
/// </para></remarks>
internal class StrongComponents<NodeType>
{
protected DepthFirstSearch<NodeType> dfsBackward, dfsForward;
protected Stack<NodeType> finished;
public event Action<NodeType> AddNode;
public event Action BeginComponent , EndComponent;
public StrongComponents(IDirectedGraph<NodeType> graph)
: this(graph.TargetsOf, graph.SourcesOf, (CanCreateNodeData<NodeType>) graph)
{
}
public StrongComponents(Converter<NodeType, IEnumerable<NodeType>> successors,
Converter<NodeType, IEnumerable<NodeType>> predecessors, CanCreateNodeData<NodeType> data)
{
dfsForward = new DepthFirstSearch<NodeType>(successors, data);
finished = new Stack<NodeType>();
dfsForward.FinishNode += node => finished.Push(node);
dfsBackward =
new DepthFirstSearch<NodeType>(node => predecessors(node).Where(source => (dfsForward.IsVisited[source] != VisitState.Unvisited)), data);
dfsBackward.DiscoverNode += OnAddNode;
}
public void SearchFrom(IEnumerable<NodeType> starts)
{
dfsForward.SearchFrom(starts);
ProcessFinished();
}
public void SearchFrom(NodeType start)
{
dfsForward.SearchFrom(start);
ProcessFinished();
}
protected void ProcessFinished()
{
while (finished.Count > 0)
{
NodeType start = finished.Pop();
if (dfsBackward.IsVisited[start] == VisitState.Unvisited)
{
OnBeginComponent();
dfsBackward.SearchFrom(start);
OnEndComponent();
}
}
}
public void OnAddNode(NodeType node)
{
if (AddNode != null) AddNode(node);
}
public void OnBeginComponent()
{
if (BeginComponent != null) BeginComponent();
}
public void OnEndComponent()
{
if (EndComponent != null) EndComponent();
}
}
/// <summary>
/// Find the strongly connected components of a directed graph.
/// </summary>
/// <typeparam name="NodeType">The node type.</typeparam>
/// <remarks><para>
/// A strongly connected component is a maximal set of nodes that can all reach each other by a directed path.
/// Every directed graph has a unique partition of nodes into strongly connected components.
/// The components form a DAG, since there cannot be any directed cycle among the components.
/// </para><para>
/// Given a graph and a set of start nodes, this class enumerates the strongly connected components
/// that are reachable from the start nodes.
/// The components are described by firing actions according to the pattern:
/// BeginComponent, AddNode, AddNode, ..., AddNode, EndComponent, BeginComponent, ..., EndComponent.
/// The nodes in a component will appear in an arbitrary order.
/// The components will appear in reverse topological order, i.e. there are no edges from
/// an earlier component to a later component.
/// </para><para>
/// The implementation uses Pierce's algorithm (a modification of Tarjan's algorithm):
/// "An Improved Algorithm for Finding the Strongly Connected Components of a Directed Graph"
/// David J. Pearce, Technical Report, 2005
/// http://www.mcs.vuw.ac.nz/~djp/files/P05.ps
/// </para></remarks>
internal class StrongComponents2<NodeType>
{
private DepthFirstSearch<NodeType> dfs;
public IndexedProperty<NodeType, int> DiscoverTime, RootDiscoverTime;
protected Stack<NodeType> finished;
public event Action<NodeType> AddNode;
public event Action BeginComponent , EndComponent;
public int time;
public StrongComponents2(IDirectedGraph<NodeType> graph)
: this(graph.TargetsOf, (CanCreateNodeData<NodeType>) graph)
{
}
public void Clear()
{
time = 0;
dfs.Clear();
}
public StrongComponents2(Converter<NodeType, IEnumerable<NodeType>> successors, CanCreateNodeData<NodeType> data)
{
dfs = new DepthFirstSearch<NodeType>(successors, data);
finished = new Stack<NodeType>();
DiscoverTime = data.CreateNodeData<int>(0);
RootDiscoverTime = data.CreateNodeData<int>(0);
dfs.DiscoverNode += delegate(NodeType node)
{
DiscoverTime[node] = time;
RootDiscoverTime[node] = time;
time++;
};
dfs.BackEdge += delegate(Edge<NodeType> edge)
{
if (RootDiscoverTime[edge.Target] < RootDiscoverTime[edge.Source])
RootDiscoverTime[edge.Source] = RootDiscoverTime[edge.Target];
};
dfs.CrossEdge += delegate(Edge<NodeType> edge)
{
if (RootDiscoverTime[edge.Target] < RootDiscoverTime[edge.Source])
RootDiscoverTime[edge.Source] = RootDiscoverTime[edge.Target];
};
dfs.FinishTreeEdge += delegate(Edge<NodeType> edge)
{
if (RootDiscoverTime[edge.Target] < RootDiscoverTime[edge.Source])
RootDiscoverTime[edge.Source] = RootDiscoverTime[edge.Target];
};
dfs.FinishNode += delegate(NodeType node)
{
int thisRootDiscoverTime = RootDiscoverTime[node];
if (thisRootDiscoverTime < DiscoverTime[node])
{
// not a root
finished.Push(node);
}
else
{
// root of a component
OnBeginComponent();
OnAddNode(node);
while (finished.Count > 0 && RootDiscoverTime[finished.Peek()] >= thisRootDiscoverTime)
{
NodeType child = finished.Pop();
OnAddNode(child);
RootDiscoverTime[child] = Int32.MaxValue; // prevent child from affecting any other components
}
RootDiscoverTime[node] = Int32.MaxValue; // prevent node from affecting any other components
OnEndComponent();
}
};
}
public void SearchFrom(IEnumerable<NodeType> starts)
{
dfs.SearchFrom(starts);
}
public void SearchFrom(NodeType start)
{
dfs.SearchFrom(start);
}
public void OnAddNode(NodeType node)
{
if (AddNode != null) AddNode(node);
}
public void OnBeginComponent()
{
if (BeginComponent != null) BeginComponent();
}
public void OnEndComponent()
{
if (EndComponent != null) EndComponent();
}
}
internal class StrongComponentChecker<NodeType, EdgeType>
{
private IDirectedGraph<NodeType, EdgeType> graph;
private DepthFirstSearch<NodeType, EdgeType> dfs;
public IndexedProperty<NodeType, int> DiscoverTime, RootDiscoverTime;
/// <summary>
/// Modified by SearchFrom
/// </summary>
public bool IsStrong;
public List<EdgeType> RedundantEdges = new List<EdgeType>();
private int time = 0;
public StrongComponentChecker(IDirectedGraph<NodeType, EdgeType> graph)
{
this.graph = graph;
dfs = new DepthFirstSearch<NodeType, EdgeType>(graph);
CanCreateNodeData<NodeType> data = (CanCreateNodeData<NodeType>) graph;
DiscoverTime = data.CreateNodeData<int>(0);
RootDiscoverTime = data.CreateNodeData<int>(0);
dfs.DiscoverNode += delegate(NodeType node)
{
DiscoverTime[node] = time;
RootDiscoverTime[node] = time;
time++;
};
dfs.BackEdge += ProcessEdge2;
dfs.CrossEdge += ProcessEdge2;
dfs.FinishTreeEdge += ProcessEdge;
dfs.FinishNode += delegate(NodeType node)
{
int thisRootDiscoverTime = RootDiscoverTime[node];
if (thisRootDiscoverTime < DiscoverTime[node])
{
// not a root
}
else
{
// root of a component
if (thisRootDiscoverTime != 0) IsStrong = false;
}
};
}
public void ProcessEdge(EdgeType edge)
{
NodeType source = graph.SourceOf(edge);
NodeType target = graph.TargetOf(edge);
if (RootDiscoverTime[target] < RootDiscoverTime[source])
RootDiscoverTime[source] = RootDiscoverTime[target];
}
public void ProcessEdge2(EdgeType edge)
{
NodeType source = graph.SourceOf(edge);
NodeType target = graph.TargetOf(edge);
if (RootDiscoverTime[target] < RootDiscoverTime[source])
RootDiscoverTime[source] = RootDiscoverTime[target];
else
RedundantEdges.Add(edge);
}
public void SearchFrom(NodeType start)
{
IsStrong = true;
dfs.SearchFrom(start);
if (time < graph.Nodes.Count) IsStrong = false;
if (!IsStrong) RedundantEdges.Clear();
}
public void Clear()
{
IsStrong = false;
RedundantEdges.Clear();
time = 0;
dfs.Clear();
// do not need to clear DiscoverTime
}
}
/// <summary>
/// A subgraph with the same nodes but fewer edges.
/// </summary>
/// <typeparam name="NodeType"></typeparam>
/// <typeparam name="EdgeType"></typeparam>
internal class DirectedGraphFilter<NodeType, EdgeType> : IDirectedGraph<NodeType, EdgeType>, CanCreateNodeData<NodeType>, CanCreateEdgeData<EdgeType>
{
private IDirectedGraph<NodeType, EdgeType> graph;
private Predicate<EdgeType> predicate;
public DirectedGraphFilter(IDirectedGraph<NodeType, EdgeType> graph, Predicate<EdgeType> predicate)
{
this.graph = graph;
this.predicate = predicate;
}
public NodeType SourceOf(EdgeType edge)
{
return graph.SourceOf(edge);
}
public NodeType TargetOf(EdgeType edge)
{
return graph.TargetOf(edge);
}
public IEnumerable<EdgeType> EdgesOutOf(NodeType source)
{
foreach (EdgeType edge in graph.EdgesOutOf(source))
{
if (predicate(edge)) yield return edge;
}
}
public IEnumerable<EdgeType> EdgesInto(NodeType target)
{
foreach (EdgeType edge in graph.EdgesInto(target))
{
if (predicate(edge)) yield return edge;
}
}
private IEnumerable<EdgeType> AllEdges()
{
foreach (EdgeType edge in graph.Edges)
{
if (predicate(edge)) yield return edge;
}
}
public IEnumerable<EdgeType> Edges
{
get { return AllEdges(); }
}
public EdgeType GetEdge(NodeType source, NodeType target)
{
EdgeType edge = graph.GetEdge(source, target);
if (predicate(edge)) return edge;
else throw new EdgeNotFoundException(source, target);
}
public bool TryGetEdge(NodeType source, NodeType target, out EdgeType edge)
{
if (graph.TryGetEdge(source, target, out edge)) return predicate(edge);
else return false;
}
public IEnumerable<EdgeType> EdgesOf(NodeType node)
{
foreach (EdgeType edge in graph.EdgesOf(node))
{
if (predicate(edge)) yield return edge;
}
}
public ICollection<NodeType> Nodes
{
get { return graph.Nodes; }
}
public int EdgeCount()
{
throw new NotImplementedException();
}
public int NeighborCount(NodeType node)
{
throw new NotImplementedException();
}
public IEnumerable<NodeType> NeighborsOf(NodeType node)
{
return SourcesOf(node).Concat(TargetsOf(node));
}
public bool ContainsEdge(NodeType source, NodeType target)
{
EdgeType edge;
if (graph.TryGetEdge(source, target, out edge)) return predicate(edge);
else return false;
}
public int TargetCount(NodeType source)
{
throw new NotImplementedException();
}
public int SourceCount(NodeType target)
{
throw new NotImplementedException();
}
public IEnumerable<NodeType> TargetsOf(NodeType source)
{
foreach (EdgeType edge in EdgesOutOf(source)) yield return TargetOf(edge);
}
public IEnumerable<NodeType> SourcesOf(NodeType target)
{
foreach (EdgeType edge in EdgesInto(target)) yield return SourceOf(edge);
}
public IndexedProperty<NodeType, T> CreateNodeData<T>(T defaultValue)
{
return ((CanCreateNodeData<NodeType>) graph).CreateNodeData<T>(defaultValue);
}
public IndexedProperty<EdgeType, T> CreateEdgeData<T>(T defaultValue)
{
return ((CanCreateEdgeData<EdgeType>) graph).CreateEdgeData<T>(defaultValue);
}
public override string ToString()
{
StringBuilder s = new StringBuilder();
foreach (NodeType node in Nodes)
{
s.AppendFormat("{0} -> ", node);
bool first = true;
foreach (NodeType target in TargetsOf(node))
{
if (!first) s.Append(" ");
else first = false;
s.Append(target.ToString());
}
s.AppendLine();
}
return s.ToString();
}
}
}

Просмотреть файл

@ -0,0 +1,114 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Models.Attributes;
namespace Microsoft.ML.Probabilistic.Algorithms
{
/// <summary>
/// Abstract base class for all algorithms
/// </summary>
public abstract class AlgorithmBase : IAlgorithm
{
/// <summary>
/// The algorithm's name
/// </summary>
public abstract string Name { get; }
/// <summary>
/// Short name for the inference algorithm
/// </summary>
public abstract string ShortName { get; }
public abstract Delegate GetVariableFactor(bool derived, bool initialised);
/// <summary>
/// Algorithm's operator suffix - used in in message update methods
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public abstract string GetOperatorMethodSuffix(List<ICompilerAttribute> factorAttributes);
/// <summary>
/// Gets the operator which converts a message to/from another algorithm
/// </summary>
/// <param name="channelType">Type of message</param>
/// <param name="alg2">The other algorithm</param>
/// <param name="isFromFactor">True if from, false if to</param>
/// <param name="args">Where to add arguments of the operator</param>
/// <returns>A method reference for the operator</returns>
public abstract MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args);
/// <summary>
/// Gets the suffix for this algorithm's evidence method
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public abstract string GetEvidenceMethodName(List<ICompilerAttribute> factorAttributes);
/// <summary>
/// Get the message prototype for this algorithm in the specified direction
/// </summary>
/// <param name="channelInfo">The channel information</param>
/// <param name="direction">The direction</param>
/// <param name="marginalPrototypeExpression">The marginal prototype expression</param>
/// <param name="path">Path name of message</param>
/// <param name="queryTypes"></param>
/// <returns></returns>
public virtual IExpression GetMessagePrototype(
ChannelInfo channelInfo, MessageDirection direction, IExpression marginalPrototypeExpression, string path, IList<QueryType> queryTypes)
{
return marginalPrototypeExpression;
}
/// <summary>
/// Allows the algorithm to modify the attributes on a factor. For example context-specific
/// message attributes on a method invoke expression
/// </summary>
/// <param name="factorExpression">The expression</param>
/// <param name="factorAttributes">Attribute registry</param>
/// <returns></returns>
public virtual void ModifyFactorAttributes(IExpression factorExpression, AttributeRegistry<object, ICompilerAttribute> factorAttributes)
{
// By default, remove any message path attributes
factorAttributes.Remove<MessagePathAttribute>(factorExpression);
return;
}
/// <summary>
/// Get the default inference query types for a variable for this algorithm.
/// </summary>
public virtual void ForEachDefaultQueryType(Action<QueryType> action)
{
action(QueryTypes.Marginal);
}
/// <summary>
/// Get the query type binding - this is the path to the given query type
/// relative to the raw marginal type.
/// </summary>
/// <param name="qt">The query type</param>
/// <returns></returns>
public virtual string GetQueryTypeBinding(QueryType qt)
{
return "";
}
private int defaultNumberOfIterations = -1;
/// <summary>
/// Default number of iterations for this algorithm
/// </summary>
public virtual int DefaultNumberOfIterations
{
get { return (defaultNumberOfIterations < 0) ? 50 : defaultNumberOfIterations; }
set { defaultNumberOfIterations = value; }
}
}
}

Просмотреть файл

@ -0,0 +1,96 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Factors;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Models;
using Microsoft.ML.Probabilistic.Models.Attributes;
namespace Microsoft.ML.Probabilistic.Algorithms
{
/// <summary>
/// The expectation propagation inference algorithm, see also
/// http://research.microsoft.com/~minka/papers/ep/roadmap.html.
/// </summary>
public class ExpectationPropagation : AlgorithmBase, IAlgorithm
{
#region IAlgorithm Members
public override Delegate GetVariableFactor(bool derived, bool initialised)
{
if (derived)
{
if (initialised) return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder>(Factor.DerivedVariableInit);
else return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.DerivedVariable);
}
else
{
if (initialised) return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder>(Factor.VariableInit);
else return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.Variable);
}
}
/// <summary>
/// Gets the suffix for Expectation Propagation operator methods
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetOperatorMethodSuffix(List<ICompilerAttribute> factorAttributes)
{
return "AverageConditional";
}
/// <summary>
/// Gets the suffix for Expectation Propagation evidence method
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetEvidenceMethodName(List<ICompilerAttribute> factorAttributes)
{
// If this is changed, must also change #define at top of GateEnter.cs and GateExit.cs
return "LogEvidenceRatio";
}
/// <summary>
/// Name of the algorithm
/// </summary>
public override string Name
{
get { return "ExpectationPropagation"; }
}
/// <summary>
/// Short name of the algorithm
/// </summary>
public override string ShortName
{
get { return "EP"; }
}
/// <summary>
/// Gets the operator which converts a message to/from another algorithm
/// </summary>
/// <param name="channelType">Type of message</param>
/// <param name="alg2">The other algorithm</param>
/// <param name="isFromFactor">True if from, false if to</param>
/// <param name="args">Where to add arguments of the operator</param>
/// <returns>A method reference for the operator</returns>
public override MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args)
{
if (alg2 is VariationalMessagePassing)
{
MethodReference mref = new MethodReference(typeof (ShiftAlpha), isFromFactor ? "FromFactor<>" : "ToFactor<>");
mref.TypeArguments = new Type[] {channelType};
args.Add(-1.0);
return mref;
}
throw new InferCompilerException("Cannot convert from " + Name + " to " + alg2.Name);
}
#endregion
}
}

Просмотреть файл

@ -0,0 +1,246 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Factors;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Models;
using Microsoft.ML.Probabilistic.Models.Attributes;
namespace Microsoft.ML.Probabilistic.Algorithms
{
/// <summary>
/// Gibbs sampling algorithm - includes block Gibbs sampling
/// </summary>
public class GibbsSampling : AlgorithmBase, IAlgorithm
{
public static bool DefaultSideChannels = false;
public bool UseSideChannels = DefaultSideChannels;
#region IAlgorithm Members
public override Delegate GetVariableFactor(bool derived, bool initialised)
{
if (derived) return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.DerivedVariableGibbs);
else return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.VariableGibbs);
}
/// <summary>
/// Gets the suffix for Gibbs Sampling operator methods
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetOperatorMethodSuffix(List<ICompilerAttribute> factorAttributes)
{
if (factorAttributes.Find(o => o.GetType().IsAssignableFrom(typeof (IsVariableFactor))) != null)
return "Gibbs";
else
return "AverageConditional";
}
/// <summary>
/// Gets the suffix for Gibbs Sampling evidence method
/// Evidence is not supported or supportable for Gibbs. The message
/// update methods are marked as unsupported so that an appropriate
/// error message is generated by the model compiler
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetEvidenceMethodName(List<ICompilerAttribute> factorAttributes)
{
if (factorAttributes.Find(o => o.GetType().IsAssignableFrom(typeof (IsVariableFactor))) != null)
return "GibbsEvidence";
else
return "LogEvidenceRatio";
//return "LogAverageFactor";
}
/// <summary>
/// Name of the algorithm
/// </summary>
public override string Name
{
get { return "GibbsSampling"; }
}
/// <summary>
/// Short name of the algorithm
/// </summary>
public override string ShortName
{
get { return "Gibbs"; }
}
/// <summary>
/// Gets the operator which converts a message to/from another algorithm
/// </summary>
/// <param name="channelType">Type of message</param>
/// <param name="alg2">The other algorithm</param>
/// <param name="isFromFactor">True if from, false if to</param>
/// <param name="args">Where to add arguments of the operator</param>
/// <returns>A method reference for the operator</returns>
public override MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args)
{
throw new InferCompilerException("Cannot convert from " + Name + " to " + alg2.Name);
}
/// <summary>
/// Get the message prototype in the specified direction
/// </summary>
/// <param name="channelInfo">The channel information</param>
/// <param name="direction">The direction</param>
/// <param name="marginalPrototypeExpression">The marginal prototype expression</param>
/// <param name="path">Path name of message</param>
/// <param name="queryTypes">The set of queries to support. Only used for marginal channels.</param>
/// <returns>An expression for the method prototype</returns>
public override IExpression GetMessagePrototype(
ChannelInfo channelInfo, MessageDirection direction,
IExpression marginalPrototypeExpression, string path, IList<QueryType> queryTypes)
{
Type t = null;
Type messTyp = null;
IExpression mp = null;
CodeBuilder Builder = CodeBuilder.Instance;
if (channelInfo.IsMarginal)
{
// We want the marginal variable to be a GibbsEstimator over the appropriate
// distribution type
if (direction == MessageDirection.Forwards && !UseSideChannels)
{
bool estimateMarginal = false;
bool collectSamples = false, collectDistributions = false;
foreach (QueryType qt in queryTypes)
{
if (qt.Name == "Marginal") estimateMarginal = true;
else if (qt.Name == "Samples") collectSamples = true;
else if (qt.Name == "Conditionals") collectDistributions = true;
}
Type innermostMessageType = marginalPrototypeExpression.GetExpressionType();
Type innermostElementType = Distribution.GetDomainType(innermostMessageType);
//t = MessageExpressionTransform.GetDistributionType(channelInfo.varInfo.varType, channelInfo.varInfo.innermostElementType, innermostMessageType, true);
t = MessageTransform.GetDistributionType(channelInfo.varInfo.varType, innermostElementType, innermostMessageType, true);
messTyp = typeof (GibbsMarginal<,>).MakeGenericType(t, channelInfo.varInfo.varType);
mp = Builder.NewObject(
messTyp, (t == innermostMessageType) ? marginalPrototypeExpression : Builder.DefaultExpr(t), Quoter.Quote(this.BurnIn), Quoter.Quote(this.Thin),
Quoter.Quote(estimateMarginal), Quoter.Quote(collectSamples), Quoter.Quote(collectDistributions));
}
else
mp = marginalPrototypeExpression;
return mp;
}
else
{
// Default is sample
t = marginalPrototypeExpression.GetExpressionType();
bool useSample = (path != "Distribution");
if (useSample)
{
messTyp = Distribution.GetDomainType(t);
while (messTyp.IsArray)
messTyp = messTyp.GetElementType();
mp = Builder.DefaultExpr(messTyp);
}
else
{
messTyp = t;
mp = marginalPrototypeExpression;
}
return mp;
}
}
/// <summary>
/// Allows the algorithm to modify the attributes on a factor. For example, in Gibbs sampling
/// different message types are passed depending on the context. This is signalled to the MessageTransform
/// by attaching a MessagePath attribute to the method invoke expression for the factor.
/// If the factor is a 'variable' pseudo-factor (UsesEqualsDef) then all incoming variables are
/// Distributions. Otherwise, incoming messages will depend on the grouping
/// </summary>
/// <param name="factorExpression">The factor expression</param>
/// <param name="factorAttributes">Attribute registry</param>
public override void ModifyFactorAttributes(IExpression factorExpression, AttributeRegistry<object, ICompilerAttribute> factorAttributes)
{
IList<MessagePathAttribute> mpas = factorAttributes.GetAll<MessagePathAttribute>(factorExpression);
bool isVariable = factorAttributes.Has<IsVariableFactor>(factorExpression);
if (isVariable) return;
// Process any Message Path attributes that may have been set by the Group transform
foreach (MessagePathAttribute mpa in mpas)
{
if (mpa.FromDistance >= mpa.ToDistance)
mpa.Path = "Distribution";
else
mpa.Path = "CurrentSample";
}
}
/// <summary>
/// Get the default inference query types for a variable for this algorithm.
/// </summary>
public override void ForEachDefaultQueryType(Action<QueryType> action)
{
action(QueryTypes.Marginal);
action(QueryTypes.Samples);
}
/// <summary>
/// Get the query type binding for Gibbs sampling - this is the path to the given query type
/// relative to the raw marginal type.
/// </summary>
/// <param name="qt">The query type</param>
/// <returns></returns>
public override string GetQueryTypeBinding(QueryType qt)
{
if (UseSideChannels) return null;
if (qt == QueryTypes.Marginal)
return "Distribution";
else if (qt == QueryTypes.Samples)
return "Samples";
else if (qt == QueryTypes.Conditionals)
return "Conditionals";
else
return "";
}
#endregion
private int burnIn = 100;
/// <summary>
/// The number of samples to discard at the beginning
/// </summary>
public int BurnIn
{
get { return burnIn; }
set { burnIn = value; }
}
private int thin = 5;
/// <summary>
/// Reduction factor when constructing sample and conditional lists
/// </summary>
public int Thin
{
get { return thin; }
set { thin = value; }
}
private int defaultNumberOfIterations = -1;
/// <summary>
/// Default number of iterations for Gibbs sampling
/// </summary>
public override int DefaultNumberOfIterations
{
get { return (defaultNumberOfIterations < 0) ? burnIn + 2000 : defaultNumberOfIterations; }
set { defaultNumberOfIterations = value; }
}
}
}

Просмотреть файл

@ -0,0 +1,97 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Models.Attributes
{
/// <summary>
/// Interface for inference algorithms
/// </summary>
public interface IAlgorithm
{
/// <summary>
/// The name
/// </summary>
string Name { get; }
/// <summary>
/// Short name for the inference algorithm
/// </summary>
string ShortName { get; }
Delegate GetVariableFactor(bool derived, bool initialised);
/// <summary>
/// Gets the suffix for this algorithm's operator methods
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
string GetOperatorMethodSuffix(List<ICompilerAttribute> factorAttributes);
/// <summary>
/// Gets the operator which converts a message to/from another algorithm
/// </summary>
/// <param name="channelType">Type of message</param>
/// <param name="alg2">The other algorithm</param>
/// <param name="isFromFactor">True if from, false if to</param>
/// <param name="args">Where to add arguments of the operator</param>
/// <returns>A method reference for the operator</returns>
MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args);
/// <summary>
/// Gets the suffix for this algorithm's evidence method
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
string GetEvidenceMethodName(List<ICompilerAttribute> factorAttributes);
/// <summary>
/// Get the message prototype for this algorithm in the specified direction
/// </summary>
/// <param name="channelInfo">The channel information</param>
/// <param name="direction">The direction</param>
/// <param name="marginalPrototypeExpression">The marginal prototype expression</param>
/// <param name="path">Sub-channel path</param>
/// <param name="queryTypes"></param>
/// <returns></returns>
IExpression GetMessagePrototype(
ChannelInfo channelInfo,
MessageDirection direction,
IExpression marginalPrototypeExpression,
string path,
IList<QueryType> queryTypes);
/// <summary>
/// Allows the algorithm to modify the attributes on a factor. For example context-specific
/// message attributes on a method invoke expression
/// </summary>
/// <param name="factorExpression">The expression</param>
/// <param name="factorAttributes">Attribute registry</param>
/// <returns></returns>
void ModifyFactorAttributes(IExpression factorExpression, AttributeRegistry<object, ICompilerAttribute> factorAttributes);
/// <summary>
/// Get the default inference query types for a variable for this algorithm.
/// </summary>
void ForEachDefaultQueryType(Action<QueryType> action);
/// <summary>
/// Get the query type binding - this is the path to the given query type
/// relative to the raw marginal type.
/// </summary>
/// <param name="qt">The query type</param>
/// <returns></returns>
string GetQueryTypeBinding(QueryType qt);
/// <summary>
/// Default number of iterations for this algorithm
/// </summary>
int DefaultNumberOfIterations { get; set; }
}
}

Просмотреть файл

@ -0,0 +1,90 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Factors;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Models;
using Microsoft.ML.Probabilistic.Models.Attributes;
namespace Microsoft.ML.Probabilistic.Algorithms
{
/// <summary>
/// Max product belief propagation.
/// </summary>
public class MaxProductBeliefPropagation : AlgorithmBase, IAlgorithm
{
#region IAlgorithm Members
public override Delegate GetVariableFactor(bool derived, bool initialised)
{
return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.VariableMax);
}
/// <summary>
/// Gets the suffix for Max Product operator methods
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetOperatorMethodSuffix(List<ICompilerAttribute> factorAttributes)
{
return "MaxConditional";
}
/// <summary>
/// Gets the suffix for Max Product evidence method
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetEvidenceMethodName(List<ICompilerAttribute> factorAttributes)
{
return "NotYetSupported";
}
/// <summary>
/// Name of the algorithm
/// </summary>
public override string Name
{
get { return "MaxProductBP"; }
}
/// <summary>
/// Short name of the algorithm
/// </summary>
public override string ShortName
{
get { return "MaxProd"; }
}
/// <summary>
/// Gets the operator which converts a message to/from another algorithm
/// </summary>
/// <param name="channelType">Type of message</param>
/// <param name="alg2">The other algorithm</param>
/// <param name="isFromFactor">True if from, false if to</param>
/// <param name="args">Where to add arguments of the operator</param>
/// <returns>A method reference for the operator</returns>
public override MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args)
{
throw new InferCompilerException("Cannot convert from " + Name + " to " + alg2.Name);
}
public override IExpression GetMessagePrototype(ChannelInfo channelInfo, MessageDirection direction, IExpression marginalPrototypeExpression, string path,
IList<QueryType> queryTypes)
{
if (marginalPrototypeExpression.GetExpressionType() == typeof (Discrete))
{
return CodeBuilder.Instance.StaticMethod(new Func<Discrete, UnnormalizedDiscrete>(UnnormalizedDiscrete.FromDiscrete), marginalPrototypeExpression);
}
return base.GetMessagePrototype(channelInfo, direction, marginalPrototypeExpression, path, queryTypes);
}
#endregion
}
}

Просмотреть файл

@ -0,0 +1,131 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Factors;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Models;
using Microsoft.ML.Probabilistic.Models.Attributes;
namespace Microsoft.ML.Probabilistic.Algorithms
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
/// <summary>
/// The variational message passing algorithm, see also
/// http://www.johnwinn.org/Research/VMP.html and
/// http://en.wikipedia.org/wiki/Variational_message_passing.
/// </summary>
public class VariationalMessagePassing : AlgorithmBase, IAlgorithm
{
internal class GateExitRandomVariable : ICompilerAttribute
{
}
public bool UseGateExitRandom;
public bool UseDerivMessages;
#region IAlgorithm Members
public override Delegate GetVariableFactor(bool derived, bool initialised)
{
if (derived)
{
if (initialised) return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder>(Factor.DerivedVariableInitVmp);
else return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.DerivedVariableVmp);
}
else
{
if (initialised) return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder, PlaceHolder>(Factor.VariableInit);
else return new FuncOut<PlaceHolder, PlaceHolder, PlaceHolder>(Factor.Variable);
}
}
/// <summary>
/// Gets the suffix for variational message passing operator methods
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetOperatorMethodSuffix(List<ICompilerAttribute> factorAttributes)
{
return "AverageLogarithm";
}
/// <summary>
/// Gets the suffix for variational message passing evidence method
/// </summary>
/// <param name="factorAttributes"></param>
/// <returns></returns>
public override string GetEvidenceMethodName(List<ICompilerAttribute> factorAttributes)
{
return "AverageLogFactor";
}
/// <summary>
/// Name of the algorithm
/// </summary>
public override string Name
{
get { return "VariationalMessagePassing"; }
}
/// <summary>
/// Short name of the algorithm
/// </summary>
public override string ShortName
{
get { return "VMP"; }
}
/// <summary>
/// Gets the operator which converts a message to/from another algorithm
/// </summary>
/// <param name="channelType">Type of message</param>
/// <param name="alg2">The other algorithm</param>
/// <param name="isFromFactor">True if from, false if to</param>
/// <param name="args">Where to add arguments of the operator</param>
/// <returns>A method reference for the operator</returns>
public override MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args)
{
if (alg2 is ExpectationPropagation)
{
MethodReference mref = new MethodReference(typeof (ShiftAlpha), isFromFactor ? "FromFactor<>" : "ToFactor<>");
mref.TypeArguments = new Type[] {channelType};
args.Add(1.0);
return mref;
}
throw new InferCompilerException("Cannot convert from " + Name + " to " + alg2.Name);
}
#endregion
}
/*public class VmpDeterministic : IAlgorithm
{
public string GetOperatorMethodSuffix(List<object> factorAttributes)
{
return "VmpDeterministic";
}
public MethodReference GetAlgorithmConversionOperator(Type channelType, IAlgorithm alg2, bool isFromFactor, List<object> args)
{
if (alg2 is VariationalMessagePassing)
{
return null;
}
}
public string Name
{
get { return "VariationalMessagePassing(Deterministic)"; }
}
}*/
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,33 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
/// <summary>
/// Stores debugging information to show in the transform browser. The browser will create a tab for each DebugInfo object.
/// This attribute should be attached to the top-level ITypeDeclaration produced by a transform.
/// </summary>
internal class DebugInfo : ICompilerAttribute
{
/// <summary>
/// The name of the tab in the browser.
/// </summary>
public string Name;
/// <summary>
/// A DataContext for DeclarationView. Currently this must be a code object (ITypeDeclaration or IStatement or Func&lt;SourceNode&gt;).
/// </summary>
public object Value;
/// <summary>
/// The transform in the browser that will show this tab.
/// </summary>
public ICodeTransform Transform;
public override string ToString()
{
return string.Format("DebugInfo({0},{1})", Transform == null ? "" : Transform.Name, Name);
}
}
}

Просмотреть файл

@ -0,0 +1,857 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
/// <summary>
/// Stores information about how a statement depends on other statements.
/// </summary>
internal class DependencyInformation : ICompilerAttribute, ICloneable
{
/// <summary>
/// Stores the dependency type of all dependent statements.
/// </summary>
public Dictionary<IStatement, DependencyType> dependencyTypeOf = new Dictionary<IStatement, DependencyType>(new IdentityComparer<IStatement>());
//public SortedDictionary<IStatement,DependencyType> dependencyTypeOf = new SortedDictionary<IStatement, DependencyType>(new StatementComparer());
/// <summary>
/// Stores the offsets of all dependent statements that have an offset.
/// </summary>
public Dictionary<IStatement, IOffsetInfo> offsetIndexOf = new Dictionary<IStatement, IOffsetInfo>(new IdentityComparer<IStatement>());
/// <summary>
/// True if this statement assigns to an output variable of the inference.
/// </summary>
public bool IsOutput;
/// <summary>
/// True if this statement always returns a uniform distribution or zero evidence value.
/// </summary>
public bool IsUniform;
/// <summary>
/// True if this statement must be updated whenever any dependency changes.
/// </summary>
public bool IsFresh;
/// <summary>
/// List of method arguments that the statement depends on.
/// </summary>
public Set<IParameterDeclaration> ParameterDependencies = new Set<IParameterDeclaration>(new IdentityComparer<IParameterDeclaration>());
public void AddOffsetIndices(IOffsetInfo offsetIndex, IStatement ist)
{
AddOffsetIndices(offsetIndexOf, offsetIndex, ist);
}
private static void AddOffsetIndices(Dictionary<IStatement, IOffsetInfo> offsetIndexOf, IOffsetInfo offsetIndex, IStatement ist)
{
IOffsetInfo offsetIndices;
if (!offsetIndexOf.TryGetValue(ist, out offsetIndices))
{
offsetIndexOf[ist] = offsetIndex;
}
else
{
OffsetInfo newOffsetInfo = new OffsetInfo();
newOffsetInfo.AddRange(offsetIndices);
newOffsetInfo.AddRange(offsetIndex);
offsetIndexOf[ist] = newOffsetInfo;
}
}
public bool HasAnyDependencyOfType(DependencyType type)
{
foreach (IStatement ist in GetDependenciesOfType(type))
return true;
return false;
}
public int Count(DependencyType type)
{
int count = 0;
foreach (IStatement ist in GetDependenciesOfType(type))
count++;
return count;
}
public bool HasDependency(DependencyType type, IStatement ist)
{
if (type == DependencyType.SkipIfUniform)
{
foreach (var dependencySt in SkipIfUniform)
{
if (ContainsStatement(dependencySt, ist))
{
return true;
}
}
return false;
}
else
{
DependencyType depType;
if (!dependencyTypeOf.TryGetValue(ist, out depType))
return false;
return (depType & type) > 0;
}
}
private bool ContainsStatement(IStatement dependencySt, IStatement ist)
{
if (ReferenceEquals(dependencySt, ist))
return true;
if (dependencySt is AnyStatement)
{
AnyStatement anySt = (AnyStatement)dependencySt;
bool found = false;
ForEachStatement(anySt, st =>
{
if (ContainsStatement(st, ist))
found = true;
});
return found;
}
if (dependencySt is IExpressionStatement && ist is IExpressionStatement)
{
IExpressionStatement es = (IExpressionStatement)dependencySt;
IExpressionStatement ies = (IExpressionStatement)ist;
return CodeBuilder.Instance.ContainsExpression(es.Expression, ies.Expression);
}
return false;
}
public IEnumerable<IStatement> GetDependenciesOfType(DependencyType type)
{
foreach (KeyValuePair<IStatement, DependencyType> entry in dependencyTypeOf)
{
if ((entry.Value & type) > 0)
yield return entry.Key;
}
}
public void Add(DependencyType type, IStatement ist)
{
Add(dependencyTypeOf, type, ist);
}
private static void Add(IDictionary<IStatement, DependencyType> dependencyTypeOf, DependencyType type, IStatement ist)
{
DependencyType depType;
dependencyTypeOf.TryGetValue(ist, out depType);
dependencyTypeOf[ist] = type | depType;
}
public void AddRange(DependencyType type, IEnumerable<IStatement> stmts)
{
foreach (IStatement ist in stmts)
Add(type, ist);
}
public void AddRange(IEnumerable<KeyValuePair<IStatement, DependencyType>> pairs)
{
foreach (KeyValuePair<IStatement, DependencyType> pair in pairs)
{
Add(pair.Value, pair.Key);
}
}
public void Remove(DependencyType type, IStatement stmt)
{
DependencyType depType;
if (dependencyTypeOf.TryGetValue(stmt, out depType))
{
depType = depType & (~type);
if (depType == 0)
dependencyTypeOf.Remove(stmt);
else
dependencyTypeOf[stmt] = depType;
}
}
public void Remove(DependencyType type)
{
Remove(type, _ => true);
}
public void Remove(DependencyType type, Predicate<IStatement> predicate)
{
List<IStatement> deps = new List<IStatement>();
deps.AddRange(GetDependenciesOfType(type));
foreach (IStatement ist in deps)
{
if (predicate(ist))
{
DependencyType depType = dependencyTypeOf[ist] & (~type);
if (depType == 0)
dependencyTypeOf.Remove(ist);
else
dependencyTypeOf[ist] = depType;
}
}
}
public void Remove(IStatement stmt)
{
dependencyTypeOf.Remove(stmt);
}
public void RemoveAll(IStatement stmt)
{
RemoveAll(ist => ReferenceEquals(ist, stmt));
}
public void RemoveAll(Predicate<IStatement> predicate)
{
Replace(ist => predicate(ist) ? null : ist);
}
public void Replace(IDictionary<IStatement, IStatement> replacements)
{
Replace(delegate (IStatement ist)
{
IStatement newStmt;
if (replacements.TryGetValue(ist, out newStmt))
return newStmt;
else
return ist;
});
}
public void Replace(Converter<IStatement, IStatement> converter)
{
List<IStatement> toRemove = new List<IStatement>();
Dictionary<IStatement, DependencyType> toAdd = new Dictionary<IStatement, DependencyType>(new IdentityComparer<IStatement>());
Dictionary<IStatement, IOffsetInfo> toAddOffset = new Dictionary<IStatement, IOffsetInfo>(new IdentityComparer<IStatement>());
foreach (KeyValuePair<IStatement, DependencyType> entry in dependencyTypeOf)
{
IStatement stmt = entry.Key;
IStatement newStmt = Replace(stmt, converter);
if (!ReferenceEquals(newStmt, stmt))
{
toRemove.Add(stmt);
if (newStmt != null)
{
DependencyType type = entry.Value;
IOffsetInfo offsetIndices;
offsetIndexOf.TryGetValue(stmt, out offsetIndices);
if (newStmt is AnyStatement)
{
AnyStatement anySt = (AnyStatement)newStmt;
DependencyType anyTypes = DependencyType.Requirement | DependencyType.SkipIfUniform;
DependencyType otherType = type & ~anyTypes;
if (otherType > 0)
{
// must split Any for these types
ForEachStatement(anySt, ist =>
{
Add(toAdd, otherType, ist);
if (offsetIndices != default(OffsetInfo))
AddOffsetIndices(toAddOffset, offsetIndices, ist);
});
type &= anyTypes;
}
}
if (type > 0)
{
Add(toAdd, type, newStmt);
if (offsetIndices != default(OffsetInfo))
AddOffsetIndices(toAddOffset, offsetIndices, newStmt);
}
}
}
}
foreach (IStatement ist in toRemove)
{
dependencyTypeOf.Remove(ist);
offsetIndexOf.Remove(ist);
}
foreach (KeyValuePair<IStatement, DependencyType> entry in toAdd)
{
Add(entry.Value, entry.Key);
}
foreach (KeyValuePair<IStatement, IOffsetInfo> entry in toAddOffset)
{
AddOffsetIndices(entry.Value, entry.Key);
}
}
private static IStatement Replace(IStatement stmt, Converter<IStatement, IStatement> converter)
{
IStatement newStmt = converter(stmt);
if (!ReferenceEquals(newStmt, stmt))
return newStmt;
else if (stmt is AnyStatement)
{
AnyStatement anySt = (AnyStatement)stmt;
AnyStatement newAnySt = new AnyStatement();
bool replaced = false;
foreach (IStatement ist in anySt.Statements)
{
newStmt = Replace(ist, converter);
if (!ReferenceEquals(newStmt, ist))
replaced = true;
if (newStmt != null)
{
// flatten nested Any statements
if (newStmt is AnyStatement)
newAnySt.Statements.AddRange(((AnyStatement)newStmt).Statements);
else
newAnySt.Statements.Add(newStmt);
}
}
if (replaced)
{
if (newAnySt.Statements.Count == 0)
return null;
else
return newAnySt;
}
else
return stmt;
}
else
return stmt;
}
/// <summary>
/// Change a dependency on a statement to also depend on its clones (in the same way).
/// </summary>
/// <param name="clonesOfStatement">Provides the clones of each statement that has clones. May be empty.</param>
public void AddClones(IDictionary<IStatement, IEnumerable<IStatement>> clonesOfStatement)
{
AddClones(delegate (IStatement ist)
{
IEnumerable<IStatement> clones;
if (clonesOfStatement.TryGetValue(ist, out clones))
return clones;
else
return null;
});
}
public void AddClones(Converter<IStatement, IEnumerable<IStatement>> getClones)
{
List<IStatement> toRemove = new List<IStatement>();
Dictionary<IStatement, DependencyType> toAdd = new Dictionary<IStatement, DependencyType>(new IdentityComparer<IStatement>());
Dictionary<IStatement, IOffsetInfo> toAddOffset = new Dictionary<IStatement, IOffsetInfo>(new IdentityComparer<IStatement>());
foreach (KeyValuePair<IStatement, DependencyType> entry in dependencyTypeOf)
{
IStatement stmt = entry.Key;
if (stmt is AnyStatement)
{
AnyStatement anySt = (AnyStatement)stmt;
AnyStatement newAnySt = new AnyStatement();
bool changed = false;
foreach (IStatement ist in anySt.Statements)
{
newAnySt.Statements.Add(ist);
var clones = getClones(ist);
if (clones != null)
{
changed = true;
// flatten nested Any statements
newAnySt.Statements.AddRange(clones);
}
}
if (changed)
{
toRemove.Add(stmt);
toAdd.Add(newAnySt, entry.Value);
}
}
else
{
var clones = getClones(stmt);
if (clones != null)
{
DependencyType type = entry.Value;
IOffsetInfo offsetIndices;
offsetIndexOf.TryGetValue(stmt, out offsetIndices);
if (type > 0)
{
foreach (var clone in clones)
{
Add(toAdd, type, clone);
if (offsetIndices != default(OffsetInfo))
AddOffsetIndices(toAddOffset, offsetIndices, clone);
}
}
}
}
}
foreach (IStatement ist in toRemove)
{
dependencyTypeOf.Remove(ist);
offsetIndexOf.Remove(ist);
}
foreach (KeyValuePair<IStatement, DependencyType> entry in toAdd)
{
Add(entry.Value, entry.Key);
}
foreach (KeyValuePair<IStatement, IOffsetInfo> entry in toAddOffset)
{
AddOffsetIndices(entry.Value, entry.Key);
}
}
private void ForEachStatement(AnyStatement anySt, Action<IStatement> action)
{
foreach (IStatement ist in anySt.Statements)
{
if (ist is AnyStatement)
ForEachStatement((AnyStatement)ist, action);
else
action(ist);
}
}
/// <summary>
/// Statements that modify variables used in this statement. Excludes initializers and allocations.
/// </summary>
public IEnumerable<IStatement> Dependencies
{
get
{
return GetDependenciesOfType(DependencyType.Dependency);
}
}
/// <summary>
/// Statements that allocate (or in some cases initialize) variables used in this statement.
/// </summary>
/// <remarks>
/// DeclDependencies and Dependencies must be disjoint.
/// </remarks>
public IEnumerable<IStatement> DeclDependencies
{
get
{
return GetDependenciesOfType(DependencyType.Declaration);
}
}
/// <summary>
/// Statements which determine whether or not this statement executes, or what its target is.
/// </summary>
public IEnumerable<IStatement> ContainerDependencies
{
get
{
return GetDependenciesOfType(DependencyType.Container);
}
}
/// <summary>
/// Statements which must be up-to-date before executing this statement.
/// </summary>
public IEnumerable<IStatement> FreshDependencies
{
get
{
return GetDependenciesOfType(DependencyType.Fresh);
}
}
/// <summary>
/// Statements that must be executed before this statement.
/// </summary>
/// <remarks>
/// AnyStatements can be used to create disjunctive requirements, e.g. "either A or B must execute before this statement".
/// </remarks>
public IEnumerable<IStatement> Requirements
{
get
{
return GetDependenciesOfType(DependencyType.Requirement);
}
}
/// <summary>
/// Statements that must be executed before this statement, and must return a non-uniform result.
/// </summary>
/// <remarks>
/// AnyStatements can be used to create disjunctive requirements, e.g. "either A or B must execute before this statement".
/// </remarks>
public IEnumerable<IStatement> SkipIfUniform
{
get
{
return GetDependenciesOfType(DependencyType.SkipIfUniform);
}
}
/// <summary>
/// Gets statements that modify or allocate variables that this statement mutates.
/// </summary>
public IEnumerable<IStatement> Overwrites
{
get
{
return GetDependenciesOfType(DependencyType.Overwrite);
}
}
/// <summary>
/// Statements whose execution invalidates the result of this statement.
/// </summary>
public IEnumerable<IStatement> Triggers
{
get
{
return GetDependenciesOfType(DependencyType.Trigger);
}
}
public object Clone()
{
DependencyInformation that = new DependencyInformation();
that.dependencyTypeOf = Clone(dependencyTypeOf);
// the OffsetInfo values inside the offsetIndexOf dictionary are not cloned.
that.offsetIndexOf = Clone(offsetIndexOf);
that.IsOutput = IsOutput;
that.IsUniform = IsUniform;
that.IsFresh = IsFresh;
that.ParameterDependencies = Clone(ParameterDependencies);
return that;
}
private List<T> Clone<T>(List<T> list)
{
List<T> result = new List<T>();
result.AddRange(list);
return result;
}
private T Clone<T>(T set)
where T : ICloneable
{
return (T)set.Clone();
}
private Dictionary<TKey, TValue> Clone<TKey, TValue>(Dictionary<TKey, TValue> that)
{
Dictionary<TKey, TValue> result = new Dictionary<TKey, TValue>(that.Comparer);
foreach (KeyValuePair<TKey, TValue> entry in that)
{
result[entry.Key] = entry.Value;
}
return result;
}
private SortedDictionary<TKey, TValue> Clone<TKey, TValue>(SortedDictionary<TKey, TValue> that)
{
SortedDictionary<TKey, TValue> result = new SortedDictionary<TKey, TValue>(that.Comparer);
foreach (KeyValuePair<TKey, TValue> entry in that)
{
result[entry.Key] = entry.Value;
}
return result;
}
public override string ToString()
{
StringBuilder sb = new StringBuilder();
sb.AppendLine("Dependency information:");
sb.AppendLine(" IsOutput=" + IsOutput + " DependsOnParameters=" + StringUtil.CollectionToString(ParameterDependencies, ","));
sb.AppendLine(" IsUniform=" + IsUniform + " IsFresh=" + IsFresh);
sb.AppendLine(StringUtil.JoinColumns(" ContainerDependencies=", StringUtil.VerboseToString(ContainerDependencies)));
sb.AppendLine(StringUtil.JoinColumns(" DeclDependencies=", StringUtil.VerboseToString(DeclDependencies)));
sb.AppendLine(StringUtil.JoinColumns(" Dependencies=", StringUtil.VerboseToString(Dependencies)));
sb.AppendLine(StringUtil.JoinColumns(" Requirements=", StringUtil.VerboseToString(Requirements)));
sb.AppendLine(StringUtil.JoinColumns(" SkipIfUniform=", StringUtil.VerboseToString(SkipIfUniform)));
sb.AppendLine(StringUtil.JoinColumns(" Triggers=", StringUtil.VerboseToString(Triggers)));
sb.AppendLine(StringUtil.JoinColumns(" FreshDependencies=", StringUtil.VerboseToString(FreshDependencies)));
sb.AppendLine(StringUtil.JoinColumns(" Overwrites=", StringUtil.ToString(Overwrites)));
if (HasAnyDependencyOfType(DependencyType.Cancels))
sb.AppendLine(StringUtil.JoinColumns(" Cancels=", StringUtil.ToString(GetDependenciesOfType(DependencyType.Cancels))));
if (HasAnyDependencyOfType(DependencyType.NoInit))
sb.AppendLine(StringUtil.JoinColumns(" NoInit=", StringUtil.ToString(GetDependenciesOfType(DependencyType.NoInit))));
if (HasAnyDependencyOfType(DependencyType.Diode))
sb.AppendLine(StringUtil.JoinColumns(" Diode=", StringUtil.ToString(GetDependenciesOfType(DependencyType.Diode))));
if (offsetIndexOf.Count > 0)
{
StringBuilder sb2 = new StringBuilder();
int count = 0;
foreach (KeyValuePair<IStatement, IOffsetInfo> entry in offsetIndexOf)
{
if (count > 0)
sb2.AppendLine();
sb2.Append("[");
sb2.Append(count++);
sb2.Append("] ");
IStatement ist = entry.Key;
sb2.Append(entry.Value);
sb2.Append(" ");
sb2.Append(ist);
}
sb.AppendLine(StringUtil.JoinColumns(" OffsetIndices=", sb2.ToString()));
}
return sb.ToString();
}
protected string ToString(IList<IStatement> ls)
{
StringBuilder sb = new StringBuilder("[");
foreach (IStatement st in ls)
sb.AppendLine(st.ToString());
sb.AppendLine("]");
return sb.ToString();
}
#if false
public override bool Equals(object obj)
{
DependencyInformation di = obj as DependencyInformation;
if (di == null) return false;
if (di.IsUniform != IsUniform) return false;
if (ParameterDependencies != di.ParameterDependencies) return false;
if (di.IsOutput != IsOutput) return false;
if (!SetsAreEqual(Dependencies, di.Dependencies)) return false;
// if (!SetEquals(DeclDependencies,di.DeclDependencies)) return false;
if (!SetsAreEqual(Requirements, di.Requirements)) return false;
//if (!SetEquals(RequiredNumberSet,di.RequiredNumberSet)) return false;
if (!SetsAreEqual(Triggers, di.Triggers)) return false;
if (!SetsAreEqual(FreshDependencies, di.FreshDependencies)) return false;
if (!SetsAreEqual(Initializers, di.Initializers)) return false;
return true;
}
public override int GetHashCode()
{
int hash = Hash.Start;
hash = Hash.Combine(hash, IsUniform.GetHashCode());
hash = Hash.Combine(hash, ParameterDependencies.GetHashCode());
hash = Hash.Combine(hash, IsOutput.GetHashCode());
hash = Hash.Combine(hash, Enumerable.GetHashCodeAsSet(Dependencies));
hash = Hash.Combine(hash, Enumerable.GetHashCodeAsSet(Requirements));
hash = Hash.Combine(hash, Enumerable.GetHashCodeAsSet(Triggers));
hash = Hash.Combine(hash, Enumerable.GetHashCodeAsSet(FreshDependencies));
hash = Hash.Combine(hash, Enumerable.GetHashCodeAsSet(Initializers));
return hash;
}
protected static bool SetsAreEqual(List<IStatement> l1, List<IStatement> l2)
{
if (l1.Count != l2.Count) return false;
foreach (IStatement ist in l1) {
if (!l2.Contains(ist)) return false;
}
return true;
}
#endif
// Compares statements by lexicographic order
internal class StatementComparer : IComparer<IStatement>
{
public int Compare(IStatement x, IStatement y)
{
return String.Compare(x.ToString(), y.ToString(), StringComparison.InvariantCulture);
}
}
}
[Flags]
internal enum DependencyType
{
/// <summary>
/// Statements that modify variables read by this statement, i.e. read-after-write dependencies. Excludes allocations.
/// </summary>
Dependency = 1,
/// <summary>
/// Statements that must be executed before this statement.
/// </summary>
/// <remarks>
/// AnyStatements can be used to create disjunctive requirements, e.g. "either A or B must execute before this statement".
/// </remarks>
Requirement = 2,
/// <summary>
/// Statements that must be executed before this statement, and must return a non-uniform result.
/// </summary>
/// <remarks>
/// AnyStatements can be used to create disjunctive requirements, e.g. "either A or B must execute before this statement".
/// </remarks>
SkipIfUniform = 4,
/// <summary>
/// Statements whose execution invalidates the result of this statement.
/// </summary>
Trigger = 8,
/// <summary>
/// Statements which must be up-to-date before executing this statement.
/// </summary>
Fresh = 16,
/// <summary>
/// Statements that allocate variables read by this statement.
/// </summary>
/// <remarks>
/// DeclDependencies and Dependencies must be disjoint.
/// </remarks>
Declaration = 32,
/// <summary>
/// Statements that modify variables used in the containers of this statement.
/// </summary>
Container = 64,
/// <summary>
/// Statements that modify or allocate variables that this statement modifies.
/// </summary>
Overwrite = 128,
Cancels = 256,
NoInit = 512,
Diode = 1024,
All = 2047
};
internal class AllTriggersAttribute : ICompilerAttribute
{
}
internal class Offset
{
public readonly IVariableDeclaration loopVar;
public readonly int offset;
public readonly bool isAvailable;
public Offset(IVariableDeclaration loopVar, int offset, bool isAvailable)
{
this.loopVar = loopVar;
this.offset = offset;
this.isAvailable = isAvailable;
}
}
/// <summary>
/// Read-only interface to OffsetInfo.
/// </summary>
internal interface IOffsetInfo : IEnumerable<Offset>
{
bool ContainsKey(IVariableDeclaration ivd);
}
internal class OffsetInfo : ICollection<Offset>, IOffsetInfo
{
HashSet<Offset> offsetOfVar;
public IEnumerator<Offset> GetEnumerator()
{
if (offsetOfVar == null)
return new HashSet<Offset>().GetEnumerator();
return offsetOfVar.GetEnumerator();
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public bool ContainsKey(IVariableDeclaration ivd)
{
if (offsetOfVar == null)
return false;
foreach (var entry in this)
{
if (entry.loopVar.Equals(ivd))
return true;
}
return false;
}
public override string ToString()
{
StringBuilder sb = new StringBuilder();
sb.Append("OffsetInfo");
if (offsetOfVar != null)
{
foreach (var entry in offsetOfVar)
{
sb.Append("(");
sb.Append(entry.loopVar.Name);
sb.Append(",");
sb.Append(entry.offset);
sb.Append(",");
sb.Append(entry.isAvailable);
sb.Append(")");
}
}
return sb.ToString();
}
/// <summary>
/// Add an offset dependency
/// </summary>
/// <param name="ivd">Loop counter</param>
/// <param name="offset">(loop counter of write) - (loop counter of read)</param>
/// <param name="isAvailable">True if the first affected element will be mutated at an earlier loop iteration, based on the loop direction</param>
public void Add(IVariableDeclaration ivd, int offset, bool isAvailable)
{
Add(new Offset(ivd, offset, isAvailable));
}
public void Add(Offset item)
{
if (offsetOfVar == null)
offsetOfVar = new HashSet<Offset>();
offsetOfVar.Add(item);
}
public void Clear()
{
if (offsetOfVar != null)
offsetOfVar.Clear();
}
public bool Contains(Offset item)
{
throw new NotImplementedException();
}
public void CopyTo(Offset[] array, int arrayIndex)
{
throw new NotImplementedException();
}
public int Count
{
get
{
return (offsetOfVar == null) ? 0 : offsetOfVar.Count;
}
}
public bool IsReadOnly
{
get
{
return false;
}
}
public bool Remove(Offset item)
{
throw new NotImplementedException();
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,28 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
/// <summary>
/// Specifies a description for whatever the attribute is attached to.
/// </summary>
internal class DescriptionAttribute : ICompilerAttribute
{
/// <summary>
/// The description for the attribute.
/// </summary>
public string Description { get; private set; }
/// <summary>
/// Initializes a new instance of the <see cref="DescriptionAttribute"/> class.
/// </summary>
/// <param name="description">The description for the attribute</param>
public DescriptionAttribute(string description)
{
this.Description = description;
}
}
}

Просмотреть файл

@ -0,0 +1,109 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
/// <summary>
/// Represents a loop context i.e. the set of loops that an expression occurs in.
/// </summary>
internal class LoopContext : ICompilerAttribute
{
/// <summary>
/// Helps recognize code patterns
/// </summary>
private static CodeRecognizer Recognizer = CodeRecognizer.Instance;
/// <summary>
/// Loops that contain the expression, outermost first.
/// </summary>
internal List<IForStatement> loops; // = new List<IForStatement>();
/// <summary>
/// The loop variables for all contained loops
/// </summary>
internal List<IVariableDeclaration> loopVariables = new List<IVariableDeclaration>();
/// <summary>
/// Creates a loop context, given the current transform context.
/// </summary>
internal LoopContext(BasicTransformContext context) : this(context.FindAncestors<IForStatement>())
{
}
/// <summary>
/// Creates a loop context, given the current transform context.
/// </summary>
internal LoopContext(List<IForStatement> loops)
{
this.loops = loops;
foreach (IForStatement loop in loops)
{
loopVariables.Add(Recognizer.LoopVariable(loop));
}
}
/// <summary>
/// Gets the reference loop context for a reference to a local variable. A reference loop context
/// is the set of loops that a variable reference occurs in, less any loops that the variable declaration
/// occurred in.
/// </summary>
/// <returns></returns>
internal RefLoopContext GetReferenceLoopContext(BasicTransformContext context)
{
List<IForStatement> loops = context.FindAncestors<IForStatement>();
RefLoopContext rlc = new RefLoopContext();
try
{
foreach (IForStatement loop in loops)
{
IVariableDeclaration loopVar = Recognizer.LoopVariable(loop);
if (loopVariables.Contains(loopVar)) continue;
rlc.loopVariables.Add(loopVar);
rlc.loops.Add(loop);
}
}
catch (Exception ex)
{
context.Error("Could not get loop index variables", ex);
}
return rlc;
}
public override string ToString()
{
return "LoopContext" + Util.CollectionToString(loopVariables);
}
}
/// <summary>
/// Represents a reference loop context i.e. the set of loops that a variable reference
/// occurs in, less any loops that the variable declaration occurred in.
/// </summary>
internal class RefLoopContext
{
// Loops that the reference is in that the declaration isn't
internal List<IForStatement> loops = new List<IForStatement>();
// Indices for the above loops
internal List<IVariableDeclaration> loopVariables = new List<IVariableDeclaration>();
public override string ToString()
{
return "RefLoopContext" + Util.CollectionToString(loopVariables);
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,283 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Compiler.Graphs;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Compiler;
using NodeIndex = System.Int32;
using EdgeIndex = System.Int32;
using Collections;
using Utilities;
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
internal class LoopMergingInfo : ICompilerAttribute
{
/// <summary>
/// Maps statements into graph node numbers
/// </summary>
private Dictionary<IStatement, NodeIndex> indexOf = new Dictionary<IStatement, NodeIndex>(new IdentityComparer<IStatement>());
/// <summary>
/// A graph where nodes are statements and edges indicate that loop merging is prohibited between them.
/// </summary>
public IndexedGraph graph;
/// <summary>
/// For each graph edge, stores the loop variables on which merging is prohibited.
/// </summary>
private IndexedProperty<EdgeIndex, ICollection<IVariableDeclaration>> prohibitedLoopVars;
/// <summary>
/// For each graph edge, stores a pairing of loop variables and their offsets.
/// </summary>
public IndexedProperty<EdgeIndex, IOffsetInfo> offsetInfos;
public LoopMergingInfo(IList<IStatement> stmts)
{
graph = new IndexedGraph();
foreach (var stmt in stmts)
{
indexOf[stmt] = graph.AddNode();
}
prohibitedLoopVars = graph.CreateEdgeData<ICollection<IVariableDeclaration>>();
offsetInfos = graph.CreateEdgeData<IOffsetInfo>();
}
/// <summary>
/// Get the index of a top-level statement.
/// </summary>
/// <param name="statement">A top-level statement.</param>
/// <returns></returns>
public NodeIndex GetIndexOf(IStatement statement)
{
return indexOf[statement];
}
/// <summary>
/// Add a statement that shares all conflicts with an existing node.
/// </summary>
/// <param name="statement"></param>
/// <param name="node"></param>
public void AddEquivalentStatement(IStatement statement, NodeIndex node)
{
indexOf[statement] = node;
}
/// <summary>
/// Add a new statement with no conflicts.
/// </summary>
/// <param name="statement"></param>
/// <returns></returns>
public NodeIndex AddNode(IStatement statement)
{
NodeIndex node = graph.AddNode();
indexOf[statement] = node;
return node;
}
public void InheritSourceConflicts(NodeIndex newNode, NodeIndex oldNode)
{
foreach (EdgeIndex edge in graph.EdgesInto(oldNode).ToArray())
{
int source = graph.SourceOf(edge);
EdgeIndex edge2 = graph.AddEdge(source, newNode);
prohibitedLoopVars[edge2] = prohibitedLoopVars[edge];
offsetInfos[edge2] = offsetInfos[edge];
}
}
public void InheritTargetConflicts(NodeIndex newNode, NodeIndex oldNode)
{
foreach (EdgeIndex edge in graph.EdgesOutOf(oldNode).ToArray())
{
int source = graph.SourceOf(edge);
EdgeIndex edge2 = graph.AddEdge(source, newNode);
prohibitedLoopVars[edge2] = prohibitedLoopVars[edge];
offsetInfos[edge2] = offsetInfos[edge];
}
}
/// <summary>
/// Get the index of the statement that prevents loop merging, or -1 if none
/// </summary>
/// <param name="stmts"></param>
/// <param name="stmtIndex"></param>
/// <param name="loopVar"></param>
/// <param name="isForwardLoop"></param>
/// <returns></returns>
public int GetConflictingStmt(Set<int> stmts, int stmtIndex, IVariableDeclaration loopVar, bool isForwardLoop)
{
foreach (EdgeIndex edge in graph.EdgesInto(stmtIndex))
{
int source = graph.SourceOf(edge);
if (stmts.Contains(source) && IsProhibited(edge, loopVar, isForwardLoop, true))
return source;
}
foreach (EdgeIndex edge in graph.EdgesOutOf(stmtIndex))
{
int target = graph.TargetOf(edge);
if (stmts.Contains(target) && IsProhibited(edge, loopVar, isForwardLoop, false))
return target;
}
return -1;
}
private bool IsProhibited(int edge, IVariableDeclaration loopVar, bool isForwardLoop, bool isForwardEdge)
{
ICollection<IVariableDeclaration> prohibited = prohibitedLoopVars[edge];
if (prohibited != null && prohibited.Contains(loopVar))
return true;
IOffsetInfo offsetInfo = offsetInfos[edge];
if (offsetInfo != null)
{
foreach (var entry in offsetInfo)
{
if (entry.loopVar == loopVar)
{
int offset = entry.offset;
if ((offset > 0) && isForwardLoop && isForwardEdge)
return true;
if ((offset < 0) && !isForwardLoop && isForwardEdge)
return true;
}
}
}
return false;
}
public void PreventLoopMerging(IStatement mutated, IStatement affected, ICollection<IVariableDeclaration> loopVars)
{
int edge;
int source = indexOf[mutated];
int target = indexOf[affected];
if (!graph.TryGetEdge(source, target, out edge))
{
edge = graph.AddEdge(source, target);
prohibitedLoopVars[edge] = loopVars;
}
else
{
ICollection<IVariableDeclaration> list = prohibitedLoopVars[edge];
if (list == null)
{
list = new Set<IVariableDeclaration>(new IdentityComparer<IVariableDeclaration>());
prohibitedLoopVars[edge] = list;
}
list.AddRange(loopVars);
}
}
public void SetOffsetInfo(IStatement mutated, IStatement affected, IOffsetInfo offsetInfo)
{
int edge;
int source = indexOf[mutated];
int target = indexOf[affected];
if (!graph.TryGetEdge(source, target, out edge))
{
edge = graph.AddEdge(source, target);
prohibitedLoopVars[edge] = null;
offsetInfos[edge] = offsetInfo;
}
else
{
offsetInfos[edge] = offsetInfo;
}
}
public IStatement GetStatement(int index)
{
foreach (KeyValuePair<IStatement, int> entry in indexOf)
{
if (entry.Value == index)
return entry.Key;
}
return null;
}
public string VerboseToString()
{
StringBuilder sb = new StringBuilder();
for (int edge = 0; edge < graph.EdgeCount(); edge++)
{
ICollection<IVariableDeclaration> list = prohibitedLoopVars[edge];
string indexString = (list == null) ? "" : list.Aggregate("", (s, ivd) => (s + ivd.Name + " "));
string stmtString = GetStatement(graph.SourceOf(edge)).ToString() + GetStatement(graph.TargetOf(edge)).ToString();
sb.AppendLine(StringUtil.JoinColumns(indexString, stmtString));
}
return sb.ToString();
}
public DebugInfo GetDebugInfo(ICodeTransform transform)
{
CodeBuilder Builder = CodeBuilder.Instance;
IBlockStatement block = Builder.BlockStmt();
bool includeStatementNumbers = false;
if (includeStatementNumbers)
{
List<List<IStatement>> stmts = new List<List<IStatement>>();
foreach (var entry in indexOf)
{
IStatement ist = entry.Key;
int index = entry.Value;
while (stmts.Count <= index)
stmts.Add(new List<IStatement>());
stmts[index].Add(ist);
}
for (int i = 0; i < stmts.Count; i++)
{
block.Statements.Add(Builder.CommentStmt(i.ToString()));
foreach (var ist in stmts[i])
{
block.Statements.Add(ist);
}
}
}
for (int edge = 0; edge < graph.EdgeCount(); edge++)
{
ICollection<IVariableDeclaration> list = prohibitedLoopVars[edge];
IBlockStatement body = Builder.BlockStmt();
body.Statements.Add(GetStatement(graph.SourceOf(edge)));
body.Statements.Add(GetStatement(graph.TargetOf(edge)));
if (list != null)
{
foreach (IVariableDeclaration ivd in list)
{
IForEachStatement ifes = Builder.ForEachStmt();
ifes.Variable = ivd;
ifes.Expression = null;
ifes.Body = body;
body = Builder.BlockStmt();
body.Statements.Add(ifes);
}
}
if (body.Statements.Count == 1)
block.Statements.Add(body.Statements[0]);
else
block.Statements.Add(body);
}
DebugInfo info = new DebugInfo();
info.Transform = transform;
info.Name = "LoopMergingInfo";
info.Value = block;
return info;
}
public override string ToString()
{
return "LoopMergingInfo";
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
internal class MultiplyAllCompilerAttribute : ICompilerAttribute
{
}
}

Просмотреть файл

@ -0,0 +1,19 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Factors.Attributes;
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
internal class QualityBandCompilerAttribute : ICompilerAttribute
{
public QualityBandCompilerAttribute(QualityBand qualityBand)
{
this.QualityBand = qualityBand;
}
public QualityBand QualityBand { get; private set; }
}
}

Просмотреть файл

@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
internal class QueryTypeCompilerAttribute : ICompilerAttribute
{
public QueryTypeCompilerAttribute(QueryType queryType)
{
this.QueryType = queryType;
}
public QueryType QueryType { get; private set; }
}
}

Просмотреть файл

@ -0,0 +1,109 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
/// <summary>
/// Represents a repeat context i.e. the set of repeat blocks that an expression occurs in.
/// </summary>
internal class RepeatContext : ICompilerAttribute
{
/// <summary>
/// Helps recognize code patterns
/// </summary>
private static CodeRecognizer Recognizer = CodeRecognizer.Instance;
/// <summary>
/// Repeat blocks that contain the expression, outermost first.
/// </summary>
internal List<IRepeatStatement> repeats; // = new List<IForStatement>();
/// <summary>
/// The repeat counts for all contained repeat blocks
/// </summary>
internal List<IExpression> repeatCounts = new List<IExpression>();
/// <summary>
/// Creates a repeat context, given the current transform context.
/// </summary>
internal RepeatContext(BasicTransformContext context) : this(context.FindAncestors<IRepeatStatement>())
{
}
/// <summary>
/// Creates a repeat context, given the current transform context.
/// </summary>
internal RepeatContext(List<IRepeatStatement> repeats)
{
this.repeats = repeats;
foreach (IRepeatStatement rep in repeats)
{
repeatCounts.Add(rep.Count);
}
}
/// <summary>
/// Gets the reference loop context for a reference to a local variable. A reference loop context
/// is the set of loops that a variable reference occurs in, less any loops that the variable declaration
/// occurred in.
/// </summary>
/// <returns></returns>
internal RefRepeatContext GetReferenceRepeatContext(BasicTransformContext context)
{
List<IRepeatStatement> reps = context.FindAncestors<IRepeatStatement>();
// Make a cloned list of repeat counts and remove when found
var rcs = new List<IExpression>(repeatCounts);
RefRepeatContext rlc = new RefRepeatContext();
foreach (IRepeatStatement rep in reps)
{
IExpression repCount = rep.Count;
int k = rcs.IndexOf(repCount);
if (k != -1)
{
rcs.RemoveAt(k); // remove this count so we don't use it again.
continue;
}
rlc.repeatCounts.Add(repCount);
rlc.repeats.Add(rep);
}
return rlc;
}
public override string ToString()
{
return "RepeatContext" + Util.CollectionToString(repeatCounts);
}
}
/// <summary>
/// Represents a reference loop context i.e. the set of loops that a variable reference
/// occurs in, less any loops that the variable declaration occurred in.
/// </summary>
internal class RefRepeatContext
{
// Repeat blocks that the reference is in that the declaration isn't
internal List<IRepeatStatement> repeats = new List<IRepeatStatement>();
// Counts for the above repeat blocks
internal List<IExpression> repeatCounts = new List<IExpression>();
public override string ToString()
{
return "RefLoopContext" + Util.CollectionToString(repeatCounts);
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,825 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Models.Attributes;
namespace Microsoft.ML.Probabilistic.Compiler.Attributes
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
/// <summary>
/// Describes a variable in MSL (random, constant, or loop variable)
/// </summary>
/// <remarks>
/// </remarks>
internal class VariableInformation : ICompilerAttribute
{
/// <summary>
/// Helps build class declarations
/// </summary>
private static readonly CodeBuilder Builder = CodeBuilder.Instance;
/// <summary>
/// Helps recognize code patterns
/// </summary>
private static CodeRecognizer Recognizer = CodeRecognizer.Instance;
/// <summary>
/// Stores the lengths that were used to define an array in MSL.
/// </summary>
internal IList<IExpression[]> sizes = new List<IExpression[]>();
/// <summary>
/// For jagged arrays, the index variables used in the loops corresponding to the above sizes.
/// </summary>
/// <remarks>
/// May contain null elements for indices that are not variables.
/// </remarks>
internal IList<IVariableDeclaration[]> indexVars = new List<IVariableDeclaration[]>();
/// <summary>
/// True if this is a stochastic variable. False for constants and loop variables.
/// </summary>
public bool IsStochastic;
// Marginal prototype (this is the prototype of an element, if this is an array)
internal IExpression marginalPrototypeExpression;
//internal Type marginalType;
internal readonly IVariableDeclaration declaration;
public string Name
{
get { return declaration.Name; }
}
public IType VariableType
{
get { return declaration.VariableType; }
}
/// <summary>
/// A cache of ToType(VariableType)
/// </summary>
internal readonly Type varType;
private Type innermostElementType;
internal Type InnermostElementType
{
get
{
if (innermostElementType == null)
{
innermostElementType = varType;
for (int bracket = 0; bracket < sizes.Count; bracket++)
{
innermostElementType = Util.GetElementType(innermostElementType);
}
}
return innermostElementType;
}
}
/// <summary>
/// The number of the indexing bracket which always appears indexed by literals, or 0 if none
/// </summary>
public int LiteralIndexingDepth;
/// <summary>
/// jagged array depth of varType
/// </summary>
private readonly int arrayDepth;
/// <summary>
/// Returns the array depth of this variable. This is the number of pairs of square brackets needed
/// to fully index the variable i.e. 0 for a non-array, 1 for x[] or x[,], 2 for x[][] or x[,][] or x[,][,] etc.
/// </summary>
public int ArrayDepth
{
get { return arrayDepth; }
}
public VariableInformation(IVariableDeclaration declaration)
{
this.declaration = declaration;
varType = Builder.ToType(declaration.VariableType);
var elementType = varType;
arrayDepth = 0;
while (elementType.IsArray)
{
elementType = elementType.GetElementType();
arrayDepth++;
}
}
public void SetSizesAtDepth(int depth, IExpression[] lengths)
{
if (sizes.Count > depth) throw new NotSupportedException("Attempt to redefine sizes at depth " + depth + ".");
if (sizes.Count < depth) throw new InferCompilerException("Attempt to set sizes at depth " + depth + " before depth " + (depth - 1) + ".");
sizes.Add(lengths);
}
/// <summary>
/// Provide missing index variables.
/// </summary>
/// <param name="depth">Bracket depth (0 is first bracket)</param>
/// <param name="vars">May contain null entries for indices that are not variables.</param>
/// <param name="allowMismatch"></param>
public void SetIndexVariablesAtDepth(int depth, IVariableDeclaration[] vars, bool allowMismatch = false)
{
if (indexVars.Count > depth)
{
for (int i = 0; i < vars.Length; i++)
{
if (vars[i] != null)
{
if (indexVars[depth][i] == null)
{
indexVars[depth][i] = vars[i];
}
else if (vars[i] != indexVars[depth][i] && !allowMismatch)
{
throw new ArgumentException("Invalid definition of array '" + this.Name + "'. Variable '" + vars[i].Name +
"' cannot be used as an index on the left hand side. Must use '" + indexVars[depth][i].Name + "'.");
}
}
}
return;
}
else if (indexVars.Count < depth) throw new InferCompilerException("Attempt to set index var at depth " + depth + " before depth " + (depth - 1) + ".");
else indexVars.Add(vars);
}
public static string GenerateName(BasicTransformContext context, string prefix)
{
int ancIndex = context.FindAncestorIndex<ITypeDeclaration>();
object input = context.GetAncestor(ancIndex);
NameGenerator ng = context.InputAttributes.Get<NameGenerator>(input);
if (ng == null)
{
ng = new NameGenerator();
context.InputAttributes.Set(input, ng);
object output = context.GetOutputForAncestorIndex<object>(ancIndex);
context.OutputAttributes.Set(output, ng);
}
return ng.GenerateName(prefix);
}
public static IVariableDeclaration GenerateLoopVar(BasicTransformContext context, string prefix)
{
IVariableDeclaration ivd = Builder.VarDecl(GenerateName(context, prefix), typeof(int));
//VariableInformation.GetVariableInformation(context, ivd);
return ivd;
}
public void DefineAllIndexVars(BasicTransformContext context)
{
DefineIndexVarsUpToDepth(context, sizes.Count);
}
public void DefineIndexVarsUpToDepth(BasicTransformContext context, int depth)
{
for (int d = 0; d < depth; d++)
{
for (int i = 0; i < sizes[d].Length; i++)
{
IVariableDeclaration v = (indexVars.Count <= d) ? null : indexVars[d][i];
if (v == null)
{
v = GenerateLoopVar(context, "_iv");
}
if (indexVars.Count == d) indexVars.Add(new IVariableDeclaration[sizes[d].Length]);
indexVars[d][i] = v;
}
}
}
public List<IList<IExpression>> GetIndexExpressions(BasicTransformContext context, int depth)
{
DefineIndexVarsUpToDepth(context, depth);
List<IList<IExpression>> indexExprs = new List<IList<IExpression>>();
for (int d = 0; d < depth; d++)
{
IList<IExpression> bracketExprs = Builder.ExprCollection();
for (int i = 0; i < indexVars[d].Length; i++)
{
IVariableDeclaration indexVar = indexVars[d][i];
bracketExprs.Add(Builder.VarRefExpr(indexVar));
}
indexExprs.Add(bracketExprs);
}
return indexExprs;
}
public void DefineSizesUpToDepth(BasicTransformContext context, int arrayDepth)
{
IExpression sourceArray = Builder.VarRefExpr(declaration);
for (int depth = 0; depth < arrayDepth; depth++)
{
bool notLast = (depth < arrayDepth - 1);
int rank;
Type arrayType = sourceArray.GetExpressionType();
Util.GetElementType(arrayType, out rank);
if (sizes.Count <= depth) sizes.Add(new IExpression[rank]);
IExpression[] indices = new IExpression[rank];
for (int i = 0; i < rank; i++)
{
if (sizes.Count <= depth || sizes[depth][i] == null)
{
if (rank == 1)
{
sizes[depth][i] = Builder.PropRefExpr(sourceArray, arrayType, arrayType.IsArray ? "Length" : "Count", typeof(int));
}
else
{
sizes[depth][i] = Builder.Method(sourceArray, typeof(Array).GetMethod("GetLength"), Builder.LiteralExpr(i));
}
}
if (notLast)
{
if (indexVars.Count <= depth) indexVars.Add(new IVariableDeclaration[rank]);
IVariableDeclaration v = indexVars[depth][i];
if (v == null)
{
v = GenerateLoopVar(context, "_iv");
indexVars[depth][i] = v;
}
indices[i] = Builder.VarRefExpr(v);
}
}
if (notLast) sourceArray = Builder.ArrayIndex(sourceArray, indices);
}
}
/// <summary>
/// Gets the VariableInformation attribute of ivd, or creates one if it doesn't already exist
/// </summary>
internal static VariableInformation GetVariableInformation(BasicTransformContext context, object target)
{
IVariableDeclaration ivd = null;
if (target is IVariableDeclaration) ivd = (IVariableDeclaration)target;
else if (target is IParameterDeclaration)
{
IParameterDeclaration ipd = (IParameterDeclaration)target;
ivd = Builder.VarDecl(ipd.Name, ipd.ParameterType);
}
else if (target is IFieldDeclaration)
{
IFieldDeclaration ifd = (IFieldDeclaration)target;
ivd = Builder.VarDecl(ifd.Name, ifd.FieldType);
}
else throw new ArgumentException("target is not a variable or parameter");
VariableInformation vi = context.InputAttributes.Get<VariableInformation>(target);
if (vi == null)
{
vi = new VariableInformation(ivd);
context.InputAttributes.Set(target, vi);
}
return vi;
}
public override string ToString()
{
StringBuilder s = new StringBuilder();
string stocString = IsStochastic ? "stoc " : "";
foreach (IExpression[] lengths in sizes)
{
s.Append('[');
bool notFirst = false;
foreach (IExpression length in lengths)
{
if (notFirst) s.Append(",");
else notFirst = true;
if (length != null) s.Append(length);
}
s.Append(']');
}
string sizesString = s.ToString();
s = new StringBuilder();
foreach (IVariableDeclaration[] vars in indexVars)
{
s.Append('[');
bool notFirst = false;
foreach (IVariableDeclaration v in vars)
{
if (notFirst) s.Append(",");
else notFirst = true;
if (v != null) s.Append(v.Name);
//if (v != null) s.Append("("+System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(v)+")");
}
s.Append(']');
}
string indexVarsString = s.ToString();
if (indexVarsString.Length > 0) indexVarsString = ",indexVars=" + indexVarsString;
string optionString = (LiteralIndexingDepth != 0) ? (",LiteralIndexingDepth=" + LiteralIndexingDepth) : "";
return "VariableInformation(" + stocString + declaration + sizesString + indexVarsString + optionString + ")";
}
private bool TrySetMarginalPrototypeAutomatically()
{
Type marginalType;
if (this.InnermostElementType == typeof(double))
{
marginalType = typeof(Gaussian);
}
else if (this.InnermostElementType == typeof(bool))
{
marginalType = typeof(Bernoulli);
}
else if (this.InnermostElementType == typeof(string))
{
marginalType = typeof(StringDistribution);
}
else if (this.InnermostElementType == typeof(char))
{
marginalType = typeof(DiscreteChar);
}
else
{
return false;
}
this.marginalPrototypeExpression = Builder.NewObject(marginalType);
return true;
}
/// <summary>
/// Sets the marginal prototype from a supplied MarginalPrototype attribute.
/// If this is null, attempts to set the marginal prototype automatically.
/// </summary>
/// <param name="mpa"></param>
/// <param name="throwIfMissing"></param>
internal bool SetMarginalPrototypeFromAttribute(Models.Attributes.MarginalPrototype mpa, bool throwIfMissing = true)
{
if (mpa != null)
{
if (mpa.prototypeExpression != null)
{
marginalPrototypeExpression = mpa.prototypeExpression;
Type marginalType = mpa.prototypeExpression.GetExpressionType();
if (marginalType == null) throw new InferCompilerException("Cannot determine type of marginal prototype expression: " + mpa.prototypeExpression);
}
else
{
marginalPrototypeExpression = Quoter.Quote(mpa.prototype);
}
}
else if (!TrySetMarginalPrototypeAutomatically())
{
if (throwIfMissing)
throw new ArgumentException("Cannot automatically determine distribution type for variable type '" + StringUtil.TypeToString(InnermostElementType) + "'" +
": you must specify a ValueRange or MarginalPrototype attribute for variable '" + Name + "' or its parent variables.");
else
return false;
}
return true;
}
internal IExpression GetMarginalPrototypeExpression(BasicTransformContext context, IExpression prototypeExpression,
IList<IList<IExpression>> indices, IList<IList<IExpression>> wildcardVars = null)
{
IExpression original = prototypeExpression;
int replaceCount = 0;
prototypeExpression = ReplaceIndexVars(context, prototypeExpression, indices, wildcardVars, ref replaceCount);
int mpDepth = Util.GetArrayDepth(varType, Distribution.GetDomainType(prototypeExpression.GetExpressionType()));
int indexingDepth = indices.Count;
int wildcardBracket = 0;
for (int depth = mpDepth; depth < indexingDepth; depth++)
{
IList<IExpression> indexCollection = Builder.ExprCollection();
int wildcardCount = 0;
for (int i = 0; i < indices[depth].Count; i++)
{
if (Recognizer.IsStaticMethod(indices[depth][i], new Func<int>(GateAnalysisTransform.AnyIndex)))
{
indexCollection.Add(wildcardVars[wildcardBracket][wildcardCount]);
wildcardCount++;
}
else
{
indexCollection.Add(indices[depth][i]);
}
}
if (indexCollection.Count > 0)
{
if (wildcardCount > 0) wildcardBracket++;
prototypeExpression = Builder.ArrayIndex(prototypeExpression, indexCollection);
replaceCount++;
}
}
if (replaceCount > 0) return prototypeExpression;
else return original;
}
/// <summary>
/// Create an array of this variable type, optionally after slicing it.
/// </summary>
/// <param name="addTo"></param>
/// <param name="context"></param>
/// <param name="name">Name of the new variable.</param>
/// <param name="arraySize">Length of the array. Cannot be null.</param>
/// <param name="newIndexVar">Name of a new integer variable used to index the array. Cannot be null.</param>
/// <param name="indices">Indices applied to this, before creating the array. May be null. May contain wildcards.</param>
/// <param name="wildcardVars">Loop variables to use for wildcards. May be null if there are no wildcards.</param>
/// <param name="useLiteralIndices">If true, literal indices will be used instead of newIndexVar.</param>
/// <param name="copyInitializer">If true, the new variable will be given an InitialiseTo attribute if this variable had one</param>
/// <remarks>
/// The new array is indexed by wildcards first, then newIndexVar, then the indices remaining from the original array.
/// For example, if original array is indexed [i,j][k,l][m,n] and indices = [*,*][3,*] then
/// the new array is indexed [wildcard0,wildcard1][wildcard2][newIndexVar][m,n] where sizes and marginalPrototype have expressions replaced according
/// to (i=wildcard0, j=wildcard1, k=3, l=wildcard2).
/// </remarks>
/// <returns>A new array of depth <c>(arraySize != null) + ArrayDepth - indices.Count + wildcardVars.Count</c></returns>
internal IVariableDeclaration DeriveArrayVariable(ICollection<IStatement> addTo, BasicTransformContext context, string name,
IExpression arraySize, IVariableDeclaration newIndexVar,
IList<IList<IExpression>> indices = null,
IList<IList<IExpression>> wildcardVars = null,
bool useLiteralIndices = false,
bool copyInitializer = false)
{
if (arraySize == null)
throw new ArgumentException("arraySize is null");
if (newIndexVar == null)
throw new ArgumentException("newIndexVar is null");
return DeriveArrayVariable(addTo, context, name,
new IExpression[][] { new[] { arraySize } },
new IVariableDeclaration[][] { new[] { newIndexVar } },
indices, wildcardVars, useLiteralIndices, copyInitializer);
}
internal IVariableDeclaration DeriveArrayVariable(ICollection<IStatement> addTo, BasicTransformContext context, string name,
IList<IExpression[]> arraySize, IList<IVariableDeclaration[]> newIndexVar,
IList<IList<IExpression>> indices = null,
IList<IList<IExpression>> wildcardVars = null,
bool useLiteralIndices = false,
bool copyInitializer = false)
{
List<IExpression[]> newSizes = new List<IExpression[]>();
List<IVariableDeclaration[]> newIndexVars = new List<IVariableDeclaration[]>();
Type innerType = varType;
if (indices != null)
{
// add wildcard variables to newIndexVars
for (int i = 0; i < indices.Count; i++)
{
List<IExpression> sizeBracket = new List<IExpression>();
List<IVariableDeclaration> indexVarsBracket = new List<IVariableDeclaration>();
for (int j = 0; j < indices[i].Count; j++)
{
IExpression index = indices[i][j];
if (Recognizer.IsStaticMethod(index, new Func<int>(GateAnalysisTransform.AnyIndex)))
{
int replaceCount = 0;
sizeBracket.Add(ReplaceIndexVars(context, sizes[i][j], indices, wildcardVars, ref replaceCount));
IVariableDeclaration v = indexVars[i][j];
if (wildcardVars != null) v = Recognizer.GetVariableDeclaration(wildcardVars[newIndexVars.Count][indexVarsBracket.Count]);
else if (Recognizer.GetLoopForVariable(context, v) != null)
{
// v is already used in a parent loop. must generate a new variable.
v = GenerateLoopVar(context, "_a");
}
indexVarsBracket.Add(v);
}
}
if (sizeBracket.Count > 0)
{
newSizes.Add(sizeBracket.ToArray());
newIndexVars.Add(indexVarsBracket.ToArray());
}
innerType = Util.GetElementType(innerType);
}
}
int literalIndexingDepth = 0;
if (arraySize != null)
{
newSizes.AddRange(arraySize);
if (useLiteralIndices)
literalIndexingDepth = newSizes.Count;
newIndexVars.AddRange(newIndexVar);
}
// innerType may not be an array type, so we create the new array type here instead of descending further.
Type tp = CodeBuilder.MakeJaggedArrayType(innerType, newSizes);
int indexingDepth = (indices == null) ? 0 : indices.Count;
List<IList<IExpression>> replacements = new List<IList<IExpression>>();
if (indices != null) replacements.AddRange(indices);
for (int i = indexingDepth; i < sizes.Count; i++)
{
if (indices == null)
{
newSizes.Add(sizes[i]);
if (indexVars.Count > i) newIndexVars.Add(indexVars[i]);
}
else
{
// must substitute references to indexVars with indices
IExpression[] sizeBracket = new IExpression[sizes[i].Length];
IVariableDeclaration[] indexVarBracket = new IVariableDeclaration[sizes[i].Length];
IList<IExpression> replacementBracket = Builder.ExprCollection();
for (int j = 0; j < sizeBracket.Length; j++)
{
int replaceCount = 0;
sizeBracket[j] = ReplaceIndexVars(context, sizes[i][j], replacements, wildcardVars, ref replaceCount);
if (replaceCount > 0) indexVarBracket[j] = GenerateLoopVar(context, "_a");
else if (indexVars.Count > i) indexVarBracket[j] = indexVars[i][j];
if (indexVarBracket[j] != null) replacementBracket.Add(Builder.VarRefExpr(indexVarBracket[j]));
}
newSizes.Add(sizeBracket);
newIndexVars.Add(indexVarBracket);
replacements.Add(replacementBracket);
}
}
IVariableDeclaration arrayvd = Builder.VarDecl(CodeBuilder.MakeValid(name), tp);
Builder.NewJaggedArray(addTo, arrayvd, newIndexVars, newSizes, literalIndexingDepth);
context.InputAttributes.CopyObjectAttributesTo(declaration, context.OutputAttributes, arrayvd);
// cannot copy the initializer since it will have a different size.
context.OutputAttributes.Remove<InitialiseTo>(arrayvd);
context.OutputAttributes.Remove<InitialiseBackwardTo>(arrayvd);
context.OutputAttributes.Remove<InitialiseBackward>(arrayvd);
context.OutputAttributes.Remove<VariableInformation>(arrayvd);
context.OutputAttributes.Remove<SuppressVariableFactor>(arrayvd);
context.OutputAttributes.Remove<LoopContext>(arrayvd);
context.OutputAttributes.Remove<Containers>(arrayvd);
context.OutputAttributes.Remove<ChannelInfo>(arrayvd);
context.OutputAttributes.Remove<IsInferred>(arrayvd);
context.OutputAttributes.Remove<QueryTypeCompilerAttribute>(arrayvd);
context.OutputAttributes.Remove<DerivMessage>(arrayvd);
context.OutputAttributes.Remove<PointEstimate>(arrayvd);
VariableInformation vi = VariableInformation.GetVariableInformation(context, arrayvd);
vi.IsStochastic = IsStochastic;
vi.sizes = newSizes;
vi.indexVars = newIndexVars;
if (useLiteralIndices)
vi.LiteralIndexingDepth = literalIndexingDepth;
if (indexingDepth > 0)
{
// substitute indices in the marginal prototype expression (if any)
MarginalPrototype mpa = context.InputAttributes.Get<MarginalPrototype>(declaration);
if (mpa != null && mpa.prototypeExpression != null)
{
IExpression mpe = GetMarginalPrototypeExpression(context, mpa.prototypeExpression, replacements, wildcardVars);
if (mpe != mpa.prototypeExpression)
{
MarginalPrototype mpa2 = new MarginalPrototype(null);
mpa2.prototypeExpression = mpe;
context.OutputAttributes.Remove<MarginalPrototype>(arrayvd);
context.OutputAttributes.Set(arrayvd, mpa2);
}
}
}
InitialiseTo it = context.InputAttributes.Get<InitialiseTo>(declaration);
if (it != null && copyInitializer)
{
// if original array is indexed [i,j][k,l][m,n] and indices = [*,*][3,*] then
// initExpr2 = new PlaceHolder[wildcard0,wildcard1] { new PlaceHolder[wildcard2] { new PlaceHolder[newIndexVar] { initExpr[wildcard0,wildcard1][3,wildcard2] } } }
IExpression initExpr = it.initialMessagesExpression;
// add indices to the initialiser expression
int wildcardBracket = 0;
for (int depth = 0; depth < indexingDepth; depth++)
{
IList<IExpression> indexCollection = Builder.ExprCollection();
int wildcardCount = 0;
for (int i = 0; i < indices[depth].Count; i++)
{
if (Recognizer.IsStaticMethod(indices[depth][i], new Func<int>(GateAnalysisTransform.AnyIndex)))
{
indexCollection.Add(wildcardVars[wildcardBracket][wildcardCount]);
wildcardCount++;
}
else
{
indexCollection.Add(indices[depth][i]);
}
}
if (indexCollection.Count > 0)
{
if (wildcardCount > 0) wildcardBracket++;
initExpr = Builder.ArrayIndex(initExpr, indexCollection);
}
}
// add array creates to the initialiser expression
if (newIndexVar != null)
{
initExpr = MakePlaceHolderArrayCreate(initExpr, newIndexVar);
}
if (wildcardBracket > 0)
{
while (wildcardBracket > 0)
{
wildcardBracket--;
initExpr = MakePlaceHolderArrayCreate(initExpr, vi.indexVars[wildcardBracket]);
}
}
context.OutputAttributes.Set(arrayvd, new InitialiseTo(initExpr));
}
ChannelTransform.setAllGroupRoots(context, arrayvd, false);
return arrayvd;
}
internal static IExpression MakePlaceHolderArrayCreate(IExpression expr, IList<IVariableDeclaration[]> indexVars)
{
for (int bracket = indexVars.Count - 1; bracket >= 0; bracket--)
{
expr = MakePlaceHolderArrayCreate(expr, indexVars[bracket]);
}
return expr;
}
internal static IExpression MakePlaceHolderArrayCreate(IExpression expr, IList<IVariableDeclaration> indexVars)
{
CodeBuilder Builder = CodeBuilder.Instance;
IArrayCreateExpression iace = Builder.ArrayCreateExpr(typeof(PlaceHolder), Util.ArrayInit(indexVars.Count, i => Builder.VarRefExpr(indexVars[i])));
iace.Initializer = Builder.BlockExpr();
iace.Initializer.Expressions.Add(expr);
return iace;
}
/// <summary>
/// Create a slice of this variable array, where all indices up to a certain depth are given.
/// </summary>
/// <param name="addTo"></param>
/// <param name="context"></param>
/// <param name="name">Name of the new variable array</param>
/// <param name="indices">Expressions used to index the variable array. May contain wildcards.</param>
/// <param name="wildcardVars">Loop variables to use for wildcards. May be null if there are no wildcards.</param>
/// <param name="copyInitializer">If true, the new variable will be given an InitialiseTo attribute if this variable had one</param>
/// <returns>The declaration of the new variable.</returns>
/// <remarks>
/// For example, suppose we want to slice a[i][2][j][k] into b[j][k].
/// Then <paramref name="name"/>="b", <paramref name="indices"/>=<c>[i][2]</c>.
/// </remarks>
internal IVariableDeclaration DeriveIndexedVariable(IList<IStatement> addTo, BasicTransformContext context, string name,
List<IList<IExpression>> indices = null, IList<IList<IExpression>> wildcardVars = null,
bool copyInitializer = false)
{
return DeriveArrayVariable(addTo, context, name, (IList<IExpression[]>)null, null, indices, wildcardVars, false, copyInitializer);
}
/// <summary>
/// Replace all indexVars which appear in expr with the given indices.
/// </summary>
/// <param name="context"></param>
/// <param name="expr">Any expression</param>
/// <param name="indices">A list of lists of index expressions (one list for each indexing bracket).</param>
/// <param name="wildcardIndices">Expressions used to replace wildcards. May be null if there are no wildcards.</param>
/// <param name="replaceCount">Incremented for each replacement.</param>
/// <returns>A new expression.</returns>
internal IExpression ReplaceIndexVars(BasicTransformContext context, IExpression expr, IList<IList<IExpression>> indices,
IList<IList<IExpression>> wildcardIndices, ref int replaceCount)
{
Dictionary<IVariableDeclaration, IExpression> replacedIndexVars = new Dictionary<IVariableDeclaration, IExpression>();
int wildcardBracket = 0;
for (int depth = 0; depth < indices.Count; depth++)
{
if (indexVars.Count > depth)
{
int wildcardCount = 0;
for (int i = 0; i < indices[depth].Count; i++)
{
if (indexVars[depth].Length > i)
{
IVariableDeclaration indexVar = indexVars[depth][i];
if (indexVar != null)
{
IExpression actualIndex = indices[depth][i];
if (Recognizer.IsStaticMethod(actualIndex, new Func<int>(GateAnalysisTransform.AnyIndex)))
{
actualIndex = wildcardIndices[wildcardBracket][wildcardCount];
wildcardCount++;
}
IExpression formalIndex = Builder.VarRefExpr(indexVar);
if (!formalIndex.Equals(actualIndex))
{
expr = Builder.ReplaceExpression(expr, formalIndex, actualIndex, ref replaceCount);
replacedIndexVars.Add(indexVar, actualIndex);
}
}
}
}
if (wildcardCount > 0) wildcardBracket++;
}
}
CheckReplacements(context, expr, replacedIndexVars);
return expr;
}
/// <summary>
/// Check that the replacements are safe.
/// </summary>
/// <param name="context"></param>
/// <param name="expr"></param>
/// <param name="replacedIndexVars"></param>
private static void CheckReplacements(BasicTransformContext context, IExpression expr, Dictionary<IVariableDeclaration, IExpression> replacedIndexVars)
{
foreach(var v in Recognizer.GetVariables(expr))
{
Containers containers = context.InputAttributes.Get<Containers>(v);
if (containers != null && !replacedIndexVars.ContainsKey(v))
{
foreach (IStatement container in containers.inputs)
{
if (container is IForStatement)
{
IVariableDeclaration loopVar = Recognizer.LoopVariable((IForStatement)container);
IExpression actualIndex;
if (replacedIndexVars.TryGetValue(loopVar, out actualIndex))
{
context.Error($"Cannot index {expr} by {loopVar.Name}={actualIndex} since {v.Name} has an implicit dependency on {loopVar.Name}. Try making the dependency explicit by putting {v.Name} into an array indexed by {loopVar.Name}");
}
}
}
}
}
}
internal bool HasIndexVar(IVariableDeclaration ivd)
{
foreach (IVariableDeclaration[] bracket in indexVars)
{
foreach (IVariableDeclaration indexVar in bracket)
{
if (indexVar == null) continue;
if (indexVar.Name == ivd.Name)
{
return true;
}
}
}
return false;
}
internal List<IStatement> BuildWildcardLoops(IList<IList<IExpression>> wildcardVars)
{
List<IStatement> loops = new List<IStatement>();
for (int i = 0; i < wildcardVars.Count; i++)
{
for (int j = 0; j < wildcardVars[i].Count; j++)
{
IExpression size = sizes[i][j];
IVariableDeclaration v = Recognizer.GetVariableDeclaration(wildcardVars[i][j]);
loops.Add(Builder.ForStmt(v, size));
}
}
return loops;
}
internal bool IsPartitionedAtDepth(BasicTransformContext context, int depth)
{
if (depth >= indexVars.Count) return false;
IVariableDeclaration[] bracket = indexVars[depth];
bool allPartitioned = true;
bool anyPartitioned = false;
for (int i = 0; i < bracket.Length; i++)
{
IVariableDeclaration indexVar = bracket[i];
bool isPartitioned = (indexVar != null && context.InputAttributes.Has<Partitioned>(indexVar));
if (isPartitioned) anyPartitioned = true;
else allPartitioned = false;
}
if (allPartitioned) return true;
else if (anyPartitioned) throw new Exception("indexing bracket is partially partitioned");
else return false;
}
}
internal class NameGenerator : ICompilerAttribute
{
private Dictionary<string, int> counts = new Dictionary<string, int>();
public string GenerateName(string prefix)
{
if (prefix.Length > 0)
{
// If prefix ends with a digit, append an underscore.
// This ensures that names generated from different prefixes cannot collide.
char lastChar = prefix[prefix.Length - 1];
if (char.IsDigit(lastChar))
prefix += "_";
}
int count;
counts.TryGetValue(prefix, out count);
if (count == 0) count = 1;
counts[prefix] = count + 1;
if (count == 1) return prefix;
return prefix + count;
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Двоичные данные
src/Compiler/Infer/Infer.ico Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 9.4 KiB

Просмотреть файл

@ -0,0 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Compiler
{
using System;
using System.Runtime.Serialization;
/// <summary>
/// The exception that is thrown in the case of an issue encountered by the Infer.NET Compiler.
/// </summary>
public class InferCompilerException : Exception
{
/// <summary>
/// Initializes a new instance of the <see cref="InferCompilerException"/> class.
/// </summary>
public InferCompilerException()
{
}
/// <summary>
/// Initializes a new instance of the <see cref="InferCompilerException"/> class with a specified error message.
/// </summary>
/// <param name="message">The error message.</param>
public InferCompilerException(string message)
: base(message)
{
}
/// <summary>
/// Initializes a new instance of the <see cref="InferCompilerException"/> class with a specified error message
/// and a reference to the inner exception that is the cause of this exception.
/// </summary>
/// <param name="message">The error message that explains the reason for the exception.</param>
/// <param name="inner">The exception that is the cause of the current exception.</param>
public InferCompilerException(string message, Exception inner)
: base(message, inner)
{
}
// This constructor is needed for serialization.
protected InferCompilerException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,556 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using System.Runtime.Serialization;
namespace Microsoft.ML.Probabilistic.Models
{
internal interface IStatementBlock
{
/// <summary>
/// Get a statement for the entire block, and a pointer to its body.
/// </summary>
/// <param name="innerBlock">On return, a pointer to the body of the block.</param>
/// <returns></returns>
IStatement GetStatement(out IList<IStatement> innerBlock);
}
/// <summary>
/// Thrown when an empty block is closed.
/// </summary>
public class EmptyBlockException : Exception
{
/// <summary>
/// Initializes a new instance of the <see cref="EmptyBlockException"/> class.
/// </summary>
public EmptyBlockException()
{
}
/// <summary>
/// Initializes a new instance of the <see cref="EmptyBlockException"/> class with a specified error message.
/// </summary>
/// <param name="message">The error message.</param>
public EmptyBlockException(string message)
: base(message)
{
}
// This constructor is needed for serialization.
protected EmptyBlockException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
/// <summary>
/// Abstract base class for statement blocks
/// </summary>
public abstract class StatementBlock : IDisposable, IStatementBlock
{
/// <summary>
/// A list of currently open blocks. This is a thread-specific static variable and
/// will have a different value for each thread.
/// </summary>
[ThreadStatic] private static List<IStatementBlock> openBlocks; // note cannot initalise thread-static variables here since it will only work in one thread
internal static List<IStatementBlock> GetOpenBlocks()
{
if (openBlocks == null) openBlocks = new List<IStatementBlock>();
return openBlocks;
}
internal static List<T> GetOpenBlocks<T>()
{
List<T> list = new List<T>();
List<IStatementBlock> blocks = GetOpenBlocks();
foreach (IStatementBlock sb in blocks)
{
if (sb is T) list.Add((T) sb);
}
return list;
}
internal static IEnumerable<T> EnumerateOpenBlocks<T>()
{
return EnumerateBlocks<T>(GetOpenBlocks());
}
internal static IEnumerable<T> EnumerateBlocks<T>(IEnumerable<IStatementBlock> blocks)
{
foreach (IStatementBlock sb in blocks)
{
if (sb is T) yield return (T) sb;
}
}
/// <summary>
/// Adds this block to a thread-specific list of open blocks.
/// </summary>
internal virtual void OpenBlock()
{
GetOpenBlocks().Add(this);
}
/// <summary>
/// Removes this block from a thread-specific list of open blocks.
/// If this block is not the final element of the list, gives an error.
/// </summary>
public void CloseBlock()
{
List<IStatementBlock> blocks = GetOpenBlocks();
int k = blocks.IndexOf(this);
if (k == -1)
{
throw new InvalidOperationException("Cannot close a block that is not open.");
}
if (k != blocks.Count - 1)
{
throw new InvalidOperationException("Blocks must be closed in the reverse order that they were opened.");
}
blocks.Remove(this);
}
/// <summary>
/// Close blocks in order to recover from exceptions
/// </summary>
internal static void CloseAllBlocks()
{
List<IStatementBlock> blocks = new List<IStatementBlock>(StatementBlock.GetOpenBlocks());
blocks.Reverse();
foreach (StatementBlock block in blocks) block.CloseBlock();
}
/// <summary>
/// Causes CloseBlock() to be called, so that this class can be used as the argument of a using() statement.
/// </summary>
/// <exclude/>
public void Dispose()
{
CloseBlock();
}
/// <summary>
/// Get a statement for the entire block, and a pointer to its body.
/// </summary>
/// <param name="innerBlock">On return, a pointer to the body of the block.</param>
/// <returns></returns>
internal abstract IStatement GetStatement(out IList<IStatement> innerBlock);
IStatement IStatementBlock.GetStatement(out IList<IStatement> innerBlock)
{
return GetStatement(out innerBlock);
}
}
/// <summary>
/// Indicates that a StatementBlock has an associated range that it loops over.
/// </summary>
public interface HasRange
{
/// <summary>
/// The Range being looped over.
/// </summary>
Range Range { get; }
}
/// <summary>
/// 'For each' block
/// </summary>
public class ForEachBlock : StatementBlock, HasRange
{
/// <summary>
/// Range associated with the 'for each' block
/// </summary>
protected Range range;
/// <summary>
/// Range associated with the 'for each' block
/// </summary>
public Range Range
{
get { return range; }
}
private Variable<int> indexVar;
/// <summary>
/// The index variable associated with the range
/// </summary>
public Variable<int> Index
{
get { return indexVar; }
}
/// <summary>
/// Constructs 'for each' block from a range
/// </summary>
/// <param name="range">The range</param>
public ForEachBlock(Range range)
{
this.range = range;
OpenBlock();
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
/// <exclude/>
public override string ToString()
{
return "ForEach(" + range + ")";
}
internal static void CheckRangeCanBeOpened(Range range)
{
// check that all ranges in Range.Size are already opened.
Set<Range> openRanges = new Set<Range>();
foreach (HasRange fb in EnumerateOpenBlocks<HasRange>())
{
openRanges.Add(fb.Range);
}
if (openRanges.Contains(range))
{
throw new InvalidOperationException("Range '" + range + "' is already open in a ForEach or Switch block");
}
Models.MethodInvoke.ForEachRange(range.Size,
delegate(Range r)
{
if (!openRanges.Contains(r))
throw new InvalidOperationException("Range '" + range + "' depends on range '" + r + "', but range '" + r +
"' is not open in a ForEach block. Insert 'Variable.ForEach(" + r +
")' around 'Variable.ForEach(" + range + ")'.");
});
}
/// <summary>
/// Adds this block to a thread-specific list of open blocks.
/// </summary>
internal override void OpenBlock()
{
CheckRangeCanBeOpened(range);
indexVar = new Variable<int>(range); // Needs to be here to prevent error when creating grid with .Index syntax
base.OpenBlock();
}
/// <summary>
/// Get a statement for the entire block, and a pointer to its body.
/// </summary>
/// <param name="innerBlock">On return, a pointer to the body of the block.</param>
/// <returns></returns>
internal override IStatement GetStatement(out IList<IStatement> innerBlock)
{
return range.GetStatement(out innerBlock);
}
}
/// <summary>
/// 'Repeat' block
/// </summary>
public class RepeatBlock : StatementBlock
{
private Variable<double> countVar;
/// <summary>
/// The variable that indicates the (possibly fractional) number of repeats.
/// </summary>
public Variable<double> Count
{
get { return countVar; }
}
/// <summary>
/// Constructs 'for each' block from a range
/// </summary>
/// <param name="count"></param>
public RepeatBlock(Variable<double> count)
{
this.countVar = count;
OpenBlock();
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
/// <exclude/>
public override string ToString()
{
return "Repeat(" + countVar + ")";
}
/// <summary>
/// Get a statement for the entire block, and a pointer to its body.
/// </summary>
/// <param name="innerBlock">On return, a pointer to the body of the block.</param>
/// <returns></returns>
internal override IStatement GetStatement(out IList<IStatement> innerBlock)
{
IRepeatStatement rs = CodeBuilder.Instance.RepeatStmt(countVar.GetExpression());
innerBlock = rs.Body.Statements;
return rs;
}
}
/// <summary>
/// Base class for condition blocks
/// </summary>
public abstract class ConditionBlock : StatementBlock
{
/// <summary>
/// Helps build class declarations
/// </summary>
private static CodeBuilder Builder = CodeBuilder.Instance;
/// <summary>
/// Adds this block to a thread-specific list of open blocks.
/// </summary>
internal override void OpenBlock()
{
foreach (ConditionBlock cb in EnumerateOpenBlocks<ConditionBlock>())
{
if (cb.ConditionVariableUntyped == ConditionVariableUntyped)
{
throw new InvalidOperationException("Variable '" + ConditionVariableUntyped + "' is already being conditioned on.");
}
}
base.OpenBlock();
}
internal static ConditionBlock GetConditionBlock(Variable conditionVar)
{
foreach (ConditionBlock cb in EnumerateOpenBlocks<ConditionBlock>())
{
if (cb.ConditionVariableUntyped == conditionVar) return cb;
}
return null;
}
internal abstract IExpression GetConditionExpression();
/// <summary>
/// The condition variable for this condition block.
/// </summary>
public abstract Variable ConditionVariableUntyped { get; }
/// <summary>
/// Gets a statement for the entire block, and a pointer to its body.
/// </summary>
/// <param name="innerBlock">On return, a pointer to the body of the block.</param>
/// <returns></returns>
internal override IStatement GetStatement(out IList<IStatement> innerBlock)
{
IConditionStatement cs = Builder.CondStmt();
cs.Condition = GetConditionExpression();
cs.Then = Builder.BlockStmt();
innerBlock = cs.Then.Statements;
return cs;
}
}
/// <summary>
/// Represents a conditional block in a model definition. Anything defined inside
/// the block is placed inside a gate, whose condition is the condition of the block.
/// </summary>
public class ConditionBlock<T> : ConditionBlock
{
/// <summary>
/// Helps build class declarations
/// </summary>
private static CodeBuilder Builder = CodeBuilder.Instance;
private readonly Variable<T> conditionVariable;
private readonly T conditionValue;
internal ConditionBlock(Variable<T> conditionVariable, T conditionValue)
: this(conditionVariable, conditionValue, true)
{
}
internal ConditionBlock(Variable<T> conditionVariable, T conditionValue, bool openBlock)
{
// check that all ranges in the conditionVariable are already opened.
Set<Range> openRanges = new Set<Range>();
foreach (HasRange fb in EnumerateOpenBlocks<HasRange>())
{
openRanges.Add(fb.Range);
}
Models.MethodInvoke.ForEachRange(conditionVariable,
delegate(Range r)
{
if (!openRanges.Contains(r))
throw new InvalidOperationException(conditionVariable + " depends on range '" + r + "', but range '" + r +
"' is not open in a ForEach block. Insert 'Variable.ForEach(" + r +
")' around this block.");
});
this.conditionVariable = conditionVariable;
this.conditionValue = conditionValue;
if (openBlock) OpenBlock();
}
/// <summary>
/// The random variable which controls when this IfBlock is active.
/// </summary>
public Variable<T> ConditionVariable
{
get { return conditionVariable; }
}
/// <summary>
/// The value of the condition variable which switches on this IfBlock.
/// </summary>
public T ConditionValue
{
get { return conditionValue; }
}
/// <summary>
/// Equals override
/// </summary>
/// <param name="obj"></param>
/// <returns></returns>
/// <exclude/>
public override bool Equals(object obj)
{
ConditionBlock<T> cb = obj as ConditionBlock<T>;
if (cb == null) return false;
if (!ReferenceEquals(conditionVariable, cb.conditionVariable)) return false;
return conditionValue.Equals(cb.conditionValue);
}
/// <summary>
/// Hash code override
/// </summary>
/// <returns></returns>
/// <exclude/>
public override int GetHashCode()
{
int hash = conditionVariable.GetHashCode() + conditionValue.GetHashCode();
return hash;
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
/// <exclude/>
public override string ToString()
{
return GetType().Name + "(" + GetConditionExpression().ToString() + ")";
}
internal override IExpression GetConditionExpression()
{
return Builder.BinaryExpr(conditionVariable.GetExpression(), BinaryOperator.ValueEquality, Builder.LiteralExpr(conditionValue));
}
/// <summary>
/// The condition variable for this condition block.
/// </summary>
public override Variable ConditionVariableUntyped
{
get { return conditionVariable; }
}
}
/// <summary>
/// An If block is a condition block with a binary condition.
/// </summary>
public class IfBlock : ConditionBlock<bool>
{
/// <summary>
/// Helps build class declarations
/// </summary>
private static CodeBuilder Builder = CodeBuilder.Instance;
internal IfBlock(Variable<bool> conditionVariable, bool value)
: base(conditionVariable, value)
{
}
internal override IExpression GetConditionExpression()
{
IExpression expr = ConditionVariable.GetExpression();
if (ConditionValue) return expr;
else return Builder.NotExpr(expr);
}
}
/// <summary>
/// A case block is a condition block with a condition of the form (i==value) for integer i.
/// </summary>
public class CaseBlock : ConditionBlock<int>
{
internal CaseBlock(Variable<int> conditionVariable, int value)
: base(conditionVariable, value)
{
}
}
/// <summary>
/// A switch block is a condition block which acts like multiple case blocks ranging over the values
/// of the integer condition variable.
/// </summary>
public class SwitchBlock : ConditionBlock<int>, HasRange
{
/// <summary>
/// Helps build class declarations
/// </summary>
private static CodeBuilder Builder = CodeBuilder.Instance;
private Range range;
internal SwitchBlock(Variable<int> conditionVariable, Range range)
: base(conditionVariable, -1, false)
{
this.range = range;
OpenBlock();
}
/// <summary>
/// Adds this block to a thread-specific list of open blocks.
/// </summary>
internal override void OpenBlock()
{
ForEachBlock.CheckRangeCanBeOpened(range);
base.OpenBlock();
}
/// <summary>
/// Get switch block's range
/// </summary>
public Range Range
{
get { return range; }
}
internal override IExpression GetConditionExpression()
{
return Builder.BinaryExpr(ConditionVariable.GetExpression(), BinaryOperator.ValueEquality, range.GetExpression());
}
/// <summary>
/// Gets a statement for the entire block, and a pointer to its body.
/// </summary>
/// <param name="innerBlock">On return, a pointer to the body of the block.</param>
/// <returns></returns>
internal override IStatement GetStatement(out IList<IStatement> innerBlock)
{
if (ConditionVariable.IsObserved)
{
innerBlock = null;
return null;
}
IForStatement ifs = Builder.ForStmt(range.GetIndexDeclaration(), range.GetSizeExpression());
ifs.Body.Statements.Add(base.GetStatement(out innerBlock));
return ifs;
}
}
}

Просмотреть файл

@ -0,0 +1,51 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Models
{
#pragma warning disable 1591
public delegate TResult FuncOut<in T1, TOut, out TResult>(T1 arg, out TOut output);
public delegate TResult FuncOut<in T1, in T2, TOut, out TResult>(T1 arg, T2 arg2, out TOut output);
public delegate TResult FuncOut2<in T1, TOut, TOut2, out TResult>(T1 arg, out TOut output, out TOut2 output2);
public delegate TResult FuncOut2<in T1, in T2, TOut, TOut2, out TResult>(T1 arg, T2 arg2, out TOut output, out TOut2 output2);
public delegate TResult FuncOut3<in T1, in T2, TOut, TOut2, TOut3, out TResult>(T1 arg, T2 arg2, out TOut output, out TOut2 output2, out TOut3 output3);
public delegate TResult FuncOut3<in T1, in T2, in T3, TOut, TOut2, TOut3, out TResult>(T1 arg, T2 arg2, T3 arg3, out TOut output, out TOut2 output2, out TOut3 output3);
public delegate TResult FuncOut3<in T1, in T2, in T3, in T4, TOut, TOut2, TOut3, out TResult>(T1 arg, T2 arg2, T3 arg3, T4 arg4, out TOut output, out TOut2 output2, out TOut3 output3);
public delegate TResult FuncOut4<in T1, in T2, TOut, TOut2, TOut3, TOut4, out TResult>(T1 arg, T2 arg2, out TOut output, out TOut2 output2, out TOut3 output3, out TOut4 output4);
public delegate TResult FuncOut4<in T1, in T2, in T3, TOut, TOut2, TOut3, TOut4, out TResult>(
T1 arg, T2 arg2, T3 arg3, out TOut output, out TOut2 output2, out TOut3 output3, out TOut4 output4);
#pragma warning restore 1591
/// <summary>
/// Generic delegate with 2 out parameters
/// </summary>
/// <typeparam name="T1">Type of first argument</typeparam>
/// <typeparam name="T2">Type of second argument</typeparam>
/// <typeparam name="T3">Type of third argument</typeparam>
/// <param name="arg1">First argument</param>
/// <param name="arg2">Second argument</param>
/// <param name="arg3">Third argument</param>
public delegate void ActionOut2<in T1, T2, T3>(T1 arg1, out T2 arg2, out T3 arg3);
/// <summary>
/// Generic delegate with 2 out parameters
/// </summary>
/// <typeparam name="T1">Type of first argument</typeparam>
/// <typeparam name="T2">Type of second argument</typeparam>
/// <typeparam name="T3">Type of third argument</typeparam>
/// <typeparam name="T4">Type of fourth argument</typeparam>
/// <param name="arg1">First argument</param>
/// <param name="arg2">Second argument</param>
/// <param name="arg3">Third argument</param>
/// <param name="arg4">Fourth argument</param>
public delegate void ActionOut2<in T1, in T2, T3, T4>(T1 arg1, T2 arg2, out T3 arg3, out T4 arg4);
}

Просмотреть файл

@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Threading;
namespace Microsoft.ML.Probabilistic.Models
{
internal class GlobalCounter
{
private int count = -1;
public int GetNext()
{
return Interlocked.Increment(ref count);
}
}
}

Просмотреть файл

@ -0,0 +1,70 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
namespace Microsoft.ML.Probabilistic.Models
{
/// <summary>
/// Interface to a modelling expression, such as a constant, variable or parameter.
/// </summary>
public interface IModelExpression
{
/// <summary>
/// Get the code model expression
/// </summary>
/// <returns></returns>
IExpression GetExpression();
/// <summary>
/// Expression name
/// </summary>
string Name { get; }
}
/// <summary>
/// Generic inferface to a modelling expression of type T.
/// </summary>
/// <typeparam name="T"></typeparam>
public interface IModelExpression<T> : IModelExpression
{
}
/// <summary>
/// A marker interface for variables.
/// </summary>
public interface IVariable : IModelExpression, HasObservedValue
{
}
/// <summary>
/// Interface for getting list of containers
/// </summary>
public interface CanGetContainers
{
/// <summary>
/// Get list of containers for a variable
/// </summary>
/// <typeparam name="T">Type of variable</typeparam>
/// <returns></returns>
List<T> GetContainers<T>();
}
/// <summary>
/// Interface for a variable to have an observed value
/// </summary>
public interface HasObservedValue
{
/// <summary>
/// Returns true if the variable is observed.
/// </summary>
bool IsObserved { get; }
/// <summary>
/// Observed value property
/// </summary>
object ObservedValue { get; set; }
}
}

Просмотреть файл

@ -0,0 +1,921 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Diagnostics;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Compiler;
using System.Reflection;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using System.Collections.ObjectModel;
using System.IO;
using System.Collections.Concurrent;
using System.Linq;
using Microsoft.ML.Probabilistic.Models.Attributes;
using Microsoft.ML.Probabilistic.Compiler.Visualizers;
namespace Microsoft.ML.Probabilistic.Models
{
/// <summary>
/// An inference engine, used to perform inference tasks in Infer.NET.
/// </summary>
/// <remarks>
/// The Debug class may be used to get debug messages for the inference engine.
/// For example, use <code>Debug.Listeners.Add(new TextWriterTraceListener(Console.Out));</code>
/// to get debug information for when compiled models and marginals are re-used.
/// </remarks>
public class InferenceEngine : SettableTo<InferenceEngine>
{
/// <summary>
/// Bag of weak references to engine instances. The weak references allow the instances
/// to be garbage collected when they are no longer being used.
/// </summary>
private static readonly ConcurrentDictionary<WeakReference, EmptyStruct> allEngineInstances = new ConcurrentDictionary<WeakReference, EmptyStruct>();
protected struct EmptyStruct
{
}
/// <summary>
/// Internal list of built-in algorithms
/// </summary>
private static readonly IAlgorithm[] algs =
{
new Algorithms.ExpectationPropagation(), new Algorithms.VariationalMessagePassing(),
new Algorithms.GibbsSampling(), new Algorithms.MaxProductBeliefPropagation()
};
/// <summary>
/// Default inference engine whose settings will be copied onto newly created engines.
/// </summary>
public static readonly InferenceEngine DefaultEngine = new InferenceEngine(false);
private static bool IsUnitTest()
{
// FriendlyName works for VS 2005 but not VS 2008
if (AppDomain.CurrentDomain.FriendlyName.Contains("UnitTest")) return true;
Assembly[] assemblies = AppDomain.CurrentDomain.GetAssemblies();
foreach (Assembly assembly in assemblies)
{
if (assembly.GetName().Name == "Microsoft.VisualStudio.QualityTools.AgentObject") return true;
}
return false;
}
static InferenceEngine()
{
DefaultEngine.ResetOnObservedValueChanged = true;
}
/// <summary>
/// The full name of the inference engine, including version
/// </summary>
public static string Name
{
get { return "Infer.NET " + Assembly.GetExecutingAssembly().GetName().Version; }
}
/// <summary>
/// Model namespace, used when naming generated classes.
/// </summary>
public string ModelNamespace { get; set; } = ModelBuilder.ModelNamespace;
/// <summary>
/// Model name, used when naming generated classes.
/// </summary>
public string ModelName { get; set; } = "Model";
/// <summary>
/// Provides the implementation of ShowFactorGraph, ShowSchedule, and BrowserMode.
/// </summary>
public static Visualizer Visualizer { get; set; } = new DefaultVisualizer();
/// <summary>
/// The ModelBuilder used to construct MSL from in-memory graphs of Variables etc.
/// </summary>
private ModelBuilder mb = new ModelBuilder();
private ConcurrentStack<CompiledAlgorithmInfo> compiledAlgorithms = new ConcurrentStack<CompiledAlgorithmInfo>();
private Dictionary<IVariable, CompiledAlgorithmInfo> compiledAlgorithmForVariable = new Dictionary<IVariable, CompiledAlgorithmInfo>();
private class CompiledAlgorithmInfo
{
public IGeneratedAlgorithm exec;
public List<Variable> observedVarsInOrder = new List<Variable>();
public Set<Variable> observedVars = new Set<Variable>();
public CompiledAlgorithmInfo(IGeneratedAlgorithm exec, IEnumerable<Variable> observedVars)
{
this.exec = exec;
this.observedVarsInOrder.AddRange(observedVars);
this.observedVars.AddRange(observedVars);
}
}
/// <summary>
/// The Compiler used to compile MSL into a compiled algorithm.
/// </summary>
protected ModelCompiler compiler;
/// <summary>
/// Creates an inference engine which uses the default inference algorithm
/// (currently this is expectation propagation).
/// </summary>
public InferenceEngine()
: this(true)
{
allEngineInstances.TryAdd(new WeakReference(this), new EmptyStruct());
}
/// <summary>
/// Creates an inference engine which uses the specified inference algorithm.
/// </summary>
public InferenceEngine(IAlgorithm algorithm)
: this()
{
this.Algorithm = algorithm;
}
/// <summary>
/// Create a new ModelCompiler object
/// </summary>
private void CreateCompiler()
{
compiler = new ModelCompiler();
compiler.ParametersChanged += delegate (object sender, EventArgs e) { InvalidateCompiledAlgorithms(); };
compiler.Compiling += delegate (ModelCompiler sender, ModelCompiler.CompileEventArgs e) { if (ShowProgress) Console.Write("Compiling model..."); };
compiler.Compiled += delegate (ModelCompiler sender, ModelCompiler.CompileEventArgs e)
{
if (ShowWarnings && e.Warnings != null && (e.Warnings.Count > 0))
{
Console.WriteLine("compilation had " + e.Warnings.Count + " warning(s).");
int count = 1;
foreach (TransformError te in e.Warnings)
{
if (!te.IsWarning)
continue;
Console.WriteLine(" [" + count + "] " + te.ErrorText);
count++;
}
}
if (e.Exception != null)
{
if (ShowProgress) Console.WriteLine("compilation failed.");
}
else
{
if (ShowProgress) Console.WriteLine("done.");
}
};
}
/// <summary>
/// Creates an inference engine, optionally copying values from the default engine.
/// </summary>
/// <param name="copyValuesFromDefault"></param>
internal InferenceEngine(bool copyValuesFromDefault)
{
CreateCompiler();
if (copyValuesFromDefault)
{
lock (DefaultEngine)
{
SetTo(DefaultEngine);
}
}
else
{
Algorithm = algs[0];
}
}
/// <summary>
/// Get the abstract syntax tree for the generated code.
/// </summary>
/// <returns>A list of type declaration objects.</returns>
public List<ITypeDeclaration> GetCodeToInfer(IVariable var)
{
if (!mb.variablesToInfer.Contains(var))
{
mb.Build(this, false, new IVariable[] { var });
}
return mb.GetGeneratedSyntax(this);
}
/// <summary>
/// Compiles the last built model into a CompiledAlgorithm which implements
/// the specified inference algorithm on the model.
/// </summary>
/// <returns></returns>
private IGeneratedAlgorithm Compile()
{
mb.SetModelName(ModelNamespace, ModelName);
if (ShowMsl) Console.WriteLine(mb.ModelString());
if (ShowFactorGraph || SaveFactorGraphToFolder != null)
{
if (SaveFactorGraphToFolder != null && Visualizer?.GraphWriter != null)
{
Directory.CreateDirectory(SaveFactorGraphToFolder);
Visualizer.GraphWriter.WriteGraph(mb, SaveFactorGraphToFolder + @"\" + ModelName);
}
if (ShowFactorGraph && Visualizer?.FactorGraphVisualizer != null)
Visualizer.FactorGraphVisualizer.VisualizeFactorGraph(mb);
}
Stopwatch s = null;
if (ShowTimings)
{
s = new Stopwatch();
s.Start();
}
IGeneratedAlgorithm compiledAlgorithm = Compiler.CompileWithoutParams(mb.modelType, null, mb.Attributes);
if (ShowTimings)
{
s.Stop();
Console.WriteLine("Compilation time was " + s.ElapsedMilliseconds + "ms.");
}
CompiledAlgorithmInfo info = new CompiledAlgorithmInfo(compiledAlgorithm, mb.observedVars);
compiledAlgorithms.Push(info);
foreach (IVariable v in mb.variablesToInfer)
{
compiledAlgorithmForVariable[v] = info;
}
SetObservedValues(info);
return info.exec;
}
/// <summary>
/// Infers the marginal distribution for the specified variable.
/// </summary>
/// <param name="var">The variable whose marginal is to be inferred</param>
/// <returns>The marginal distribution (or an approximation to it)</returns>
public object Infer(IVariable var)
{
IGeneratedAlgorithm ca = InferAll(false, var);
return ca.Marginal(var.Name);
}
/// <summary>
/// Performs an inference query for the specified variable, given a query type.
/// </summary>
/// <param name="var">The variable whose marginal is to be inferred</param>
/// <param name="queryType">The type of query</param>
/// <returns>The marginal distribution (or an approximation to it)</returns>
public object Infer(IVariable var, QueryType queryType)
{
var ca = InferAll(false, var);
return ca.Marginal(var.Name, queryType.Name);
}
/// <summary>
/// Infers the marginal distribution for the specified variable.
/// </summary>
/// <typeparam name="TReturn">Desired return type which may be a distribution type or an array type if the argument is a VariableArray</typeparam>
/// <param name="var">The variable whose marginal is to be inferred</param>
/// <returns>The marginal distribution (or an approximation to it)</returns>
public TReturn Infer<TReturn>(IVariable var)
{
IGeneratedAlgorithm ca = InferAll(false, var);
return ca.Marginal<TReturn>(var.Name);
}
/// <summary>
/// Infers the marginal distribution for the specified variable, and the specified
/// query type
/// </summary>
/// <typeparam name="TReturn">Desired return type</typeparam>
/// <param name="var">The variable whose marginal is to be inferred</param>
/// <param name="queryType">The query type</param>
/// <returns>The marginal distribution (or an approximation to it)</returns>
public TReturn Infer<TReturn>(IVariable var, QueryType queryType)
{
// If asked for a non-default QueryType that is not an attribute of var,
// this code will give an error message.
// TM: This code previously recompiled the model if a new QueryType was requested.
// This is bad because it leads to inconsistent inference results for different QueryTypes.
// For example, if someone infers Samples and then infers Conditionals with a recompile in between.
// Or if someone infers Marginal and then MarginalDividedByPrior with a recompile in between.
var ca = InferAll(false, var);
return ca.Marginal<TReturn>(var.Name, queryType.Name);
}
internal static T ConvertValueToDistribution<T>(object value)
{
if (value is T) return (T)value;
Type toType = typeof(T);
Type domainType = value.GetType();
MethodInfo method = new Func<Converter<object, object>>(InferenceEngine.GetValueToDistributionConverter<object, object>).Method.GetGenericMethodDefinition();
method = method.MakeGenericMethod(domainType, toType);
Delegate converter = (Delegate)Util.Invoke(method, null);
return (T)Util.DynamicInvoke(converter, value);
}
internal static TOutput[] ArrayConvertAll<TInput, TOutput>(Converter<TInput, TOutput> converter, TInput[] array)
{
return Array.ConvertAll(array, converter);
}
internal static Converter<TInput, TOutput> GetValueToDistributionConverter<TInput, TOutput>()
{
Exception exception = null;
Type domainType = typeof(TInput);
Type toType = typeof(TOutput);
if (toType.IsArray)
{
Type toEltType = toType.GetElementType();
if (domainType.IsArray)
{
try
{
Type fromEltType = domainType.GetElementType();
// ArrayConvertAll<TInput,TOutput>(itemConverter, TInput[] array)
MethodInfo convertAll =
new Func<Converter<object, object>, object[], object[]>(InferenceEngine.ArrayConvertAll<object, object>).Method.GetGenericMethodDefinition();
convertAll = convertAll.MakeGenericMethod(fromEltType, toEltType);
MethodInfo thisMethod =
new Func<Converter<object, object>>(InferenceEngine.GetValueToDistributionConverter<object, object>).Method.GetGenericMethodDefinition();
thisMethod = thisMethod.MakeGenericMethod(fromEltType, toEltType);
object itemConverter = Util.Invoke(thisMethod, null);
return (Converter<TInput, TOutput>)Delegate.CreateDelegate(typeof(Converter<TInput, TOutput>), itemConverter, convertAll);
}
catch (Exception e)
{
// fall through to exception below
exception = e;
}
}
// fall through
}
else
{
Type hasPointType = typeof(HasPoint<>).MakeGenericType(domainType);
if (hasPointType.IsAssignableFrom(toType))
{
MethodInfo method =
(MethodInfo)
Microsoft.ML.Probabilistic.Compiler.Reflection.Invoker.GetBestMethod(toType, "PointMass",
BindingFlags.Public | BindingFlags.Static | BindingFlags.InvokeMethod | BindingFlags.FlattenHierarchy, null,
new Type[] { domainType }, out exception);
if (method != null)
{
return (Converter<TInput, TOutput>)Delegate.CreateDelegate(typeof(Converter<TInput, TOutput>), method);
}
else if (toType.IsGenericType && toType.GetGenericTypeDefinition().Equals(typeof(PointMass<>)))
{
MethodInfo method2 =
(MethodInfo)
Microsoft.ML.Probabilistic.Compiler.Reflection.Invoker.GetBestMethod(toType, "Create",
BindingFlags.Public | BindingFlags.Static | BindingFlags.InvokeMethod | BindingFlags.FlattenHierarchy, null,
new Type[] { domainType }, out exception);
return (Converter<TInput, TOutput>)Delegate.CreateDelegate(typeof(Converter<TInput, TOutput>), method2);
}
// fall through
}
else exception = new Exception(StringUtil.TypeToString(toType) + " does not implement " + StringUtil.TypeToString(hasPointType));
}
throw new ArgumentException("Cannot convert to distribution type " + StringUtil.TypeToString(toType) + " from type " + StringUtil.TypeToString(domainType) + ".",
exception);
}
/// <summary>
/// Attempts to convert the supplied object to the specified target type.
/// Throws an ArgumentException if this is not possible.
/// </summary>
/// <remarks>
/// Currently supports converting DistributionArray instances to .NET arrays and
/// converting PointMass instances to distributions configured as point masses.
/// </remarks>
/// <typeparam name="T">The target type</typeparam>
/// <param name="obj">The source object</param>
/// <returns>The source object converted to type T</returns>
internal static T ConvertDistributionToType<T>(object obj)
{
// Fast path if the object is already of the right type
if (obj is T) return (T)obj;
// Conversion from PointMass to an instance of T set to a point mass, if T supports HasPoint.
Type fromType = obj.GetType();
if (fromType.IsGenericType && fromType.GetGenericTypeDefinition().Equals(typeof(PointMass<>)))
{
object value = fromType.GetProperty("Point").GetValue(obj, null);
return ConvertValueToDistribution<T>(value);
}
Type toType = typeof(T);
// Conversion from DistributionArray to dotNET array
if (toType.IsArray)
{
try
{
return Distribution.ToArray<T>(obj);
}
catch (Exception)
{
// throw exception below instead
}
}
return ConvertValueToDistribution<T>(obj);
throw new ArgumentException("Cannot convert to type " + StringUtil.TypeToString(toType) + " from type " + StringUtil.TypeToString(fromType) + ".");
}
/// <summary>
/// Computes the output message (message to the prior) for the specified variable.
/// </summary>
/// <typeparam name="Distribution">Desired distribution type</typeparam>
/// <param name="var">The variable whose output message is to be inferred</param>
/// <returns>The output message (or an approximation to it)</returns>
public Distribution GetOutputMessage<Distribution>(IVariable var)
{
return Infer<Distribution>(var, QueryTypes.MarginalDividedByPrior);
}
///// <summary>
///// Optimize the engine to infer only the specified variables, overriding any previous calls to RestrictInferenceTo.
///// </summary>
///// <param name="vars">The variables whose marginals are to be inferred</param>
//public void RestrictInferenceTo(params IVariable[] vars)
//{
// RestrictInferenceTo((IEnumerable<IVariable>)vars);
//}
///// <summary>
///// Optimize the engine to infer only the specified variables, overriding any previous calls to RestrictInferenceTo.
///// </summary>
///// <param name="vars">The variables whose marginals are to be inferred</param>
//public void RestrictInferenceTo(IEnumerable<IVariable> vars)
//{
// mb.Build(this, true, vars);
// compiledAlgorithms.Clear();
//}
private IList<IVariable> optimiseForVariables = null;
/// <summary>
/// The variables to optimize the engine to infer.
/// If set to a list of variables, only the specified variables can be inferred by this engine.
/// If set to null, any variable can be inferred by this engine.
/// </summary>
/// <remarks>
/// Setting this property to a list of variables can improve performance by removing redundent
/// computation and storage needed to infer marginals for variables which are not on the list.</remarks>
public IList<IVariable> OptimiseForVariables
{
get
{
if (optimiseForVariables == null) return null;
// return a read only view of the internal list
return new ReadOnlyCollection<IVariable>(optimiseForVariables);
}
set
{
if (value == null)
{
optimiseForVariables = null;
}
else
{
// make a copy of the passed in list
optimiseForVariables = new List<IVariable>(value);
}
InvalidateCompiledAlgorithms();
}
}
protected IGeneratedAlgorithm InferAll(bool inferOnlySpecifiedVars, IVariable var)
{
IGeneratedAlgorithm ca = GetCompiledInferenceAlgorithm(inferOnlySpecifiedVars, var);
Execute(ca);
return ca;
}
protected IGeneratedAlgorithm InferAll(bool inferOnlySpecifiedVars, IEnumerable<IVariable> vars)
{
IGeneratedAlgorithm ca = GetCompiledInferenceAlgorithm(inferOnlySpecifiedVars, vars);
Execute(ca);
return ca;
}
private void SetObservedValues(CompiledAlgorithmInfo info)
{
foreach (Variable var in info.observedVarsInOrder)
{
info.exec.SetObservedValue(var.NameInGeneratedCode, ((HasObservedValue)var).ObservedValue);
}
}
private void Execute(IGeneratedAlgorithm ca)
{
// If there is a message update listener, try to add in the engine to listen to messages.
if (this.MessageUpdated != null)
{
DebuggingSupport.TryAddRemoveEventListenerDynamic(ca, OnMessageUpdated, add: true);
}
// Register the ProgressChanged handler only while doing inference within InferenceEngine.
// We do not want the handler to run if the user accesses the GeneratedAlgorithms directly.
ca.ProgressChanged += OnProgressChanged;
try
{
Stopwatch s = null;
if (ShowTimings)
{
s = new Stopwatch();
s.Start();
FileStats.Clear();
}
if (ResetOnObservedValueChanged)
ca.Execute(NumberOfIterations);
else
ca.Update(NumberOfIterations - ca.NumberOfIterationsDone);
if (s != null)
{
long elapsed = s.ElapsedMilliseconds;
Console.WriteLine("Inference time was {1}ms (max {0} iterations)",
NumberOfIterations, elapsed);
if (FileStats.ReadCount > 0 || FileStats.WriteCount > 0)
Console.WriteLine("{0} file reads {1} file writes", FileStats.ReadCount, FileStats.WriteCount);
}
}
finally
{
ca.ProgressChanged -= OnProgressChanged;
if (this.MessageUpdated != null)
{
DebuggingSupport.TryAddRemoveEventListenerDynamic(ca, OnMessageUpdated, add: false);
}
}
}
/// <summary>
/// Returns a compiled algorithm which can later be used to infer marginal
/// distributions for the specified variables. This method allows more fine-grained
/// control over the inference procedure.
/// </summary>
/// <remarks>This method should not be used unless fine-grained control over the
/// inference is required. Infer.NET will cache the last compiled algorithm
/// and re-use it if possible.
/// </remarks>
/// <param name="vars">The variables whose marginals are to be computed by the returned algorithm.</param>
/// <returns>An IGeneratedAlgorithm object</returns>
public IGeneratedAlgorithm GetCompiledInferenceAlgorithm(params IVariable[] vars)
{
return GetCompiledInferenceAlgorithm(true, vars);
}
/// <summary>
/// For advanced use. Returns all the model expressions that are relevant to
/// inferring the set of variables provided. This may be useful for constructing visualisations of the model.
/// </summary>
/// <remarks>
/// The returned collection includes Variable and VariableArray objects which the engine has determined are
/// relevant to inferring marginals over the variables provided. This will at least include
/// the provided variables, but may include other relevant variables as well. It will also
/// include MethodInvoke objects which act as priors, constraints or factors in the model.
/// </remarks>
/// <param name="vars">The variables to build a model for</param>
/// <returns>A collection of model expressions</returns>
public IReadOnlyCollection<IModelExpression> GetRelevantModelExpressions(params IVariable[] vars)
{
ModelBuilder mb2 = new ModelBuilder();
mb2.Build(this, true, vars);
return mb2.ModelExpressions;
}
internal IGeneratedAlgorithm GetCompiledInferenceAlgorithm(bool inferOnlySpecifiedVars, IVariable var)
{
// optimize the case of repeated inference on the same variable
CompiledAlgorithmInfo info;
if (compiledAlgorithmForVariable.TryGetValue(var, out info))
{
//SetObservedValues(info);
return info.exec;
}
else
{
return BuildAndCompile(false, new IVariable[] { var });
}
}
internal IGeneratedAlgorithm GetCompiledInferenceAlgorithm(bool inferOnlySpecifiedVars, IEnumerable<IVariable> vars)
{
// If a single compiledAlgorithm is available to infer all of the vars, then return it.
// otherwise, build a new one.
CompiledAlgorithmInfo info = null;
foreach (IVariable var in vars)
{
CompiledAlgorithmInfo info2;
if (!compiledAlgorithmForVariable.TryGetValue(var, out info2)) return BuildAndCompile(inferOnlySpecifiedVars, vars);
if (info == null) info = info2;
else if (!ReferenceEquals(info, info2)) return BuildAndCompile(inferOnlySpecifiedVars, vars);
}
if (info == null) throw new ArgumentException("Empty set of variables to infer");
return info.exec;
}
private IGeneratedAlgorithm BuildAndCompile(bool inferOnlySpecifiedVars, IEnumerable<IVariable> vars)
{
if (optimiseForVariables != null)
{
foreach (IVariable v in vars)
{
if (!optimiseForVariables.Contains(v))
{
throw new ArgumentException("Cannot call ML.Probabilistic() on variable '" + v.Name +
"' which is not in the OptimiseForVariables list. The list currently contains: " +
StringUtil.CollectionToString(OptimiseForVariables.ListSelect(x => "'" + x.Name + "'"), ",") + ".");
}
}
if (optimiseForVariables.Contains(null))
throw new ArgumentException("OptimiseForVariables contains a null variable");
mb.Build(this, true, optimiseForVariables);
}
else
{
mb.Build(this, inferOnlySpecifiedVars, vars);
}
return Compile();
}
internal void OnProgressChanged(object sender, ProgressChangedEventArgs progress)
{
if (ProgressChanged != null) ProgressChanged(this, new InferenceProgressEventArgs() { Iteration = progress.Iteration, Algorithm = (IGeneratedAlgorithm)sender });
if (!ShowProgress) return;
int iteration = progress.Iteration + 1;
if (iteration == 1) Console.WriteLine("Iterating: ");
Console.Write(iteration % 10 == 0 ? "|" : ".");
if ((iteration % 50 == 0) || (iteration == NumberOfIterations))
{
Console.WriteLine(" " + iteration);
}
}
/// <summary>
/// Event that is fired when the progress of inference changes, typically at the
/// end of one iteration of the inference algorithm.
/// </summary>
public event InferenceProgressEventHandler ProgressChanged;
internal void OnMessageUpdated(object sender, MessageUpdatedEventArgs messageEvent)
{
if (MessageUpdated != null) MessageUpdated(sender as IGeneratedAlgorithm, messageEvent);
}
/// <summary>
/// Event that is fired when a message that has been marked with ListenToMessages has been updated.
/// </summary>
public event MessageUpdatedEventHandler MessageUpdated;
/// <summary>
/// Ensures that the last compiled algorithm will not be re-used. This should be called
/// whenever a change is made that requires recompiling (but not rebuilding) the model.
/// </summary>
internal void InvalidateCompiledAlgorithms()
{
compiledAlgorithmForVariable.Clear();
compiledAlgorithms.Clear();
}
/// <summary>
/// For message passing algorithms, reset all messages to their initial values.
/// </summary>
protected void Reset()
{
foreach (CompiledAlgorithmInfo info in compiledAlgorithms)
{
info.exec.Reset();
}
}
/// <summary>
/// If true (default), Infer resets messages to their initial values if an observed value has changed.
/// </summary>
public bool ResetOnObservedValueChanged { get; set; }
internal static void InvalidateAllEngines(IModelExpression expr)
{
foreach (WeakReference weakRef in allEngineInstances.Keys)
{
InferenceEngine engine = weakRef.Target as InferenceEngine;
if (engine == null)
{
// The engine has been freed, so we can remove it from the dictionary.
EmptyStruct value;
allEngineInstances.TryRemove(weakRef, out value);
}
else
{
var modelExpressions = engine.mb.ModelExpressions;
if (modelExpressions != null && modelExpressions.Contains(expr))
{
engine.mb.Reset(); // must rebuild the model
engine.InvalidateCompiledAlgorithms();
}
}
}
}
internal static void ObservedValueChanged(Variable var)
{
foreach (WeakReference weakRef in allEngineInstances.Keys)
{
InferenceEngine engine = weakRef.Target as InferenceEngine;
if (engine == null)
{
// The engine has been freed, so we can remove it from the dictionary.
EmptyStruct value;
allEngineInstances.TryRemove(weakRef, out value);
}
else
{
foreach (CompiledAlgorithmInfo info in engine.compiledAlgorithms)
{
if (info.observedVars.Contains(var))
{
info.exec.SetObservedValue(var.NameInGeneratedCode, ((HasObservedValue)var).ObservedValue);
}
}
}
}
}
/// <summary>
/// The model compiler that this inference engine uses.
/// </summary>
public ModelCompiler Compiler
{
get { return compiler; }
}
/// <summary>
/// The default inference algorithm to use. This can be overridden for individual
/// variables or factors using the Algorithm attribute.
/// </summary>
public IAlgorithm Algorithm
{
get { return compiler.Algorithm; }
set { compiler.Algorithm = value; }
}
private int numberOfIterations = -1;
/// <summary>
/// The number of iterations to use when executing the compiled inference algorithm.
/// </summary>
public int NumberOfIterations
{
get { return (numberOfIterations < 0) ? Algorithm.DefaultNumberOfIterations : numberOfIterations; }
set { numberOfIterations = value; }
}
private bool showProgress = true;
/// <summary>
/// If true, prints progress information to the console during inference.
/// </summary>
public bool ShowProgress
{
get { return showProgress; }
set { showProgress = value; }
}
private bool showTimings = false;
/// <summary>
/// If true, prints timing information to the console during inference.
/// </summary>
public bool ShowTimings
{
get { return showTimings; }
set { showTimings = value; }
}
private bool showMsl = false;
/// <summary>
/// If true, prints the model definition in Model Specification Language (MSL), prior
/// to compiling the model.
/// </summary>
public bool ShowMsl
{
get { return showMsl; }
set { showMsl = value; }
}
/// <summary>
/// If true, any warnings encountered during model compilation will be printed to the console.
/// </summary>
public bool ShowWarnings
{
get { return compiler.ShowWarnings; }
set { compiler.ShowWarnings = value; }
}
/// <summary>
/// If true, displays the factor graph for the model, prior to compiling it.
/// </summary>
public bool ShowFactorGraph
{
get;
set;
}
/// <summary>
/// If not null, the factor graph will be saved (in DGML format) to a file in the specified folder (created if necessary) under the model name and the extension ".dgml"
/// </summary>
public string SaveFactorGraphToFolder
{
get;
set;
}
/// <summary>
/// If true, displays the schedule for the model, after the scheduler has run.
/// </summary>
public bool ShowSchedule
{
get { return Compiler.ShowSchedule; }
set { Compiler.ShowSchedule = value; }
}
/// <summary>
/// Configures this inference engine by copying the settings from the supplied inference engine.
/// </summary>
/// <param name="engine"></param>
public void SetTo(InferenceEngine engine)
{
// note this does not copy events
compiler.SetTo(engine.compiler);
groups = new List<VariableGroup>();
ModelName = engine.ModelName;
numberOfIterations = engine.numberOfIterations;
ShowFactorGraph = engine.ShowFactorGraph;
SaveFactorGraphToFolder = engine.SaveFactorGraphToFolder;
showMsl = engine.showMsl;
showProgress = engine.showProgress;
showTimings = engine.showTimings;
ResetOnObservedValueChanged = engine.ResetOnObservedValueChanged;
}
/// <summary>
/// Shows the factor manager, indicating which factors are available in Infer.NET and which
/// are supported for each built-in inference algorithm.
/// </summary>
public static void ShowFactorManager(bool showMissingEvidences)
{
ShowFactorManager(showMissingEvidences, GetBuiltInAlgorithms());
}
/// <summary>
/// Returns an array of the built-in inference algorithms.
/// </summary>
public static IAlgorithm[] GetBuiltInAlgorithms()
{
return algs;
}
/// <summary>
/// Shows the factor manager, indicating which factors are available in Infer.NET and which
/// are supported for the supplied list of inference algorithms.
/// </summary>
public static void ShowFactorManager(bool showMissingEvidences, params IAlgorithm[] algorithms)
{
if (Visualizer?.FactorManager != null)
Visualizer.FactorManager.ShowFactorManager(showMissingEvidences, algorithms);
}
/// <summary>
/// Variable groupings for the algorithm
/// </summary>
private List<VariableGroup> groups = new List<VariableGroup>();
/// <summary>
/// List of groups
/// </summary>
public IList<VariableGroup> Groups
{
get { return groups.AsReadOnly(); }
}
/// <summary>
/// Add a variable group
/// </summary>
/// <param name="variables"></param>
/// <returns></returns>
public VariableGroup Group(params Variable[] variables)
{
VariableGroup vg = VariableGroup.FromVariables(variables);
for (int i = 0; i < variables.Length; i++)
{
variables[i].AddAttribute(new Models.Attributes.GroupMember(vg, i == 0));
}
groups.Add(vg);
return vg;
}
}
}

Просмотреть файл

@ -0,0 +1,32 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
namespace Microsoft.ML.Probabilistic.Models
{
/// <summary>
/// Delegate for handlers of inference progress events.
/// </summary>
/// <param name="engine">The inference engine which invoked the inference query</param>
/// <param name="progress">The progress object describing the progress of the inference algorithm</param>
public delegate void InferenceProgressEventHandler(InferenceEngine engine, InferenceProgressEventArgs progress);
/// <summary>
/// Provides information about the progress of the inference algorithm, as it
/// is being executed.
/// </summary>
public class InferenceProgressEventArgs : EventArgs
{
/// <summary>
/// The iteration of inference that has just been completed.
/// </summary>
public int Iteration { get; internal set; }
/// <summary>
/// The compiled algorithm which is performing the inference.
/// </summary>
public IGeneratedAlgorithm Algorithm { get; internal set; }
}
}

Просмотреть файл

@ -0,0 +1,404 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Reflection;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
namespace Microsoft.ML.Probabilistic.Models
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
internal class MethodInvoke : IModelExpression
{
/// <summary>
/// Helps build class declarations
/// </summary>
private static readonly CodeBuilder Builder = CodeBuilder.Instance;
// The factor or constraint method
internal MethodInfo method;
// The arguments
internal List<IModelExpression> args = new List<IModelExpression>();
// The return value (or null if void)
internal IModelExpression returnValue;
// The operator if this method was created from an operator
internal Variable.Operator? op = null;
// Attributes of the method invocation i.e. the factor or constraint.
internal List<ICompilerAttribute> attributes = new List<ICompilerAttribute>();
// The condition blocks this method is contained in
private List<IStatementBlock> containers;
internal List<IStatementBlock> Containers
{
get { return containers; }
}
// Provides global ordering for ModelBuilder
internal readonly int timestamp;
private static readonly GlobalCounter globalCounter = new GlobalCounter();
internal static int GetTimestamp()
{
return globalCounter.GetNext();
}
internal MethodInvoke(MethodInfo method, params IModelExpression[] args)
: this(StatementBlock.GetOpenBlocks(), method, args)
{
}
internal MethodInvoke(IEnumerable<IStatementBlock> containers, MethodInfo method, params IModelExpression[] args)
{
this.timestamp = GetTimestamp();
this.method = method;
this.args.AddRange(args);
this.containers = new List<IStatementBlock>(containers);
foreach (IModelExpression arg in args)
{
if (ReferenceEquals(arg, null)) throw new ArgumentNullException();
if (arg is Variable)
{
Variable v = (Variable) arg;
if (v.IsObserved) continue;
foreach (ConditionBlock cb in v.GetContainers<ConditionBlock>())
{
if (!this.containers.Contains(cb))
{
throw new InvalidOperationException(arg + " was created in condition " + cb + " and cannot be used outside. To give " + arg +
" a conditional definition, use SetTo inside " + cb + " rather than assignment (=)");
}
}
}
}
foreach (ConditionBlock cb in StatementBlock.EnumerateBlocks<ConditionBlock>(containers))
{
cb.ConditionVariableUntyped.constraints.Add(this);
}
}
/// <summary>
/// The name of the method
/// </summary>
public string Name
{
get { return method.Name; }
}
/// <summary>
/// The method arguments
/// </summary>
public List<IModelExpression> Arguments
{
get { return args; }
}
/// <summary>
/// The expression the return value of the method will be assigned to.
/// </summary>
public IModelExpression ReturnValue
{
get { return returnValue; }
}
public void AddAttribute(ICompilerAttribute attr)
{
InferenceEngine.InvalidateAllEngines(this);
attributes.Add(attr);
}
/// <summary>
/// Inline method for adding an attribute to a method invoke. This method
/// returns the method invoke object, so that is can be used in an inline expression.
/// </summary>
/// <param name="attr">The attribute to add</param>
/// <returns>This object</returns>
public MethodInvoke Attrib(ICompilerAttribute attr)
{
AddAttribute(attr);
return this;
}
/// <summary>
/// Get all attributes of this variable having type AttributeType.
/// </summary>
/// <typeparam name="AttributeType"></typeparam>
/// <returns></returns>
public IEnumerable<AttributeType> GetAttributes<AttributeType>() where AttributeType : ICompilerAttribute
{
// find the base variable
foreach (ICompilerAttribute attr in attributes)
{
if (attr is AttributeType) yield return (AttributeType) attr;
}
}
public IExpression GetExpression()
{
IExpression expr = GetMethodInvokeExpression();
if (returnValue == null) return expr;
expr = Builder.AssignExpr(returnValue.GetExpression(), expr);
return expr;
}
/// <summary>
/// True if the expression contains a loop index and all other variable references are givens.
/// </summary>
/// <returns></returns>
internal bool CanBeInlined()
{
if (op == null) return false;
bool hasLoopIndex = false;
for (int i = 0; i < args.Count; i++)
{
if (args[i] is Variable<int>)
{
Variable<int> v = (Variable<int>) args[i];
if (v.IsLoopIndex) hasLoopIndex = true;
else if (!v.IsObserved) return false;
}
else return false;
}
return hasLoopIndex;
}
internal IExpression GetMethodInvokeExpression(bool inline = false)
{
IExpression[] argExprs = new IExpression[args.Count];
for (int i = 0; i < argExprs.Length; i++)
{
argExprs[i] = args[i].GetExpression();
}
if (inline || CanBeInlined())
{
if (op == Variable.Operator.Plus) return Builder.BinaryExpr(argExprs[0], BinaryOperator.Add, argExprs[1]);
else if (op == Variable.Operator.Minus) return Builder.BinaryExpr(argExprs[0], BinaryOperator.Subtract, argExprs[1]);
else if (op == Variable.Operator.LessThan) return Builder.BinaryExpr(argExprs[0], BinaryOperator.LessThan, argExprs[1]);
else if (op == Variable.Operator.LessThanOrEqual) return Builder.BinaryExpr(argExprs[0], BinaryOperator.LessThanOrEqual, argExprs[1]);
else if (op == Variable.Operator.GreaterThan) return Builder.BinaryExpr(argExprs[0], BinaryOperator.GreaterThan, argExprs[1]);
else if (op == Variable.Operator.GreaterThanOrEqual) return Builder.BinaryExpr(argExprs[0], BinaryOperator.GreaterThanOrEqual, argExprs[1]);
else if (op == Variable.Operator.Equal) return Builder.BinaryExpr(argExprs[0], BinaryOperator.ValueEquality, argExprs[1]);
else if (op == Variable.Operator.NotEqual) return Builder.BinaryExpr(argExprs[0], BinaryOperator.ValueInequality, argExprs[1]);
}
IMethodInvokeExpression imie = null;
if (method.IsGenericMethod && !method.ContainsGenericParameters)
{
imie = Builder.StaticGenericMethod(method, argExprs);
}
else
{
imie = Builder.StaticMethod(method, argExprs);
}
return imie;
}
public override string ToString()
{
StringBuilder sb = new StringBuilder(method.Name);
sb.Append('(');
bool isFirst = true;
foreach (IModelExpression arg in args)
{
if (!isFirst)
sb.Append(',');
else
isFirst = false;
if(arg != null)
sb.Append(arg.ToString());
}
sb.Append(')');
return sb.ToString();
}
internal IEnumerable<IModelExpression> returnValueAndArgs()
{
if (returnValue != null) yield return returnValue;
foreach (IModelExpression arg in args) yield return arg;
}
/// <summary>
/// Get the set of ranges used as indices in the arguments of the MethodInvoke, that are not included in its ForEach containers.
/// </summary>
/// <returns></returns>
internal Set<Range> GetLocalRangeSet()
{
Set<Range> ranges = new Set<Range>();
foreach (IModelExpression arg in returnValueAndArgs()) ForEachRange(arg, ranges.Add);
foreach (IStatementBlock b in containers)
{
if (b is HasRange)
{
HasRange br = (HasRange) b;
ranges.Remove(br.Range);
}
}
return ranges;
}
/// <summary>
/// Get the set of ranges used as indices in the arguments of the MethodInvoke, that are not included in its ForEach containers.
/// </summary>
/// <returns></returns>
internal List<Range> GetLocalRangeList()
{
List<Range> ranges = new List<Range>();
foreach (IModelExpression arg in returnValueAndArgs())
{
ForEachRange(arg, delegate(Range r) { if (!ranges.Contains(r)) ranges.Add(r); });
}
foreach (IStatementBlock b in containers)
{
if (b is HasRange)
{
HasRange br = (HasRange) b;
ranges.Remove(br.Range);
}
}
return ranges;
}
internal static void ForEachRange(IModelExpression arg, Action<Range> action)
{
if (arg is Range)
{
action((Range) arg);
return;
}
else if (arg is Variable)
{
Variable v = (Variable) arg;
if (v.IsLoopIndex)
{
action(v.loopRange);
}
if (v.IsArrayElement)
{
ForEachRange(v.ArrayVariable, action);
// must add item indices after array's indices
foreach (IModelExpression expr in v.indices)
{
ForEachRange(expr, action);
}
}
}
}
/// <summary>
/// Get a dictionary mapping all array indexer expressions (including sub-expressions) to a list of their Range indexes, in order.
/// </summary>
/// <param name="args"></param>
/// <returns></returns>
internal static Dictionary<IModelExpression, List<List<Range>>> GetRangeBrackets(IEnumerable<IModelExpression> args)
{
Dictionary<IModelExpression, List<List<Range>>> dict = new Dictionary<IModelExpression, List<List<Range>>>();
foreach (IModelExpression arg in args)
{
List<List<Range>> brackets = GetRangeBrackets(arg, dict);
dict[arg] = brackets;
}
return dict;
}
/// <summary>
/// If arg is an array indexer expression, get a list of all Range indexes, in order. Indexes that are not Ranges instead get their Ranges added to dict.
/// </summary>
/// <param name="arg"></param>
/// <param name="dict"></param>
/// <returns></returns>
internal static List<List<Range>> GetRangeBrackets(IModelExpression arg, IDictionary<IModelExpression, List<List<Range>>> dict)
{
if (arg is Variable)
{
Variable v = (Variable) arg;
if (v.IsArrayElement)
{
List<List<Range>> brackets = GetRangeBrackets(v.ArrayVariable, dict);
List<Range> indices = new List<Range>();
// must add item indices after array's indices
foreach (IModelExpression expr in v.indices)
{
if (expr is Range) indices.Add((Range) expr);
else
{
List<List<Range>> argBrackets = GetRangeBrackets(expr, dict);
dict[expr] = argBrackets;
}
}
brackets.Add(indices);
return brackets;
}
}
return new List<List<Range>>();
}
internal static int CompareRanges(IDictionary<IModelExpression, List<List<Range>>> dict, Range a, Range b)
{
foreach (List<List<Range>> brackets in dict.Values)
{
bool aInPreviousBracket = false;
bool bInPreviousBracket = false;
foreach (List<Range> bracket in brackets)
{
bool aInThisBracket = false;
bool bInThisBracket = false;
foreach (Range range in bracket)
{
if (range == a) aInThisBracket = true;
if (range == b) bInThisBracket = true;
}
if (bInThisBracket && aInPreviousBracket && !bInPreviousBracket) return -1;
if (aInThisBracket && bInPreviousBracket && !aInPreviousBracket) return 1;
aInPreviousBracket = aInThisBracket;
bInPreviousBracket = bInThisBracket;
}
}
return 0;
}
/// <summary>
/// True if arg is indexed by at least the given ranges.
/// </summary>
/// <param name="arg"></param>
/// <param name="ranges"></param>
/// <returns></returns>
internal static bool IsIndexedByAll(IModelExpression arg, ICollection<Range> ranges)
{
Set<Range> argRanges = new Set<Range>();
ForEachRange(arg, argRanges.Add);
foreach (Range r in ranges)
{
if (!argRanges.Contains(r)) return false;
}
return true;
}
/*internal string GetReturnValueName()
{
if (method == null) return "";
if (op != null)
{
return args[0].Name + " " + op + " " + args[1].Name;
}
string;
}*/
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -0,0 +1,501 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Models.Attributes
{
/// <summary>
/// Specifies the range of values taken by an integer variable, or the dimension of a Dirichlet variable.
/// This attribute can be used to explicitly specify the value range for a variable
/// in cases where it cannot be deduced by the model compiler.
/// </summary>
[AttributeUsage(AttributeTargets.All, AllowMultiple = false)]
public class ValueRange : Attribute, ICompilerAttribute
{
/// <summary>
/// The range indicating the values a variable can take or the dimension of the variable.
/// </summary>
public Range Range;
/// <summary>
/// Creates a ValueRange with the specified range.
/// </summary>
/// <param name="range"></param>
public ValueRange(Range range)
{
this.Range = range;
}
/// <summary>
/// Returns a string representation of the ValueRange.
/// </summary>
/// <returns></returns>
public override string ToString()
{
return "ValueRange(" + Range + ")";
}
}
/// <summary>
/// Specifies a prototype marginal distribution for a variable. This attribute
/// can be used to explicitly specify the marginal distribution type for a variable
/// in cases where it cannot be deduced by the model compiler.
/// </summary>
[AttributeUsage(AttributeTargets.All, AllowMultiple = false)]
public class MarginalPrototype : Attribute, ICompilerAttribute
{
/// <summary>
/// The prototype marginal distribution
/// </summary>
public object prototype;
internal IExpression prototypeExpression;
/// <summary>
/// Creates a new marginal prototype attribute. This attribute
/// targets variables.
/// </summary>
/// <param name="prototype">The marginal prototype</param>
public MarginalPrototype(object prototype)
{
this.prototype = prototype;
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
/// <exclude/>
public override string ToString()
{
return "MarginalPrototype(" + ((prototypeExpression != null) ? prototypeExpression : prototype) + ")";
}
}
///// <summary>
///// Attribute which indicates a sparse marginal prototype
///// </summary>
//public class Sparse : Attribute { }
/// <summary>
/// When attached to a Range, indicates that the elements of the range should be updated sequentially rather than in parallel.
/// </summary>
public class Sequential : ICompilerAttribute
{
/// <summary>
/// If true, updates should be done in both directions of the loop
/// </summary>
public bool BackwardPass;
}
/// <summary>
/// When attached to a Sequential Range, specifies which indices should be processed by each thread
/// </summary>
public class ParallelSchedule : ICompilerAttribute
{
internal IModelExpression scheduleExpression;
/// <summary>
/// Create a new ParallelSchedule attribute
/// </summary>
/// <param name="scheduleExpression">An observed variable of type int[][][], whose dimensions are [thread][block][item]. Each thread must have the same number of blocks, but blocks can be different sizes. Must have at least one thread.</param>
public ParallelSchedule(Variable<int[][][]> scheduleExpression)
{
this.scheduleExpression = scheduleExpression;
}
public override string ToString()
{
return "ParallelSchedule(" + scheduleExpression + ")";
}
}
/// <summary>
/// When attached to a Sequential Range, specifies which indices should be processed by each thread
/// </summary>
internal class ParallelScheduleExpression : ICompilerAttribute
{
internal IExpression scheduleExpression;
public ParallelScheduleExpression(IExpression scheduleExpression)
{
this.scheduleExpression = scheduleExpression;
}
public override string ToString()
{
return "ParallelScheduleExpression(" + scheduleExpression + ")";
}
}
/// <summary>
/// When attached to a Sequential Range, specifies which indices should be processed by each thread
/// </summary>
public class DistributedSchedule : ICompilerAttribute
{
internal IModelExpression commExpression;
internal IModelExpression scheduleExpression;
internal IModelExpression schedulePerThreadExpression;
/// <summary>
/// Create a new DistributedSchedule attribute
/// </summary>
/// <param name="commExpression"></param>
public DistributedSchedule(Variable<ICommunicator> commExpression)
{
this.commExpression = commExpression;
}
/// <summary>
/// Create a new DistributedSchedule attribute
/// </summary>
/// <param name="commExpression"></param>
/// <param name="scheduleExpression">An observed variable of type int[][], whose dimensions are [block][item].</param>
public DistributedSchedule(Variable<ICommunicator> commExpression, Variable<int[][]> scheduleExpression)
{
this.commExpression = commExpression;
this.scheduleExpression = scheduleExpression;
}
/// <summary>
/// Create a new DistributedSchedule attribute
/// </summary>
/// <param name="commExpression"></param>
/// <param name="schedulePerThreadExpression">An observed variable of type int[][][][], whose dimensions are [distributedStage][thread][block][item]. Each thread must have the same number of blocks, but blocks can be different sizes. Must have at least one thread.</param>
public DistributedSchedule(Variable<ICommunicator> commExpression, Variable<int[][][][]> schedulePerThreadExpression)
{
this.commExpression = commExpression;
this.schedulePerThreadExpression = schedulePerThreadExpression;
}
public override string ToString()
{
return $"DistributedSchedule({scheduleExpression}, {schedulePerThreadExpression})";
}
}
/// <summary>
/// When attached to a Sequential Range, specifies which indices should be processed by each thread
/// </summary>
internal class DistributedScheduleExpression : ICompilerAttribute
{
internal IExpression commExpression;
internal IExpression scheduleExpression;
internal IExpression schedulePerThreadExpression;
public DistributedScheduleExpression(IExpression commExpression, IExpression scheduleExpression, IExpression schedulePerThreadExpression)
{
this.commExpression = commExpression;
this.scheduleExpression = scheduleExpression;
this.schedulePerThreadExpression = schedulePerThreadExpression;
}
public override string ToString()
{
return $"DistributedScheduleExpression({scheduleExpression}, {schedulePerThreadExpression})";
}
}
/// <summary>
/// Attached to an index array.
/// </summary>
public class DistributedCommunication : ICompilerAttribute
{
internal IModelExpression arrayIndicesToSendExpression;
internal IModelExpression arrayIndicesToReceiveExpression;
/// <summary>
/// Creates a new DistributedCommunication attribute
/// </summary>
/// <param name="arrayIndicesToSendExpression"></param>
/// <param name="arrayIndicesToReceiveExpression"></param>
public DistributedCommunication(IModelExpression arrayIndicesToSendExpression, IModelExpression arrayIndicesToReceiveExpression)
{
this.arrayIndicesToSendExpression = arrayIndicesToSendExpression;
this.arrayIndicesToReceiveExpression = arrayIndicesToReceiveExpression;
}
public override string ToString()
{
return $"DistributedCommunication({arrayIndicesToSendExpression}, {arrayIndicesToReceiveExpression})";
}
}
/// <summary>
/// Attached to an index array.
/// </summary>
public class DistributedCommunicationExpression : ICompilerAttribute
{
internal IExpression arrayIndicesToSendExpression;
internal IExpression arrayIndicesToReceiveExpression;
/// <summary>
/// Creates a new DistributedCommunication attribute
/// </summary>
/// <param name="arrayIndicesToSendExpression"></param>
/// <param name="arrayIndicesToReceiveExpression"></param>
public DistributedCommunicationExpression(IExpression arrayIndicesToSendExpression, IExpression arrayIndicesToReceiveExpression)
{
this.arrayIndicesToSendExpression = arrayIndicesToSendExpression;
this.arrayIndicesToReceiveExpression = arrayIndicesToReceiveExpression;
}
public override string ToString()
{
return $"DistributedCommunicationExpression({arrayIndicesToSendExpression}, {arrayIndicesToReceiveExpression})";
}
}
/// <summary>
/// When attached to a Variable, specifies the initial forward messages to be used at the start of inference.
/// </summary>
internal class InitialiseTo : ICompilerAttribute
{
internal IExpression initialMessagesExpression;
public InitialiseTo(IExpression initialMessagesExpression)
{
this.initialMessagesExpression = initialMessagesExpression;
}
public override string ToString()
{
return "InitialiseTo(" + initialMessagesExpression + ")";
}
}
/// <summary>
/// When attached to a Variable, specifies the initial backward messages to be used at the start of inference.
/// </summary>
internal class InitialiseBackwardTo : ICompilerAttribute
{
internal IExpression initialMessagesExpression;
public InitialiseBackwardTo(IExpression initialMessagesExpression)
{
this.initialMessagesExpression = initialMessagesExpression;
}
public override string ToString()
{
return "InitialiseBackwardTo(" + initialMessagesExpression + ")";
}
}
/// <summary>
/// When attached to a Variable, indicates that the backward messages to factors with NoInit attributes should be treated as initialised by the scheduler, even though they will be initialised to uniform
/// </summary>
public class InitialiseBackward : ICompilerAttribute
{
}
/// <summary>
/// Attribute which associates a specified algorithm to a targetted variable or statement.
/// This is used for hybrid inference where different algorithms are used for different parts
/// of the model
/// </summary>
[AttributeUsage(AttributeTargets.All, AllowMultiple = false)]
public class Algorithm : Attribute, ICompilerAttribute
{
/// <summary>
/// The algorithm
/// </summary>
public IAlgorithm algorithm;
/// <summary>
/// Creates a new Algorithm attribute which assigns the given algorithm to the target
/// </summary>
/// <param name="algorithm"></param>
public Algorithm(IAlgorithm algorithm)
{
this.algorithm = algorithm;
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
public override string ToString()
{
return string.Format("Algorithm({0})", algorithm);
}
}
/// <summary>
/// Attribute which associates a specified algorithm to all factors that define a variable.
/// This is used for hybrid inference where different algorithms are used for different parts
/// of the model
/// </summary>
public class FactorAlgorithm : ICompilerAttribute
{
/// <summary>
/// The algorithm
/// </summary>
public IAlgorithm algorithm;
/// <summary>
/// Creates a new Algorithm attribute which assigns the given algorithm to the target's factor
/// </summary>
/// <param name="algorithm"></param>
public FactorAlgorithm(IAlgorithm algorithm)
{
this.algorithm = algorithm;
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
public override string ToString()
{
return string.Format("FactorAlgorithm({0})", algorithm);
}
}
/// <summary>
/// Group member attribute - attached to MSL variables based on
/// inference engine groups
/// </summary>
[AttributeUsage(AttributeTargets.All, AllowMultiple = true)]
public class GroupMember : Attribute, ICompilerAttribute
{
/// <summary>
/// The associated variable group
/// </summary>
public VariableGroup Group;
/// <summary>
/// This variable is a root in this group
/// </summary>
public bool IsRoot;
/// <summary>
/// Creates a group member attribute on a variable
/// </summary>
/// <param name="vg">The variable group</param>
/// <param name="isRoot">Whether this variable is the root of the group</param>
public GroupMember(VariableGroup vg, bool isRoot)
{
Group = vg;
IsRoot = isRoot;
}
/// <summary>
/// Returns a string representation of this group member attribute.
/// </summary>
/// <returns></returns>
public override string ToString()
{
string strRoot = (IsRoot) ? " (root)" : "";
return String.Format("GroupMember({0}{1})", Group, strRoot);
}
}
/// <summary>
/// When attached to a variable, indicates that the variable will not be inferred, producing more efficient generated code.
/// </summary>
public class DoNotInfer : Attribute, ICompilerAttribute
{
}
/// <summary>
/// Attribute which indicates that the output message will be recovered from
/// the targetted variable. The output message of a variable is its marginal divided by
/// its inbox message, and is used in situations where variables are shared
/// between different models
/// </summary>
[Obsolete("Use QueryTypes.MarginalDividedByPrior")]
public class Output : Attribute, ICompilerAttribute
{
}
/// <summary>
/// For expert use only! When sharing a variable between multiple models (e.g. using SharedVariable)
/// you can add this attribute to have the variable be treated as a derived variable, even if it
/// is not derived in the submodel where it appears.
/// </summary>
public class DerivedVariable : Attribute, ICompilerAttribute
{
}
/// <summary>
/// Attribute to generate trace outputs for the messages associated with the target variable
/// </summary>
public class TraceMessages : Attribute, ICompilerAttribute
{
/// <summary>
/// If non-null, only trace messages where the string representing the message expression
/// contains this string.
/// </summary>
public string Containing { get; set; }
}
/// <summary>
/// Attribute to cause message update events to be generated for the messages associated with the target variable
/// </summary>
public class ListenToMessages : Attribute, ICompilerAttribute
{
/// <summary>
/// If non-null, only trace messages where the string representing the message expression
/// contains this string.
/// </summary>
public string Containing { get; set; }
}
/// <summary>
/// Attached to Variable or MethodInvoke to give priority in the operator search path
/// </summary>
public class GivePriorityTo : ICompilerAttribute
{
public object Container;
public GivePriorityTo(object container)
{
this.Container = container;
}
public override string ToString()
{
return "GivePriorityTo(" + Container + ")";
}
}
/// <summary>
/// Attached to Variable objects to specify if outgoing messages should be computed by division
/// </summary>
public class DivideMessages : ICompilerAttribute
{
public bool useDivision;
public DivideMessages(bool useDivision = true)
{
this.useDivision = useDivision;
}
public override string ToString()
{
return "DivideMessages(" + useDivision + ")";
}
}
/// <summary>
/// Attached to Ranges to specify that only one element should be in memory at a time (per thread)
/// </summary>
public class Partitioned : ICompilerAttribute
{
}
/// <summary>
/// Attached to Variable objects to indicate that their uncertainty should be ignored during inference.
/// The inferred marginal will always be a point mass.
/// </summary>
public class PointEstimate : ICompilerAttribute
{
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,377 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Compiler.Reflection;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using System.Reflection;
namespace Microsoft.ML.Probabilistic.Models
{
/// <summary>
/// A range of values from 0 to N-1. The size N may be an integer expression or constant.
/// </summary>
public class Range : IModelExpression, IStatementBlock
{
/// <summary>
/// Helps build class declarations
/// </summary>
private static CodeBuilder Builder = CodeBuilder.Instance;
/// <summary>
/// Model expression for size of the range
/// </summary>
public IModelExpression<int> Size { get; private set; }
/// <summary>
/// Range from which this range was cloned, or null if none.
/// </summary>
public Range Parent { get; private set; }
/// <summary>
/// Name
/// </summary>
protected string name;
/// <summary>
/// Name of the range
/// </summary>
public string Name
{
get { return name; }
set { name = value; }
}
/// <summary>
/// Name used in generated code
/// </summary>
protected string nameInGeneratedCode;
/// <summary>
/// Name used in generated code
/// </summary>
internal string NameInGeneratedCode
{
get
{
if (nameInGeneratedCode == null) nameInGeneratedCode = CodeBuilder.MakeValid(Name);
return nameInGeneratedCode;
}
}
string IModelExpression.Name
{
get { return NameInGeneratedCode; }
}
/// <summary>
/// The attributes associated with this Range.
/// </summary>
protected List<ICompilerAttribute> attributes = new List<ICompilerAttribute>();
/// <summary>
/// Inline method for adding an attribute to a range. This method
/// returns the range object, so that is can be used in an inline expression.
/// </summary>
/// <param name="attr">The attribute to add</param>
/// <returns>The range object</returns>
public Range Attrib(ICompilerAttribute attr)
{
AddAttribute(attr);
return this;
}
/// <summary>
/// Adds an attribute to this range. Attributes can be used
/// to modify how inference is performed on the range.
/// </summary>
/// <param name="attr">The attribute to add</param>
public void AddAttribute(ICompilerAttribute attr)
{
InferenceEngine.InvalidateAllEngines(this);
attributes.Add(attr);
}
/// <summary>
/// Get all attributes of this range having type AttributeType.
/// </summary>
/// <typeparam name="AttributeType"></typeparam>
/// <returns></returns>
public IEnumerable<AttributeType> GetAttributes<AttributeType>() where AttributeType : ICompilerAttribute
{
foreach (ICompilerAttribute attr in attributes)
{
if (attr is AttributeType) yield return (AttributeType) attr;
}
}
/// <summary>
/// Global counter used to generate variable names.
/// </summary>
private static readonly GlobalCounter globalCounter = new GlobalCounter();
/// <summary>
/// Constructs a range containing values from 0 to N-1.
/// </summary>
/// <param name="N">The number of elements in the range, including zero.</param>
public Range(int N)
: this(Variable.Constant(N))
{
}
/// <summary>
/// Constructs a range whose size is given by an integer-value expression.
/// </summary>
/// <param name="size">An expression giving the size of the range</param>
public Range(IModelExpression<int> size)
{
this.name = $"index{globalCounter.GetNext()}";
this.Size = size;
}
/// <summary>
/// Copy constructor
/// </summary>
/// <param name="parent"></param>
protected Range(Range parent)
: this(parent.Size)
{
this.Parent = parent;
}
/// <summary>
/// Create a copy of a range. The copy can be used to index the same arrays as the original range.
/// </summary>
/// <returns></returns>
public Range Clone()
{
return new Range(this);
}
/// <summary>
/// Returns the size of the range as an integer. This will fail if the size is not a constant,
/// for example, if it is a Given value.
/// </summary>
public int SizeAsInt
{
get
{
if (!(Size is Variable<int>)) throw new InvalidOperationException("The Range does not have constant size. Set IsReadOnly=true on the range size.");
Variable<int> sizeVar = (Variable<int>) Size;
if (!(sizeVar.IsObserved && sizeVar.IsReadOnly))
throw new InvalidOperationException("The Range does not have constant size. To use SizeAsInt, set IsReadOnly=true on the range size.");
return sizeVar.ObservedValue;
}
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
/// <exclude/>
public override string ToString()
{
return Name;
//return "Range " + Name + " (Size=" + Size + ")";
}
/// <summary>
/// Inline method to name a range
/// </summary>
/// <param name="name">Name for the range</param>
/// <returns>this</returns>
public Range Named(string name)
{
this.name = name;
return this;
}
internal IExpression GetSizeExpression()
{
return Size.GetExpression();
}
private IVariableDeclaration index;
internal IVariableDeclaration GetIndexDeclaration()
{
if (index == null) index = Builder.VarDecl(NameInGeneratedCode, typeof (int));
return index;
}
/// <summary>
/// Gets the expression for the index variable
/// </summary>
/// <returns></returns>
public IExpression GetExpression()
{
return Builder.VarRefExpr(GetIndexDeclaration());
}
internal static string ToString(IList<Range> ranges)
{
StringBuilder sb = new StringBuilder("[");
foreach (Range r in ranges)
{
if (sb.Length > 1) sb.Append(",");
sb.Append(r.Name);
}
sb.Append("]");
return sb.ToString();
}
internal Range GetRoot()
{
Range root = this;
while (root.Parent != null) root = root.Parent;
return root;
}
private static Range ReplaceExpressions(Range r, Dictionary<IModelExpression, IModelExpression> replacements)
{
IModelExpression<int> newSize = (IModelExpression<int>)ReplaceExpressions(r.Size, replacements);
if (ReferenceEquals(newSize, r.Size))
return r;
Range newRange = new Range(newSize);
newRange.Parent = r;
replacements.Add(r, newRange);
return newRange;
}
private static IModelExpression ReplaceExpressions(IModelExpression expr, Dictionary<IModelExpression, IModelExpression> replacements)
{
if (replacements.ContainsKey(expr)) return replacements[expr];
if (expr is Range)
{
return ReplaceExpressions((Range)expr, replacements);
}
else if (expr is Variable)
{
Variable v = (Variable) expr;
if (v.IsArrayElement)
{
bool changed = false;
IVariableArray newArray = (IVariableArray) ReplaceExpressions(v.ArrayVariable, replacements);
if (!ReferenceEquals(newArray, v.ArrayVariable)) changed = true;
IModelExpression[] newIndices = new IModelExpression[v.indices.Count];
for (int i = 0; i < newIndices.Length; i++)
{
newIndices[i] = ReplaceExpressions(v.indices[i], replacements);
if (!ReferenceEquals(newIndices[i], v.indices[i])) changed = true;
}
if (changed)
return
(IModelExpression)
Invoker.InvokeMember(newArray.GetType(), "get_Item", BindingFlags.Public | BindingFlags.Instance | BindingFlags.InvokeMethod, newArray, newIndices);
}
}
return expr;
}
/// <summary>
/// Construct a new Range in which all subranges and size expressions have been replaced according to given Dictionaries.
/// </summary>
/// <param name="rangeReplacements"></param>
/// <param name="expressionReplacements">Modified on exit to contain newly created ranges</param>
/// <returns></returns>
internal Range Replace(Dictionary<Range, Range> rangeReplacements, Dictionary<IModelExpression, IModelExpression> expressionReplacements)
{
if (rangeReplacements.ContainsKey(this)) return rangeReplacements[this];
return ReplaceExpressions(this, expressionReplacements);
}
/// <summary>
/// True if index is compatible with this range
/// </summary>
/// <param name="index">Index expression</param>
/// <returns></returns>
/// <exclude/>
internal bool IsCompatibleWith(IModelExpression index)
{
if (index is Range) return (((Range) index).GetRoot() == GetRoot());
else if (index is Variable)
{
Variable indexVar = (Variable) index;
Range range = indexVar.GetValueRange(false);
if (range == null) return true;
return IsCompatibleWith(range);
}
else
{
return true;
}
}
/// <summary>
/// Throws an exception if an index expression is not valid for subscripting an array.
/// </summary>
/// <param name="index">Index expression</param>
/// <param name="array">Array that the expression is indexing</param>
/// <exclude/>
internal void CheckCompatible(IModelExpression index, IVariableArray array)
{
if (IsCompatibleWith(index)) return;
string message = StringUtil.TypeToString(array.GetType()) + " " + array + " cannot be indexed by " + index + ".";
if (index is Range)
{
string constructorName = "the constructor";
message += " Perhaps you omitted " + index + " as an argument to " + constructorName + "?";
}
throw new ArgumentException(message, "index");
}
/// <summary>
/// Throws an exception if two index expression collections do not contain the same elements (regardless of order).
/// </summary>
/// <param name="set1">First set of index expressions</param>
/// <param name="set2">Second set of index expressions</param>
/// <exclude/>
internal static void CheckCompatible(ICollection<IModelExpression> set1, ICollection<IModelExpression> set2)
{
if (set2.Count == 0)
{
if (set1.Count > 0)
throw new ArgumentException("The right-hand side is missing .ForEach(" + StringUtil.CollectionToString(set1, ",") + ")");
}
foreach (IModelExpression expr in set1)
{
if (!set2.Contains(expr))
throw new ArgumentException("The right-hand side indices " + Util.CollectionToString(set2) + " do not include the range '" + expr +
"'. Try adding .ForEach(" + expr + ")");
}
foreach (IModelExpression expr in set2)
{
if (!set1.Contains(expr))
throw new ArgumentException("The left-hand side indices " + Util.CollectionToString(set1) + " do not include the range '" + expr +
"', which appears on the right-hand side (perhaps implicitly by an open ForEach block).");
}
}
#region IStatementBlock Members
/// <summary>
/// Get 'for statement' for iterating over the range.
/// </summary>
/// <param name="innerBlock"></param>
/// <returns></returns>
internal IStatement GetStatement(out IList<IStatement> innerBlock)
{
IForStatement fs = Builder.ForStmt(GetIndexDeclaration(), GetSizeExpression());
innerBlock = fs.Body.Statements;
return fs;
}
IStatement IStatementBlock.GetStatement(out IList<IStatement> innerBlock)
{
return GetStatement(out innerBlock);
}
#endregion
}
}

Просмотреть файл

@ -0,0 +1,764 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Factors;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Models.Attributes;
using Microsoft.ML.Probabilistic.Compiler;
namespace Microsoft.ML.Probabilistic.Models
{
/// <summary>
/// Abstract base class for shared variables. Shared variables allow a model to be split
/// into submodels in which variables are shared. Each submodel can have many copies.
/// </summary>
/// <typeparam name="DomainType">Domian type of the variable</typeparam>
/// <remarks>A typical use of this is for large data sets where the likelihood parts of the
/// model cannot all fit in memory. The solution is to divide the data into chunks (or 'batches'), and specify
/// a single submodel which includes the likelihood factors and variables for one chunk, along
/// with the shared parameters; the number of copies of the submodel is set to the number
/// of chunks. In a related pattern, there are one or more additional submodels for defining
/// the parameter variables.</remarks>
public abstract class SharedVariable<DomainType> : ISharedVariable
{
/// <summary>
/// Name of the shared variable.
/// </summary>
public string Name;
/// <summary>
/// Creates a shared random variable with the specified prior distribution.
/// </summary>
/// <typeparam name="DistributionType">Distribution type</typeparam>
/// <param name="prior">Prior</param>
/// <param name="divideMessages">Use division (the faster default) for calculating messages to batches</param>
/// <returns></returns>
public static SharedVariable<DomainType> Random<DistributionType>(DistributionType prior, bool divideMessages = true)
where DistributionType : IDistribution<DomainType>, Sampleable<DomainType>, SettableToProduct<DistributionType>,
ICloneable, SettableToUniform, SettableTo<DistributionType>, SettableToRatio<DistributionType>, CanGetLogAverageOf<DistributionType>
{
return new SharedVariable<DomainType, DistributionType>(prior, divideMessages);
}
/// <summary>
/// Creates a 1D array of shared random variables of size given by the specified range.
/// </summary>
/// <typeparam name="DistributionArrayType">The type of the supplied prior</typeparam>
/// <param name="range">Range.</param>
/// <param name="prior">A distribution over an array, to use as the prior.</param>
/// <param name="divideMessages">Use division (the faster default) for calculating messages to batches</param>
/// <returns></returns>
public static SharedVariableArray<DomainType> Random<DistributionArrayType>(Range range, DistributionArrayType prior, bool divideMessages = true)
where DistributionArrayType : IDistribution<DomainType[]>, Sampleable<DomainType[]>, SettableToProduct<DistributionArrayType>,
ICloneable, SettableToUniform, SettableTo<DistributionArrayType>, SettableToRatio<DistributionArrayType>, CanGetLogAverageOf<DistributionArrayType>
{
return new SharedVariableArray<DomainType, DistributionArrayType>(range, prior, divideMessages);
}
/// <summary>
/// Creates a 1D jagged array of shared random variables.
/// </summary>
/// <typeparam name="DistributionArrayType"></typeparam>
/// <param name="itemPrototype">A fresh variable object representing an array element.</param>
/// <param name="range">Outer range.</param>
/// <param name="prior">Prior for the array.</param>
/// <param name="divideMessages">Use division (the faster default) for calculating messages to batches</param>
/// <returns></returns>
public static ISharedVariableArray<VariableArray<DomainType>, DomainType[][]> Random<DistributionArrayType>(VariableArray<DomainType> itemPrototype, Range range,
DistributionArrayType prior, bool divideMessages = true)
where DistributionArrayType : IDistribution<DomainType[][]>, Sampleable<DomainType[][]>, SettableToProduct<DistributionArrayType>,
ICloneable, SettableToUniform, SettableTo<DistributionArrayType>, SettableToRatio<DistributionArrayType>, CanGetLogAverageOf<DistributionArrayType>
{
return new SharedVariableArray<VariableArray<DomainType>, DomainType[][], DistributionArrayType>(itemPrototype, range, prior, divideMessages);
}
/// <summary>
/// Creates a generic jagged array of shared random variables.
/// </summary>
/// <typeparam name="ItemType">Item type</typeparam>
/// <typeparam name="DistributionArrayType">Distribution array type</typeparam>
/// <param name="itemPrototype">A fresh variable object representing an array element.</param>
/// <param name="range">Outer range</param>
/// <param name="prior">Prior for the array.</param>
/// <param name="divideMessages">Use division (the faster default) for calculating messages to batches</param>
/// <returns></returns>
public static ISharedVariableArray<ItemType, DomainType> Random<ItemType, DistributionArrayType>(ItemType itemPrototype, Range range, DistributionArrayType prior,
bool divideMessages = true)
where ItemType : Variable, SettableTo<ItemType>, ICloneable
where DistributionArrayType : IDistribution<DomainType>, Sampleable<DomainType>, SettableToProduct<DistributionArrayType>,
ICloneable, SettableToUniform, SettableTo<DistributionArrayType>, SettableToRatio<DistributionArrayType>, CanGetLogAverageOf<DistributionArrayType>
{
return new SharedVariableArray<ItemType, DomainType, DistributionArrayType>(itemPrototype, range, prior, divideMessages);
}
/// <summary>
/// Inline method for naming a shared variable.
/// </summary>
/// <param name="name">The name</param>
/// <returns>this</returns>
public SharedVariable<DomainType> Named(string name)
{
this.Name = name;
return this;
}
/// <summary>
/// ToString override.
/// </summary>
/// <returns></returns>
/// <exclude/>
public override string ToString()
{
return Name;
}
/// <summary>
/// Get the marginal distribution for the shared variable, converted to type T.
/// </summary>
/// <typeparam name="T">The desired type</typeparam>
/// <returns></returns>
public abstract T Marginal<T>();
/// <summary>
/// Gets a copy of the variable for the specified model.
/// </summary>
/// <param name="model">Model id.</param>
/// <returns></returns>
public abstract Variable<DomainType> GetCopyFor(Model model);
/// <summary>
/// Sets the definition of the shared variable.
/// </summary>
/// <param name="model">Model id.</param>
/// <param name="definition">Defining variable.</param>
/// <returns></returns>
/// <remarks>Use this method if the model is defining the shared variable rather than
/// using one defined in this or another model.</remarks>
public abstract void SetDefinitionTo(Model model, Variable<DomainType> definition);
/// <summary>
/// Sets the shared variable's inbox for a given model and batch.
/// </summary>
/// <param name="modelNumber"></param>
/// <param name="batchNumber"></param>
public abstract void SetInput(Model modelNumber, int batchNumber);
/// <summary>
/// Infer the shared variable's output message for the given model and batch number.
/// </summary>
/// <param name="engine">The inference engine.</param>
/// <param name="modelNumber">The model id.</param>
/// <param name="batchNumber">The batch number.</param>
public abstract void InferOutput(InferenceEngine engine, Model modelNumber, int batchNumber);
/// <summary>
/// Infer the shared variable's output message for the given model and batch number.
/// </summary>
/// <param name="ca">The compiled algorithm.</param>
/// <param name="modelNumber">The model id.</param>
/// <param name="batchNumber">The batch number.</param>
public abstract void InferOutput(IGeneratedAlgorithm ca, Model modelNumber, int batchNumber);
/// <summary>
/// Gets the evidence correction for this shared variable.
/// </summary>
/// <returns></returns>
public abstract double GetEvidenceCorrection();
/// <summary>
/// Marks this shared variable as one that calculates evidence
/// </summary>
public bool IsEvidenceVariable { get; set; }
}
#if false
public interface SharedVariable<DomainType> : ISharedVariable
{
Variable<DomainType> GetCopyFor(Model model);
SharedVariable<DomainType> Named(string name);
}
#endif
/// <summary>
/// A helper class that represents a variable which is shared between multiple models.
/// For example, where a very large model has been divided into sections corresponding to
/// batches of data, an instance of this class can be used to help learn each parameter
/// shared between the batches.
/// </summary>
/// <remarks>
/// <para>
/// Shared variables are used as follows. First the shared variable is created with a prior distribution.
/// Then a copy is created for each model using the <see cref="GetCopyFor"/> method.
/// Each model has a BatchCount which is the number of data batches you want to process with that model.
/// Before performing inference in each model and batch, <see cref="SetInput"/> should be called for each shared variable.
/// After all shared variables have their inputs set, <see cref="InferOutput(InferenceEngine,Model,int)"/> should then be called for each model and batch.
/// These two steps are done automatically by <see cref="Model.InferShared(InferenceEngine,int)"/>.
/// For inference to converge, you must loop multiple times through all the models, calling <see cref="Model.InferShared(InferenceEngine,int)"/> or SetInput/InferOutput each time.
/// At any point the current marginal of the shared variable can be retrieved using <see cref="Marginal"/>.
/// </para>
/// <para>In some situations, shared variables cannot be created directly from a prior distribution, for
/// example in a hierarchical model. In these situations, create the shared variable with a uniform
/// prior, and use <see cref="SetDefinitionTo"/> to define the variable.
/// </para>
/// <para>A shared variable which calculates evidence must be treated as a special case; such variables can be marked
/// using <see cref="SharedVariable{DomainType}.IsEvidenceVariable"/>, and the evidence is recovered using <see cref="Model.GetEvidenceForAll"/></para>
/// </remarks>
/// <typeparam name="DomainType">The domain type</typeparam>
/// <typeparam name="DistributionType">The marginal distribution type</typeparam>
internal class SharedVariable<DomainType, DistributionType> : SharedVariable<DomainType>
where DistributionType : IDistribution<DomainType>, Sampleable<DomainType>, SettableToProduct<DistributionType>, SettableToRatio<DistributionType>,
ICloneable, SettableToUniform, SettableTo<DistributionType>, CanGetLogAverageOf<DistributionType>
{
/// <summary>
/// Prior
/// </summary>
protected DistributionType Prior;
/// <summary>
/// Marginal
/// </summary>
protected DistributionType CurrentMarginal;
/// <summary>
/// Dictionary of output messages keyed by model
/// </summary>
protected Dictionary<Model, DistributionType[]> Outputs = new Dictionary<Model, DistributionType[]>();
/// <summary>
/// Dictionary of variable copies indexed by model
/// </summary>
protected Dictionary<Model, Variable<DomainType>> variables = new Dictionary<Model, Variable<DomainType>>();
/// <summary>
/// Dictionary of priors indexed by model
/// </summary>
protected Dictionary<Model, Variable<DistributionType>> priors = new Dictionary<Model, Variable<DistributionType>>();
/// <summary>
/// Defining model - only one single-batch model can define a shared variable, and this is optional.
/// </summary>
protected Model DefiningModel = null;
/// <summary>
/// The algorithm
/// </summary>
protected IAlgorithm algorithm;
/// <summary>
/// Global counter used to generate variable names.
/// </summary>
private static readonly GlobalCounter globalCounter = new GlobalCounter();
/// <summary>
/// If true (the default), uses division to calculate the messages to batches.
/// This is more efficient, but may introduce round-off errors.
/// </summary>
protected bool DivideMessages;
internal SharedVariable(DistributionType prior, bool divideMessages = true)
{
this.Name = $"shared{StringUtil.TypeToString(typeof (DomainType))}{StringUtil.TypeToString(typeof (DistributionType))}{globalCounter.GetNext()}";
this.Prior = prior;
this.DivideMessages = divideMessages;
if (divideMessages)
this.CurrentMarginal = (DistributionType) this.Prior.Clone();
}
/// <summary>
/// Constructs a new shared variable with a given domain type and distribution type
/// </summary>
/// <param name="name">Name of the shared variable</param>
/// <returns></returns>
public new SharedVariable<DomainType, DistributionType> Named(string name)
{
base.Named(name);
return this;
}
#if false
SharedVariable<DomainType> SharedVariable<DomainType>.Named(string name)
{
return Named(name);
}
#endif
/// <summary>
/// Gets a copy of this shared variable for the specified model
/// </summary>
/// <param name="model">The model identifier</param>
/// <returns></returns>
public override Variable<DomainType> GetCopyFor(Model model)
{
if (model == DefiningModel)
throw new InferCompilerException("You cannot get a copy as the shared variable is defined by this model");
Variable<DomainType> v;
if (!variables.TryGetValue(model, out v))
{
Variable<DistributionType> vPrior = Variable.New<DistributionType>().Named(Name + "Prior");
vPrior.ObservedValue = default(DistributionType);
v = Variable<DomainType>.Random(vPrior).Named(Name).Attrib(QueryTypes.Marginal).Attrib(QueryTypes.MarginalDividedByPrior);
variables[model] = v;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionType[] messages = new DistributionType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionType) Prior.Clone();
messages[i].SetToUniform();
}
if (DivideMessages)
CurrentMarginal = (DistributionType) Prior.Clone();
Outputs[model] = messages;
}
return v;
}
/// <summary>
/// Sets the definition of the shared variable
/// </summary>
/// <param name="model">Model id</param>
/// <param name="definition">Defining variable</param>
/// <returns></returns>
/// <remarks>Use this method if the model is defining the shared variable rather than
/// using one defined in this or another model.</remarks>
public override void SetDefinitionTo(Model model, Variable<DomainType> definition)
{
if (DefiningModel != null)
throw new InferCompilerException("You can only define a shared variable once");
if (model.BatchCount != 1)
throw new InferCompilerException("You can only define a shared variable from a model with a batch count of 1");
if (!definition.IsBase)
throw new InferCompilerException("You cannot set a shared variable to a derived variable");
Variable<DomainType> v;
if (!variables.TryGetValue(model, out v))
{
definition.AddAttribute(QueryTypes.Marginal);
Variable<DistributionType> vPrior = Variable.New<DistributionType>().Named(Name + "Constraint");
vPrior.ObservedValue = default(DistributionType);
Variable.ConstrainEqualRandom<DomainType, DistributionType>(definition, vPrior);
variables[model] = definition;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionType[] messages = new DistributionType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionType) Prior.Clone();
messages[i].SetToUniform();
}
if (DivideMessages)
CurrentMarginal = (DistributionType) Prior.Clone();
// In this case, output refers to the forward message from the definition.
// There is only one as we are requiring that batch count = 1.
Outputs[model] = messages;
// This is the defining model for the variable
DefiningModel = model;
}
}
/// <summary>
/// Sets the shared variable's inbox given model and batch number
/// </summary>
/// <param name="model">Model id</param>
/// <param name="batchNumber">Batch number</param>
public override void SetInput(Model model, int batchNumber)
{
priors[model].ObservedValue = MessageToBatch(model, batchNumber);
// this version mutates the ObservedValue in place. unfortunately, if we do this the inference object will not detect that the value has changed.
//priors[model].ObservedValue = MessageToBatch(model, batchNumber, priors[model].ObservedValue);
}
/// <summary>
/// Returns the shared variable's inbox message given model and batch number
/// </summary>
/// <param name="model">Model id</param>
/// <param name="batchNumber">Batch number</param>
/// <returns>The inbox message</returns>
public DistributionType MessageToBatch(Model model, int batchNumber)
{
return MessageToBatch(model, batchNumber, default(DistributionType));
}
/// <summary>
/// Returns the shared variable's inbox message given model and batch number
/// </summary>
/// <param name="modelNumber">Model id</param>
/// <param name="batchNumber">Batch number</param>
/// <param name="result">Where to put the result</param>
/// <returns>The inbox message</returns>
public DistributionType MessageToBatch(Model modelNumber, int batchNumber, DistributionType result)
{
if (DivideMessages)
{
if (object.ReferenceEquals(result, default(DistributionType))) result = (DistributionType) CurrentMarginal.Clone();
else result.SetTo(CurrentMarginal);
if (algorithm == null) return result;
foreach (KeyValuePair<Model, DistributionType[]> entry in Outputs)
{
if (entry.Key == modelNumber)
{
// correct even for VMP
result.SetToRatio(result, entry.Value[batchNumber]);
}
}
}
else
{
if (object.ReferenceEquals(result, default(DistributionType))) result = (DistributionType) Prior.Clone();
else result.SetTo(Prior);
if (algorithm == null) return result;
foreach (KeyValuePair<Model, DistributionType[]> entry in Outputs)
{
if (entry.Key == modelNumber)
{
// correct even for VMP
result = Distribution.SetToProductWithAllExcept(result, entry.Value, batchNumber);
}
else
{
result = Distribution.SetToProductWithAll(result, entry.Value);
}
}
}
return result;
}
/// <summary>
/// Get the marginal distribution, converted to type T
/// </summary>
/// <typeparam name="T">The desired type</typeparam>
/// <returns></returns>
public override T Marginal<T>()
{
return Distribution.ChangeType<T>(Marginal());
}
/// <summary>
/// Returns the marginal distribution
/// </summary>
/// <returns></returns>
public DistributionType Marginal()
{
return MessageToBatch(null, -1, default(DistributionType));
}
/// <summary>
/// Gets the evidence correction for this shared variable
/// </summary>
/// <returns></returns>
public override double GetEvidenceCorrection()
{
List<DistributionType> uses = new List<DistributionType>();
foreach (DistributionType[] dists in Outputs.Values)
{
uses.AddRange(dists);
}
// this is correct for EP and VMP
double result = UsesEqualDefOp.LogEvidenceRatio1(uses, Prior);
if (DefiningModel != null)
{
if (!Prior.IsUniform())
throw new InferCompilerException("Shared variable has a non-uniform prior and a definition - try using a uniform prior instead");
result -= Prior.GetLogAverageOf(Prior);
}
return result;
}
/// <summary>
/// Infer the output message given a model id and a batch id
/// </summary>
/// <param name="engine">Inference engine</param>
/// <param name="modelNumber">Model number</param>
/// <param name="batchNumber">Batch number</param>
public override void InferOutput(InferenceEngine engine, Model modelNumber, int batchNumber)
{
algorithm = engine.Algorithm;
if (DivideMessages)
{
if (modelNumber != DefiningModel)
{
Outputs[modelNumber][batchNumber] = (DistributionType) engine.GetOutputMessage<DistributionType>(GetCopyFor(modelNumber)).Clone();
}
else
{
Outputs[modelNumber][batchNumber] = (DistributionType) engine.Infer<DistributionType>(variables[modelNumber]).Clone();
Outputs[modelNumber][batchNumber].SetToRatio(Outputs[modelNumber][batchNumber], priors[modelNumber].ObservedValue);
}
CurrentMarginal = engine.Infer<DistributionType>(variables[modelNumber]);
}
else
{
algorithm = engine.Algorithm;
if (modelNumber != DefiningModel)
{
Outputs[modelNumber][batchNumber] = engine.GetOutputMessage<DistributionType>(GetCopyFor(modelNumber));
}
else
{
Outputs[modelNumber][batchNumber] = engine.Infer<DistributionType>(variables[modelNumber]);
Outputs[modelNumber][batchNumber].SetToRatio(Outputs[modelNumber][batchNumber], priors[modelNumber].ObservedValue);
}
}
}
/// <summary>
/// Infer the output message given a model id and a batch id
/// </summary>
/// <param name="ca">Compiled algorithm</param>
/// <param name="modelNumber">Model number</param>
/// <param name="batchNumber">Batch number</param>
public override void InferOutput(IGeneratedAlgorithm ca, Model modelNumber, int batchNumber)
{
throw new NotImplementedException();
//if (algorithm is ExpectationPropagation && modelNumber != DefiningModel) {
// Outputs[modelNumber][batchNumber] = (DistributionType)ca.GetOutputMessage(Name);
//} else {
// Outputs[modelNumber][batchNumber] = ca.Marginal<DistributionType>(variables[modelNumber].Name);
// // this can be avoided for VMP by labelling the variable as deterministic
// Outputs[modelNumber][batchNumber].SetToRatio(Outputs[modelNumber][batchNumber], priors[modelNumber].ObservedValue);
//}
//CurrentMarginal = ca.Marginal<DistributionType>(variables[modelNumber].Name);
}
}
/// <summary>
/// Interface for shared variables
/// </summary>
public interface ISharedVariable
{
/// <summary>
/// Sets the shared variable's inbox for a given model and batch number
/// </summary>
/// <param name="modelNumber">Model id</param>
/// <param name="batchNumber">Batch number</param>
void SetInput(Model modelNumber, int batchNumber);
/// <summary>
/// Infers the shared variable's output message for a given model and batch number
/// </summary>
/// <param name="engine">Inference engine</param>
/// <param name="modelNumber">Model id</param>
/// <param name="batchNumber">Batch number</param>
void InferOutput(InferenceEngine engine, Model modelNumber, int batchNumber);
/// <summary>
/// Infers the shared variable's output message for a given model and batch number
/// </summary>
/// <param name="ca">Compiled algorithm</param>
/// <param name="modelNumber">Model id</param>
/// <param name="batchNumber">Batch number</param>
void InferOutput(IGeneratedAlgorithm ca, Model modelNumber, int batchNumber);
/// <summary>
/// Gets the evidence correction for this shared variable
/// </summary>
/// <returns></returns>
double GetEvidenceCorrection();
/// <summary>
/// Whether this shared variable is an evidence variable
/// </summary>
bool IsEvidenceVariable { get; set; }
}
/// <summary>
/// A Set of SharedVariables that allows SetInput/InferOutput to be called on all of them at once.
/// </summary>
public class SharedVariableSet : Set<ISharedVariable>, ISharedVariable
{
/// <summary>
/// Constructs a set of shared variables
/// </summary>
public SharedVariableSet()
: base()
{
}
#if false
public SharedVariableSet(IEnumerable<ISharedVariable> variables) : base(variables)
{
}
#endif
/// <summary>
/// Set inboxes, for the given model and batch number, for all
/// shared variables in this set
/// </summary>
/// <param name="modelNumber">Model id</param>
/// <param name="batchNumber">Batch number</param>
public void SetInput(Model modelNumber, int batchNumber)
{
foreach (ISharedVariable v in this)
{
v.SetInput(modelNumber, batchNumber);
}
}
/// <summary>
/// Infer the output messages, for the given model and batch number, for all
/// shared variables in this set
/// </summary>
/// <param name="engine">Inference engine</param>
/// <param name="modelNumber">Model id</param>
/// <param name="batchNumber">Batch number</param>
public void InferOutput(InferenceEngine engine, Model modelNumber, int batchNumber)
{
foreach (ISharedVariable v in this)
{
v.InferOutput(engine, modelNumber, batchNumber);
}
}
/// <summary>
/// Infer the output messages, for the given model and batch number, for all
/// shared variables in this set
/// </summary>
/// <param name="ca">Compiled algorithm</param>
/// <param name="modelNumber">Model id</param>
/// <param name="batchNumber">Batch number</param>
public void InferOutput(IGeneratedAlgorithm ca, Model modelNumber, int batchNumber)
{
foreach (ISharedVariable v in this)
{
v.InferOutput(ca, modelNumber, batchNumber);
}
}
/// <summary>
/// Gets the evidence for this set of shared variable
/// </summary>
/// <returns></returns>
public double GetEvidence()
{
double sum = 0.0;
foreach (ISharedVariable v in this)
{
if (!v.IsEvidenceVariable)
sum += v.GetEvidenceCorrection();
else
sum += ((SharedVariable<bool, Bernoulli>) v).Marginal<Bernoulli>().LogOdds;
}
return sum;
}
/// <summary>
/// Not supported for <see cref="SharedVariableSet"/>
/// </summary>
/// <returns></returns>
public double GetEvidenceCorrection()
{
throw new NotSupportedException();
}
/// <summary>
/// Not supported for <see cref="SharedVariableSet"/>
/// </summary>
public bool IsEvidenceVariable
{
get { return false; }
set { throw new Exception("Cannot be an evidence variable"); }
}
}
/// <summary>
/// A model identifier used to manage SharedVariables.
/// </summary>
public class Model
{
/// <summary>
/// The set of SharedVariables registered with this model.
/// </summary>
public SharedVariableSet SharedVariables = new SharedVariableSet();
/// <summary>
/// The number of data batches that will be processed with this model.
/// </summary>
public int BatchCount;
/// <summary>
/// Name of the model
/// </summary>
public string Name;
private static readonly GlobalCounter globalCounter = new GlobalCounter();
/// <summary>
/// Create a new model identifier to which SharedVariables can be registered.
/// </summary>
/// <param name="batchCount">The number of data batches that will be processed with this model.</param>
public Model(int batchCount)
{
Name = $"model{globalCounter.GetNext()}";
BatchCount = batchCount;
}
/// <summary>
/// Inline method for naming a shared variable model
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
public Model Named(string name)
{
Name = name;
return this;
}
/// <summary>
/// Update all the SharedVariables registered with this model.
/// </summary>
/// <param name="engine"></param>
/// <param name="batchNumber">A number from 0 to BatchCount-1</param>
public void InferShared(InferenceEngine engine, int batchNumber)
{
SharedVariables.SetInput(this, batchNumber);
SharedVariables.InferOutput(engine, this, batchNumber);
}
/// <summary>
/// Update all the SharedVariables registered with this model.
/// </summary>
/// <param name="engine"></param>
/// <param name="batchNumber">A number from 0 to BatchCount-1</param>
public void InferShared(IGeneratedAlgorithm engine, int batchNumber)
{
SharedVariables.SetInput(this, batchNumber);
SharedVariables.InferOutput(engine, this, batchNumber);
}
/// <summary>
/// Gets evidence for all the specified models
/// </summary>
/// <param name="models">An array of models</param>
/// <returns></returns>
public static double GetEvidenceForAll(params Model[] models)
{
SharedVariableSet allVariables = new SharedVariableSet();
foreach (Model model in models)
{
allVariables.AddRange(model.SharedVariables);
}
return allVariables.GetEvidence();
}
/// <summary>
/// ToString override
/// </summary>
/// <returns></returns>
public override string ToString()
{
if (BatchCount == 1) return Name;
else return Name + "(" + BatchCount + ")";
}
}
}

Просмотреть файл

@ -0,0 +1,337 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
namespace Microsoft.ML.Probabilistic.Models
{
/// <summary>
/// Interface for jagged 1D shared variable arrays
/// </summary>
/// <typeparam name="ItemType">Variable type of an item</typeparam>
/// <typeparam name="ArrayType">Domain type of the array</typeparam>
public interface ISharedVariableArray<ItemType, ArrayType> : ISharedVariable
where ItemType : Variable, ICloneable, SettableTo<ItemType>
{
/// <summary>
/// Get the marginal, converted to type T
/// </summary>
/// <typeparam name="T">The desired type</typeparam>
/// <returns></returns>
T Marginal<T>();
/// <summary>
/// Get a copy of the variable array for the specified model
/// </summary>
/// <param name="model">The model id</param>
/// <returns></returns>
VariableArray<ItemType, ArrayType> GetCopyFor(Model model);
/// <summary>
/// Sets the definition of the shared variable
/// </summary>
/// <param name="model">Model id</param>
/// <param name="definition">Defining variable</param>
void SetDefinitionTo(Model model, VariableArray<ItemType, ArrayType> definition);
/// <summary>
/// Inline method to name shared variable arrays
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
ISharedVariableArray<ItemType, ArrayType> Named(string name);
}
/// <summary>
/// Interface for flat 1D shared variable arrays
/// </summary>
/// <typeparam name="DomainType">Domain type of the variable</typeparam>
public interface SharedVariableArray<DomainType> : ISharedVariable
//public interface SharedVariableArray<DomainType> : SharedVariableArray<VariableArray<DomainType>,DomainType[]>
{
/// <summary>
/// Get the marginal, converted to type T
/// </summary>
/// <typeparam name="T">The desired type</typeparam>
/// <returns></returns>
T Marginal<T>();
/// <summary>
/// Get a copy of the variable array for the specified model
/// </summary>
/// <param name="model">The model id</param>
/// <returns></returns>
VariableArray<DomainType> GetCopyFor(Model model);
/// <summary>
/// Sets the definition of the shared variable
/// </summary>
/// <param name="model">The model id</param>
/// <param name="definition">Defining variable</param>
/// <returns></returns>
void SetDefinitionTo(Model model, VariableArray<DomainType> definition);
/// <summary>
/// Inline method to name shared variable arrays
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
SharedVariableArray<DomainType> Named(string name);
}
/// <summary>
/// A helper class that represents a variable array which is shared between multiple models.
/// For example, where a very large model has been divided into sections corresponding to
/// batches of data, an instance of this class can be used to help learn each parameter
/// shared between the batches.
/// </summary>
/// <remarks>
/// <para>
/// Shared variable arrays are used as follows. First the shared variable array is created with a prior distribution.
/// Then a copy is created for each model using the <see cref="SharedVariableArray{DomainType}.GetCopyFor(Model)"/> method.
/// Each model has a BatchCount which is the number of data batches you want to process with that model.
/// Before performing inference in each model and batch, <see cref="SharedVariable{DomainType, DistributionType}.SetInput"/> should be called for each shared variable.
/// After all shared variables have their inputs set, <see cref="SharedVariable{DomainType, DistributionType}.InferOutput(InferenceEngine,Model,int)"/> should then be called for each model and batch.
/// These two steps are done automatically by <see cref="Model.InferShared(InferenceEngine,int)"/>.
/// For inference to converge, you must loop multiple times through all the models, calling <see cref="Model.InferShared(InferenceEngine,int)"/> or SetInput/InferOutput each time.
/// At any point the current marginal of the shared variable array can be retrieved using <see cref="SharedVariable{DomainType, DistributionType}.Marginal"/>.
/// </para><para>In some situations, shared variable arrays cannot be created directly from a prior distribution, for
/// example in a hierarchical model. In these situations, create the shared variable array with a uniform
/// prior, and use <see cref="SharedVariableArray{DomainType}.SetDefinitionTo"/> to define the variable.
/// </para>
/// </remarks>
/// <typeparam name="DomainType">The domain type of an array element</typeparam>
/// <typeparam name="DistributionArrayType">The marginal distribution type of the array</typeparam>
internal class SharedVariableArray<DomainType, DistributionArrayType> : SharedVariable<DomainType[], DistributionArrayType>, SharedVariableArray<DomainType>
where DistributionArrayType : IDistribution<DomainType[]>, Sampleable<DomainType[]>, SettableToProduct<DistributionArrayType>, SettableToRatio<DistributionArrayType>,
ICloneable, SettableToUniform, SettableTo<DistributionArrayType>, CanGetLogAverageOf<DistributionArrayType>
{
/// <summary>
/// Range for the array of shared variables
/// </summary>
public Range range;
internal SharedVariableArray(Range range, DistributionArrayType prior, bool divideMessages = true)
: base(prior, divideMessages)
{
this.range = range;
}
/// <summary>
/// Inline method for naming an array of shared variables
/// </summary>
/// <param name="name">Name</param>
/// <returns>this</returns>
public new SharedVariableArray<DomainType, DistributionArrayType> Named(string name)
{
base.Named(name);
return this;
}
SharedVariableArray<DomainType> SharedVariableArray<DomainType>.Named(string name)
{
return Named(name);
}
VariableArray<DomainType> SharedVariableArray<DomainType>.GetCopyFor(Model model)
{
if (model == DefiningModel)
throw new ArgumentException("The shared variable is already defined by this model");
Variable<DomainType[]> v;
if (!variables.TryGetValue(model, out v))
{
Variable<DistributionArrayType> vPrior = Variable.New<DistributionArrayType>()
.Named(Name + "Prior");
vPrior.ObservedValue = default(DistributionArrayType);
VariableArray<DomainType> va = Variable.Array<DomainType>(range).Named(Name).Attrib(QueryTypes.MarginalDividedByPrior).Attrib(QueryTypes.Marginal);
va.SetTo(Variable<DomainType[]>.Random(vPrior));
v = va;
variables[model] = va;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionArrayType[] messages = new DistributionArrayType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionArrayType) Prior.Clone();
messages[i].SetToUniform();
}
if (DivideMessages)
CurrentMarginal = (DistributionArrayType) Prior.Clone();
Outputs[model] = messages;
}
return (VariableArray<DomainType>) v;
}
void SharedVariableArray<DomainType>.SetDefinitionTo(Model model, VariableArray<DomainType> definition)
{
if (DefiningModel != null)
throw new InvalidOperationException("Shared variable is already defined");
if (model.BatchCount != 1)
throw new ArgumentException("model.BatchCount != 1");
if (!definition.IsBase)
throw new ArgumentException("definition is a derived variable");
Variable<DomainType[]> v;
if (!variables.TryGetValue(model, out v))
{
Variable<DistributionArrayType> vPrior = Variable.New<DistributionArrayType>()
.Named(Name + "Constraint");
vPrior.ObservedValue = default(DistributionArrayType);
Variable.ConstrainEqualRandom<DomainType[], DistributionArrayType>(definition, vPrior);
variables[model] = definition;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionArrayType[] messages = new DistributionArrayType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionArrayType) Prior.Clone();
messages[i].SetToUniform();
}
CurrentMarginal = (DistributionArrayType) Prior.Clone();
// In this case, output refers to the forward message from the definition.
// There is only one as we are requiring that batch count = 1.
Outputs[model] = messages;
// This is the defining model for the variable
DefiningModel = model;
}
}
}
/// <summary>
/// A helper class that represents a jagged variable array which is shared between multiple models.
/// For example, where a very large model has been divided into sections corresponding to
/// batches of data, an instance of this class can be used to help learn each parameter
/// shared between the batches.
/// </summary>
/// <remarks>
/// <para>
/// Shared variable arrays are used as follows. First the shared variable array is created with a prior distribution.
/// Then a copy is created for each model using the <see cref="SharedVariableArray{DomainType}.GetCopyFor(Model)"/> method.
/// Each model has a BatchCount which is the number of data batches you want to process with that model.
/// Before performing inference in each model and batch, <see cref="SharedVariable{DomainType, DistributionType}.SetInput"/> should be called for each shared variable.
/// After all shared variables have their inputs set, <see cref="SharedVariable{DomainType, DistributionType}.InferOutput(InferenceEngine,Model,int)"/> should then be called for each model and batch.
/// These two steps are done automatically by <see cref="Model.InferShared(InferenceEngine,int)"/>.
/// For inference to converge, you must loop multiple times through all the models, calling <see cref="Model.InferShared(InferenceEngine,int)"/> or SetInput/InferOutput each time.
/// At any point the current marginal of the shared variable array can be retrieved using <see cref="SharedVariable{DomainType, DistributionType}.Marginal"/>.
/// </para><para>In some situations, shared variable arrays cannot be created directly from a prior distribution, for
/// example in a hierarchical model. In these situations, create the shared variable array with a uniform
/// prior, and use <see cref="SharedVariableArray{DomainType}.SetDefinitionTo"/> to define the variable.
/// </para>
/// </remarks>
/// <typeparam name="ItemType">The variable type of an array element</typeparam>
/// <typeparam name="ArrayType">The domain type of the array.</typeparam>
/// <typeparam name="DistributionArrayType">The marginal distribution type of the array</typeparam>
internal class SharedVariableArray<ItemType, ArrayType, DistributionArrayType> : SharedVariable<ArrayType, DistributionArrayType>,
ISharedVariableArray<ItemType, ArrayType>
where DistributionArrayType : IDistribution<ArrayType>, Sampleable<ArrayType>, SettableToProduct<DistributionArrayType>, SettableToRatio<DistributionArrayType>,
ICloneable, SettableToUniform, SettableTo<DistributionArrayType>, CanGetLogAverageOf<DistributionArrayType>
where ItemType : Variable, ICloneable, SettableTo<ItemType>
{
/// <summary>
/// Range for the array of shared variables
/// </summary>
public Range range;
private ItemType itemPrototype;
internal SharedVariableArray(ItemType itemPrototype, Range range, DistributionArrayType prior, bool divideMessages = true)
: base(prior, divideMessages)
{
this.itemPrototype = itemPrototype;
this.range = range;
}
/// <summary>
/// Inline method for naming an array of shared variables
/// </summary>
/// <param name="name">Name</param>
/// <returns>this</returns>
public new SharedVariableArray<ItemType, ArrayType, DistributionArrayType> Named(string name)
{
base.Named(name);
return this;
}
ISharedVariableArray<ItemType, ArrayType> ISharedVariableArray<ItemType, ArrayType>.Named(string name)
{
return Named(name);
}
VariableArray<ItemType, ArrayType> ISharedVariableArray<ItemType, ArrayType>.GetCopyFor(Model model)
{
if (model == DefiningModel)
throw new ArgumentException("The shared variable is already defined by this model");
Variable<ArrayType> v;
if (!variables.TryGetValue(model, out v))
{
Variable<DistributionArrayType> vPrior = Variable.New<DistributionArrayType>()
.Named(Name + "Prior");
vPrior.ObservedValue = default(DistributionArrayType);
// va's containers are obtained from itemPrototype's containers, so we must set them first
itemPrototype.Containers.Clear();
itemPrototype.Containers.AddRange(StatementBlock.GetOpenBlocks());
VariableArray<ItemType, ArrayType> va = Variable.Array<ItemType, ArrayType>(itemPrototype, range)
.Named(Name).Attrib(QueryTypes.MarginalDividedByPrior).Attrib(QueryTypes.Marginal);
va.SetTo(Variable<ArrayType>.Random(vPrior));
v = va;
variables[model] = va;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionArrayType[] messages = new DistributionArrayType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionArrayType) Prior.Clone();
messages[i].SetToUniform();
}
if (DivideMessages)
CurrentMarginal = (DistributionArrayType) Prior.Clone();
Outputs[model] = messages;
}
return (VariableArray<ItemType, ArrayType>) v;
}
void ISharedVariableArray<ItemType, ArrayType>.SetDefinitionTo(Model model, VariableArray<ItemType, ArrayType> definition)
{
if (DefiningModel != null)
throw new InvalidOperationException("Shared variable is already defined");
if (model.BatchCount != 1)
throw new ArgumentException("model.BatchCount != 1");
if (!definition.IsBase)
throw new ArgumentException("definition is a derived variable");
Variable<ArrayType> v;
if (!variables.TryGetValue(model, out v))
{
definition.AddAttribute(QueryTypes.Marginal);
Variable<DistributionArrayType> vPrior = Variable.New<DistributionArrayType>()
.Named(Name + "Constraint");
vPrior.ObservedValue = default(DistributionArrayType);
Variable.ConstrainEqualRandom<ArrayType, DistributionArrayType>(definition, vPrior);
variables[model] = definition;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionArrayType[] messages = new DistributionArrayType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionArrayType) Prior.Clone();
messages[i].SetToUniform();
}
// In this case, output refers to the forward message from the definition.
// There is only one as we are requiring that batch count = 1.
Outputs[model] = messages;
// This is the defining model for the variable
DefiningModel = model;
}
}
}
}

Просмотреть файл

@ -0,0 +1,170 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
namespace Microsoft.ML.Probabilistic.Models
{
/// <summary>
/// Interface for flat 2D shared variable arrays
/// </summary>
/// <typeparam name="DomainType">Domain type of the variable</typeparam>
public interface SharedVariableArray2D<DomainType> : ISharedVariable
{
/// <summary>
/// Gets the marginal
/// </summary>
/// <typeparam name="DistributionType">The returned distribution array type</typeparam>
/// <returns></returns>
DistributionType Marginal<DistributionType>();
/// <summary>
/// Get a copy of the variable array for the specified model
/// </summary>
/// <param name="model">The model id</param>
/// <returns></returns>
VariableArray2D<DomainType> GetCopyFor(Model model);
/// <summary>
/// Sets a copy of the shared variable to a definition
/// </summary>
/// <param name="model">Model id</param>
/// <param name="definition">Defining variable</param>
/// <returns></returns>
void SetDefinitionTo(Model model, VariableArray2D<DomainType> definition);
/// <summary>
/// Inline method to name shared variable arrays
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
SharedVariableArray2D<DomainType> Named(string name);
}
/// <summary>
/// A helper class that represents a variable array which is shared between multiple models.
/// For example, where a very large model has been divided into sections corresponding to
/// batches of data, an instance of this class can be used to help learn each parameter
/// shared between the batches.
/// </summary>
/// <remarks>
/// <para>
/// Shared variable arrays are used as follows. First the shared variable array is created with a prior distribution.
/// Then a copy is created for each model using the <see cref="SharedVariableArray{DomainType}.GetCopyFor(Model)"/> method.
/// Each model has a BatchCount which is the number of data batches you want to process with that model.
/// Before performing inference in each model and batch, <see cref="SharedVariable{DomainType, DistributionType}.SetInput"/> should be called for each shared variable.
/// After all shared variables have their inputs set, <see cref="SharedVariable{DomainType, DistributionType}.InferOutput(InferenceEngine,Model,int)"/> should then be called for each model and batch.
/// These two steps are done automatically by <see cref="Model.InferShared(InferenceEngine,int)"/>.
/// For inference to converge, you must loop multiple times through all the models, calling <see cref="Model.InferShared(InferenceEngine,int)"/> or SetInput/InferOutput each time.
/// At any point the current marginal of the shared variable array can be retrieved using <see cref="SharedVariable{DomainType, DistributionType}.Marginal"/>.
/// </para><para>In some situations, shared variable arrays cannot be created directly from a prior distribution, for
/// example in a hierarchical model. In these situations, create the shared variable array with a uniform
/// prior, and use <see cref="SharedVariableArray2D{DomainType}.SetDefinitionTo"/> to define the variable.
/// </para>
/// </remarks>
/// <typeparam name="DomainType">The domain type of an array element</typeparam>
/// <typeparam name="DistributionArrayType">The marginal distribution type of the array</typeparam>
internal class SharedVariableArray2D<DomainType, DistributionArrayType> : SharedVariable<DomainType[,], DistributionArrayType>, SharedVariableArray2D<DomainType>
where DistributionArrayType : IDistribution<DomainType[,]>, Sampleable<DomainType[,]>, SettableToProduct<DistributionArrayType>, SettableToRatio<DistributionArrayType>
,
ICloneable, SettableToUniform, SettableTo<DistributionArrayType>, CanGetLogAverageOf<DistributionArrayType>
{
/// <summary>
/// Ranges for the array of shared variables
/// </summary>
public Range range0, range1;
internal SharedVariableArray2D(Range range0, Range range1, DistributionArrayType prior, bool divideMessages = true)
: base(prior, divideMessages)
{
this.range0 = range0;
this.range1 = range1;
}
/// <summary>
/// Inline method for naming an array of shared variables
/// </summary>
/// <param name="name">Name</param>
/// <returns>this</returns>
public new SharedVariableArray2D<DomainType, DistributionArrayType> Named(string name)
{
base.Named(name);
return this;
}
SharedVariableArray2D<DomainType> SharedVariableArray2D<DomainType>.Named(string name)
{
return Named(name);
}
VariableArray2D<DomainType> SharedVariableArray2D<DomainType>.GetCopyFor(Model model)
{
if (model == DefiningModel)
throw new ArgumentException("The shared variable is already defined by this model");
Variable<DomainType[,]> v;
if (!variables.TryGetValue(model, out v))
{
Variable<DistributionArrayType> vPrior = Variable.New<DistributionArrayType>()
.Named(Name + "Prior");
vPrior.ObservedValue = default(DistributionArrayType);
VariableArray2D<DomainType> va = Variable.Array<DomainType>(range0, range1).Named(Name).Attrib(QueryTypes.MarginalDividedByPrior).Attrib(QueryTypes.Marginal);
va.SetTo(Variable<DomainType[,]>.Random(vPrior));
v = va;
variables[model] = va;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionArrayType[] messages = new DistributionArrayType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionArrayType) Prior.Clone();
messages[i].SetToUniform();
}
if (DivideMessages)
CurrentMarginal = (DistributionArrayType) Prior.Clone();
Outputs[model] = messages;
}
return (VariableArray2D<DomainType>) v;
}
void SharedVariableArray2D<DomainType>.SetDefinitionTo(Model model, VariableArray2D<DomainType> definition)
{
if (DefiningModel != null)
throw new InvalidOperationException("Shared variable is already defined");
if (model.BatchCount != 1)
throw new ArgumentException("model.BatchCount != 1");
if (!definition.IsBase)
throw new ArgumentException("definition is a derived variable");
Variable<DomainType[,]> v;
if (!variables.TryGetValue(model, out v))
{
Variable<DistributionArrayType> vPrior = Variable.New<DistributionArrayType>()
.Named(Name + "Constraint");
vPrior.ObservedValue = default(DistributionArrayType);
Variable.ConstrainEqualRandom<DomainType[,], DistributionArrayType>(definition, vPrior);
variables[model] = definition;
model.SharedVariables.Add(this);
priors[model] = vPrior;
DistributionArrayType[] messages = new DistributionArrayType[model.BatchCount];
for (int i = 0; i < messages.Length; i++)
{
messages[i] = (DistributionArrayType) Prior.Clone();
messages[i].SetToUniform();
}
CurrentMarginal = (DistributionArrayType) Prior.Clone();
// In this case, output refers to the forward message from the definition.
// There is only one as we are requiring that batch count = 1.
Outputs[model] = messages;
// This is the defining model for the variable
DefiningModel = model;
}
}
}
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше