Merge branch 'master' into InnerProduct

This commit is contained in:
Tom Minka 2019-09-21 14:25:03 +01:00 коммит произвёл GitHub
Родитель a280a6272f cfe2f72f98
Коммит e642655221
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
23 изменённых файлов: 929 добавлений и 171 удалений

Просмотреть файл

@ -86,6 +86,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CrowdsourcingWithWords", "s
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ReviewerCalibration", "src\Examples\ReviewerCalibration\ReviewerCalibration.csproj", "{7D7EA3FD-8D1A-4E07-B34C-5B2B86CAFCAA}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RobustGaussianProcess", "src\Examples\RobustGaussianProcess\RobustGaussianProcess.csproj", "{62719C63-ECED-49CD-B73C-6DAF2A18033B}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -470,6 +472,18 @@ Global
{7D7EA3FD-8D1A-4E07-B34C-5B2B86CAFCAA}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{7D7EA3FD-8D1A-4E07-B34C-5B2B86CAFCAA}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{7D7EA3FD-8D1A-4E07-B34C-5B2B86CAFCAA}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.Release|Any CPU.Build.0 = Release|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{62719C63-ECED-49CD-B73C-6DAF2A18033B}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -496,6 +510,7 @@ Global
{816CD64D-7189-46E3-8C54-D4ED4C0BB758} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{6563C4C6-411E-4D67-B458-830AD9B311D2} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{7D7EA3FD-8D1A-4E07-B34C-5B2B86CAFCAA} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{62719C63-ECED-49CD-B73C-6DAF2A18033B} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {160F773C-9CF5-4F8D-B45A-1112A1BC5E16}

Просмотреть файл

@ -0,0 +1,69 @@
Auto Insurance in Sweden
In the following data
X = number of claims
Y = total payment for all the claims in thousands of Swedish Kronor
108, 392.5
19, 46.2
13, 15.7
124, 422.2
40, 119.4
57, 170.9
23, 56.9
14, 77.5
45, 214
10, 65.3
5, 20.9
48, 248.1
11, 23.5
23, 39.6
7, 48.8
2, 6.6
24, 134.9
6, 50.9
3, 4.4
23, 113
6, 14.8
9, 48.7
9, 52.1
3, 13.2
29, 103.9
7, 77.5
4, 11.8
20, 98.1
7, 27.9
4, 38.1
0, 0
25, 69.2
6, 14.6
5, 40.3
22, 161.5
11, 57.2
61, 217.6
12, 58.1
4, 12.6
16, 59.6
13, 89.9
60, 202.4
41, 181.3
37, 152.8
55, 162.8
41, 73.4
11, 21.3
27, 92.6
8, 76.1
3, 39.9
17, 142.1
13, 93
13, 31.9
15, 32.1
8, 55.6
29, 133.3
30, 194.5
24, 137.9
9, 87.4
31, 209.8
14, 95.5
53, 244.6
26, 187.5
1 Auto Insurance in Sweden
2 In the following data
3 X = number of claims
4 Y = total payment for all the claims in thousands of Swedish Kronor
5
6 108, 392.5
7 19, 46.2
8 13, 15.7
9 124, 422.2
10 40, 119.4
11 57, 170.9
12 23, 56.9
13 14, 77.5
14 45, 214
15 10, 65.3
16 5, 20.9
17 48, 248.1
18 11, 23.5
19 23, 39.6
20 7, 48.8
21 2, 6.6
22 24, 134.9
23 6, 50.9
24 3, 4.4
25 23, 113
26 6, 14.8
27 9, 48.7
28 9, 52.1
29 3, 13.2
30 29, 103.9
31 7, 77.5
32 4, 11.8
33 20, 98.1
34 7, 27.9
35 4, 38.1
36 0, 0
37 25, 69.2
38 6, 14.6
39 5, 40.3
40 22, 161.5
41 11, 57.2
42 61, 217.6
43 12, 58.1
44 4, 12.6
45 16, 59.6
46 13, 89.9
47 60, 202.4
48 41, 181.3
49 37, 152.8
50 55, 162.8
51 41, 73.4
52 11, 21.3
53 27, 92.6
54 8, 76.1
55 3, 39.9
56 17, 142.1
57 13, 93
58 13, 31.9
59 15, 32.1
60 8, 55.6
61 29, 133.3
62 30, 194.5
63 24, 137.9
64 9, 87.4
65 31, 209.8
66 14, 95.5
67 53, 244.6
68 26, 187.5

Просмотреть файл

@ -0,0 +1,73 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Models;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Distributions.Kernels;
using System;
using System.Linq;
namespace RobustGaussianProcess
{
/// <summary>
/// Class to generate synthetic data
/// 1) randomly sample a 1D function from a GP;
/// 2) pick a random subset of 'numData' points;
/// 3) pick a further random proportion 'propCorrupt' of 'numData' to corrupt according to a uniform distribution with a range of -1 to 1
/// </summary>
class GaussianProcessDataGenerator
{
public static (Vector[] dataX, double[] dataY) GenerateRandomData(int numData, double proportionCorrupt)
{
int randomSeed = 9876;
Random rng = new Random(randomSeed);
Rand.Restart(randomSeed);
InferenceEngine engine = Utilities.GetInferenceEngine();
// The points to evaluate
Vector[] randomInputs = Utilities.VectorRange(0, 1, numData, null);
var gaussianProcessGenerator = new GaussianProcessRegressor(randomInputs);
// The basis
Vector[] basis = Utilities.VectorRange(0, 1, 6, rng);
// The kernel
var kf = new SummationKernel(new SquaredExponential(-1)) + new WhiteNoise();
// Fill in the sparse GP prior
GaussianProcess gp = new GaussianProcess(new ConstantFunction(0), kf);
gaussianProcessGenerator.Prior.ObservedValue = new SparseGP(new SparseGPFixed(gp, basis));
// Infer the posterior Sparse GP, and sample a random function from it
SparseGP sgp = engine.Infer<SparseGP>(gaussianProcessGenerator.F);
var randomFunc = sgp.Sample();
double[] randomOutputs = new double[randomInputs.Length];
int numCorrupted = (int)Math.Ceiling(numData * proportionCorrupt);
var subset = Enumerable.Range(0, randomInputs.Length + 1).OrderBy(x => rng.Next()).Take(numCorrupted);
// get random data
for (int i = 0; i < randomInputs.Length; i++)
{
double post = randomFunc.Evaluate(randomInputs[i]);
// corrupt data point if it we haven't exceed the proportion we wish to corrupt
if (subset.Contains(i))
{
double sign = rng.NextDouble() > 0.5 ? 1 : -1;
double distance = rng.NextDouble() * 1;
post = (sign * distance) + post;
}
randomOutputs[i] = post;
}
Console.WriteLine("Model complete: Generated {0} points with {1} corrupted", numData, numCorrupted);
return (randomInputs, randomOutputs);
}
}
}

Просмотреть файл

@ -0,0 +1,86 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Models;
using System;
namespace RobustGaussianProcess
{
public class GaussianProcessRegressor
{
public Variable<SparseGP> Prior { get; }
public Variable<bool> Evidence { get; }
public Variable<IFunction> F { get; }
public IfBlock Block;
public VariableArray<Vector> X { get; }
public Range J;
/// <summary>
/// Helper class to instantiate Gaussian Process regressor
/// </summary>
public GaussianProcessRegressor(Vector[] trainingInputs, bool closeBlock = true)
{
// Modelling code
Evidence = Variable.Bernoulli(0.5).Named("evidence");
Block = Variable.If(Evidence);
Prior = Variable.New<SparseGP>().Named("prior");
F = Variable<IFunction>.Random(Prior).Named("f");
X = Variable.Observed(trainingInputs).Named("x");
J = X.Range.Named("j");
// If generating data, we can close block here
if (closeBlock)
{
Block.CloseBlock();
}
}
public GaussianProcessRegressor(Vector[] trainingInputs, bool useStudentTLikelihood, double[] trainingOutputs) : this(trainingInputs, closeBlock: false)
{
VariableArray<double> y = Variable.Observed(trainingOutputs, J).Named("y");
if (!useStudentTLikelihood)
{
// Standard Gaussian Process
Console.WriteLine("Training a Gaussian Process regressor");
var score = GetScore(X, F, J);
y[J] = Variable.GaussianFromMeanAndVariance(score, 0.0);
}
else
{
// Gaussian Process with Student-t likelihood
Console.WriteLine("Training a Gaussian Process regressor with Student-t likelihood");
var noisyScore = GetNoisyScore(X, F, J, trainingOutputs);
y[J] = Variable.GaussianFromMeanAndVariance(noisyScore[J], 0.0);
}
Block.CloseBlock();
}
/// <summary>
/// Score for standard Gaussian Process
/// </summary>
private static Variable<double> GetScore(VariableArray<Vector> x, Variable<IFunction> f, Range j)
{
return Variable.FunctionEvaluate(f, x[j]);
}
/// <summary>
/// Score for Gaussian Process with Student-t
/// </summary>
private static VariableArray<double> GetNoisyScore(VariableArray<Vector> x, Variable<IFunction> f, Range j, double[] trainingOutputs)
{
// The student-t distribution arises as the mean of a normal distribution once an unknown precision is marginalised out
Variable<double> score = GetScore(x, f, j);
VariableArray<double> noisyScore = Variable.Observed(trainingOutputs, j).Named("noisyScore");
using (Variable.ForEach(j))
{
// The precision of the Gaussian is modelled with a Gamma distribution
var precision = Variable.GammaFromShapeAndRate(4, 1).Named("precision");
noisyScore[j] = Variable.GaussianFromMeanAndPrecision(score, precision);
}
return noisyScore;
}
}
}

Просмотреть файл

@ -0,0 +1,68 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Distributions.Kernels;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Models;
using System;
using System.Linq;
namespace RobustGaussianProcess
{
class Program
{
/// <summary>
/// Main for Gaussian Process regression example
/// Fits two datasets (real and synthetic) using a standard Gaussian Process and a Robust Student-T Gaussian Process
/// </summary>
static void Main()
{
//FitDataset(useSynthetic: false);
FitDataset(useSynthetic: true);
}
static void FitDataset(bool useSynthetic)
{
Vector[] trainingInputs;
double[] trainingOutputs;
if (!useSynthetic)
{
var trainingData = Utilities.LoadAISDataset();
trainingInputs = trainingData.Select(tup => Vector.FromArray(new double[1] { tup.x })).ToArray();
trainingOutputs = trainingData.Select(tup => tup.y).ToArray();
}
else
{
(trainingInputs, trainingOutputs) = GaussianProcessDataGenerator.GenerateRandomData(30, 0.3);
}
InferenceEngine engine = Utilities.GetInferenceEngine();
// First fit standard GP, then fit Student-T GP
foreach (var useStudentTLikelihood in new[] { false, true })
{
var gaussianProcessRegressor = new GaussianProcessRegressor(trainingInputs, useStudentTLikelihood, trainingOutputs);
// Log length scale estimated as -1
var noiseVariance = 0.8;
var kf = new SummationKernel(new SquaredExponential(-1)) + new WhiteNoise(Math.Log(noiseVariance) / 2);
GaussianProcess gp = new GaussianProcess(new ConstantFunction(0), kf);
// Convert SparseGP to full Gaussian Process by evaluating at all the training points
gaussianProcessRegressor.Prior.ObservedValue = new SparseGP(new SparseGPFixed(gp, trainingInputs.ToArray()));
double logOdds = engine.Infer<Bernoulli>(gaussianProcessRegressor.Evidence).LogOdds;
Console.WriteLine("{0} evidence = {1}", kf, logOdds.ToString("g4"));
// Infer the posterior Sparse GP
SparseGP sgp = engine.Infer<SparseGP>(gaussianProcessRegressor.F);
#if NETFULL
string datasetName = useSynthetic ? "Synthetic" : "AIS";
Utilities.PlotPredictions(sgp, trainingInputs, trainingOutputs, useStudentTLikelihood, datasetName);
#endif
}
}
}
}

Просмотреть файл

@ -0,0 +1,79 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<GenerateAssemblyInfo>false</GenerateAssemblyInfo>
<OutputType>Exe</OutputType>
<WarningLevel>4</WarningLevel>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<WarningsAsErrors />
<ErrorReport>prompt</ErrorReport>
<Prefer32Bit>false</Prefer32Bit>
<DefineConstants>TRACE</DefineConstants>
<Configurations>Debug;Release;DebugFull;DebugCore;ReleaseFull;ReleaseCore</Configurations>
</PropertyGroup>
<Choose>
<When Condition="'$(Configuration)'=='DebugFull' OR '$(Configuration)'=='ReleaseFull'">
<PropertyGroup>
<TargetFramework>net461</TargetFramework>
</PropertyGroup>
</When>
<When Condition="'$(Configuration)'=='DebugCore' OR '$(Configuration)'=='ReleaseCore'">
<PropertyGroup>
<TargetFramework>netcoreapp2.1</TargetFramework>
</PropertyGroup>
</When>
<Otherwise>
<PropertyGroup>
<TargetFrameworks>netcoreapp2.1;net461</TargetFrameworks>
</PropertyGroup>
</Otherwise>
</Choose>
<PropertyGroup Condition=" '$(TargetFramework)' == 'netcoreapp2.1'">
<DefineConstants>$(DefineConstants);NETCORE;NETSTANDARD;NETSTANDARD2_0</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition=" '$(TargetFramework)' == 'net461'">
<DefineConstants>$(DefineConstants);NETFULL</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugFull|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugCore|AnyCPU'">
<DebugSymbols>true</DebugSymbols>
<DebugType>full</DebugType>
<Optimize>false</Optimize>
<DefineConstants>$(DefineConstants);DEBUG</DefineConstants>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|AnyCPU' OR '$(Configuration)|$(Platform)'=='ReleaseFull|AnyCPU' OR '$(Configuration)|$(Platform)'=='ReleaseCore|AnyCPU'">
<DebugType>pdbonly</DebugType>
<Optimize>true</Optimize>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugFull|AnyCPU'">
<PlatformTarget>AnyCPU</PlatformTarget>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='ReleaseFull|AnyCPU'">
<PlatformTarget>AnyCPU</PlatformTarget>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\Compiler\Compiler.csproj" />
<ProjectReference Include="..\..\Runtime\Runtime.csproj" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\Shared\SharedAssemblyFileVersion.cs" />
<Compile Include="..\..\Shared\SharedAssemblyInfo.cs" />
</ItemGroup>
<ItemGroup Condition="$(DefineConstants.Contains('NETFULL'))">
<PackageReference Include="OxyPlot.Wpf" Version="1.0.0" />
</ItemGroup>
<ItemGroup>
<None Update="Data\insurance.csv">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,176 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Algorithms;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Models;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
#if NETFULL
using OxyPlot.Wpf;
using OxyPlot;
using System.Threading;
using Microsoft.ML.Probabilistic.Distributions;
#endif
namespace RobustGaussianProcess
{
class Utilities
{
// Train Gaussian Process on the small 'Auto Insurance in Sweden' dataset
// The insurance.csv file can be found in the Data directory
private const string AisCsvPath = @"Data\insurance.csv";
// Path for the results plot
private const string OutputPlotPath = @"Data";
/// <summary>
/// Generates a 1D vector with length len having a min and max; data points are randomly distributed and ordered if specified
/// </summary>
public static Vector[] VectorRange(double min, double max, int len, Random rng)
{
var inputs = new double[len];
double num;
for (int i = 0; i < len; i++)
{
if (rng != null)
{
num = rng.NextDouble();
}
else
{
num = i / (double)(len - 1);
}
num = num * (max - min);
num += min;
inputs[i] = num;
}
if (rng != null)
{
inputs = inputs.OrderBy(x => x).ToArray();
}
return inputs.Select(x => Vector.FromArray(new double[1] { x })).ToArray();
}
/// <summary>
/// Read the Auto Insurance in Sweden dataset from its CSV file
/// </summary>
public static IEnumerable<(double x, double y)> LoadAISDataset()
{
var data = new List<(double x, double y)>();
// Read CSV file
using (var reader = new StreamReader(AisCsvPath))
{
while (!reader.EndOfStream)
{
var line = reader.ReadLine();
var values = line.Split(',');
if (values.Length == 2)
{
data.Add((double.Parse(values[0]), double.Parse(values[1])));
}
}
}
return PreprocessData(data).ToList();
}
private static IEnumerable<(double x, double y)> PreprocessData(
IEnumerable<(double x, double y)> data)
{
var x = data.Select(tup => tup.x);
var y = data.Select(tup => tup.y);
// Shift targets so mean is 0
var meanY = y.Sum() / y.Count();
y = y.Select(val => val - meanY);
// Scale data to lie between 1 and -1
var absoluteMaxY = y.Select(val => Math.Abs(val)).Max();
y = y.Select(val => val / absoluteMaxY);
var maxX = x.Max();
x = x.Select(val => val / maxX);
var dataset = x.Zip(y, (a, b) => (a, b));
// Order data by input value
return dataset.OrderBy(tup => tup.Item1);
}
public static InferenceEngine GetInferenceEngine()
{
InferenceEngine engine = new InferenceEngine();
if (!(engine.Algorithm is ExpectationPropagation))
{
throw new ArgumentException("This example only runs with Expectation Propagation");
}
return engine;
}
#if NETFULL
public static void PlotGraph(PlotModel model, string graphPath)
{
// Required for plotting
Thread thread = new Thread(() => PngExporter.Export(model, graphPath, 800, 600, OxyColors.White));
thread.SetApartmentState(ApartmentState.STA);
thread.Start();
}
public static void PlotPredictions(SparseGP sgp, Vector[] trainingInputs, double[] trainingOutputs, bool useStudentT, string dataset)
{
var meanSeries = new OxyPlot.Series.LineSeries { Title = "Mean function", Color = OxyColors.SkyBlue, };
var scatterSeries = new OxyPlot.Series.ScatterSeries { Title = "Training points" };
var areaSeries = new OxyPlot.Series.AreaSeries { Title = "\u00B1 2\u03C3", Color = OxyColors.PowderBlue };
double sqDiff = 0;
for (int i = 0; i < trainingInputs.Length; i++)
{
Gaussian post = sgp.Marginal(trainingInputs[i]);
double postMean = post.GetMean();
var xTrain = trainingInputs[i][0];
meanSeries.Points.Add(new DataPoint(xTrain, postMean));
scatterSeries.Points.Add(new OxyPlot.Series.ScatterPoint(xTrain, trainingOutputs[i]));
var stdDev = Math.Sqrt(post.GetVariance());
areaSeries.Points.Add(new DataPoint(xTrain, postMean + (2 * stdDev)));
areaSeries.Points2.Add(new DataPoint(xTrain, postMean - (2 * stdDev)));
sqDiff += Math.Pow(postMean - trainingOutputs[i], 2);
}
Console.WriteLine("RMSE is: {0}", Math.Sqrt(sqDiff / trainingOutputs.Length));
var model = new PlotModel();
string pngPath;
if (!useStudentT)
{
model.Title = $"Gaussian Process trained on {dataset} dataset";
pngPath = Path.Combine(OutputPlotPath, $"{dataset}.png");
}
else
{
model.Title = $"Gaussian Process trained on {dataset} dataset (Student-t likelihood)";
pngPath = Path.Combine(OutputPlotPath, $"StudentT{dataset}.png");
}
model.Series.Add(meanSeries);
model.Series.Add(scatterSeries);
model.Series.Add(areaSeries);
model.Axes.Add(new OxyPlot.Axes.LinearAxis {
Position = OxyPlot.Axes.AxisPosition.Bottom,
Title = "x" });
model.Axes.Add(new OxyPlot.Axes.LinearAxis {
Position = OxyPlot.Axes.AxisPosition.Left,
Title = "y" });
PlotGraph(model, pngPath);
Console.WriteLine($"Saved PNG to {Path.GetFullPath(pngPath)}");
}
#endif
}
}

Просмотреть файл

@ -3009,7 +3009,7 @@ rr = mpf('-0.99999824265582826');
}
// logProbX - logProbY = -x^2/2 + y^2/2 = (y+x)*(y-x)/2
ExtendedDouble n;
if (logProbX > logProbY)
if (logProbX > logProbY || (logProbX == logProbY && x < y))
{
n = new ExtendedDouble(rPlus1 + r * ExpMinus1(xPlusy * (x - y) / 2), logProbX);
}
@ -4348,10 +4348,12 @@ else if (m < 20.0 - 60.0/11.0 * s) {
// subnormal numbers are linearly spaced, which can lead to lowerBound being too large. Set lowerBound to zero to avoid this.
const double maxSubnormal = 2.3e-308;
if (lowerBound > 0 && lowerBound < maxSubnormal) lowerBound = 0;
else if (lowerBound < 0 && lowerBound > -maxSubnormal) lowerBound = -maxSubnormal;
double upperBound = (double)Math.Min(double.MaxValue, denominator * NextDouble(ratio));
if (upperBound == 0 && ratio > 0) upperBound = denominator; // must have ratio < 1
if (double.IsNegativeInfinity(upperBound)) return upperBound; // must have ratio < -1 and denominator > 1
if (upperBound < 0 && upperBound > -maxSubnormal) upperBound = 0;
else if (upperBound > 0 && upperBound < maxSubnormal) upperBound = maxSubnormal;
if (double.IsNegativeInfinity(ratio))
{
if (AreEqual(upperBound / denominator, ratio)) return upperBound;

Просмотреть файл

@ -378,8 +378,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <summary>
/// Stores built automaton in pre-allocated <see cref="Automaton{TSequence,TElement,TElementDistribution,TSequenceManipulator,TThis}"/> object.
/// </summary>
public DataContainer GetData(
DeterminizationState determinizationState = DeterminizationState.Unknown)
public DataContainer GetData(bool? isDeterminized = null)
{
if (this.StartStateIndex < 0 || this.StartStateIndex >= this.states.Count)
{
@ -425,11 +424,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
return new DataContainer(
this.StartStateIndex,
resultStates,
resultTransitions,
!hasEpsilonTransitions,
usesGroups,
determinizationState,
resultStates,
resultTransitions);
isDeterminized,
isZero: null);
}
#endregion

Просмотреть файл

@ -51,47 +51,69 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
public bool UsesGroups => (this.flags & Flags.UsesGroups) != 0;
/// <summary>
/// Gets value indicating whether this automaton is
/// Gets value indicating whether this automaton is determinized
/// </summary>
public DeterminizationState DeterminizationState =>
((this.flags & Flags.DeterminizationStateKnown) == 0)
? DeterminizationState.Unknown
: ((this.flags & Flags.IsDeterminized) != 0
? DeterminizationState.IsDeterminized
: DeterminizationState.IsNonDeterminizable);
/// <remarks>
/// Null value means that this property is unknown.
/// False value means that this automaton can not be determinized
/// </remarks>
public bool? IsDeterminized =>
(this.flags & Flags.DeterminizationStateKnown) != 0
? (this.flags & Flags.IsDeterminized) != 0
: (bool?)null;
/// <summary>
/// Gets value indicating whether this automaton is zero
/// </summary>
/// <remarks>
/// Null value means that this property is unknown.
/// </remarks>
public bool? IsZero =>
((this.flags & Flags.IsZeroStateKnown) != 0)
? (this.flags & Flags.IsZero) != 0
: (bool?)null;
/// <summary>
/// Initializes instance of <see cref="DataContainer"/>.
/// </summary>
[Construction("StartStateIndex", "IsEpsilonFree", "UsesGroups", "DeterminizationState", "States", "Transitions")]
[Construction("StartStateIndex", "States", "Transitions", "IsEpsilonFree", "UsesGroups", "IsDeterminized", "IsZero")]
public DataContainer(
int startStateIndex,
ReadOnlyArray<StateData> states,
ReadOnlyArray<Transition> transitions,
bool isEpsilonFree,
bool usesGroups,
DeterminizationState determinizationState,
ReadOnlyArray<StateData> states,
ReadOnlyArray<Transition> transitions)
bool? isDeterminized,
bool? isZero)
{
this.flags =
(isEpsilonFree ? Flags.IsEpsilonFree : 0) |
(usesGroups ? Flags.UsesGroups : 0) |
(determinizationState != DeterminizationState.Unknown ? Flags.DeterminizationStateKnown : 0) |
(determinizationState == DeterminizationState.IsDeterminized ? Flags.IsDeterminized : 0);
(isDeterminized.HasValue ? Flags.DeterminizationStateKnown : 0) |
(isDeterminized == true ? Flags.IsDeterminized : 0) |
(isZero.HasValue ? Flags.IsZeroStateKnown : 0) |
(isZero == true ? Flags.IsZero : 0);
this.StartStateIndex = startStateIndex;
this.States = states;
this.Transitions = transitions;
}
public DataContainer WithDeterminizationState(DeterminizationState determinizationState)
public DataContainer With(
bool? isDeterminized = null,
bool? isZero= null)
{
Debug.Assert(this.DeterminizationState == DeterminizationState.Unknown);
// Can't overwrite known properties
Debug.Assert(isDeterminized.HasValue != this.IsDeterminized.HasValue || isDeterminized == this.IsDeterminized);
Debug.Assert(isZero.HasValue != this.IsZero.HasValue || isZero == this.IsZero);
return new DataContainer(
this.StartStateIndex,
this.States,
this.Transitions,
this.IsEpsilonFree,
this.UsesGroups,
determinizationState,
this.States,
this.Transitions);
isDeterminized ?? this.IsDeterminized,
isZero ?? this.IsZero);
}
/// <summary>
@ -170,14 +192,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
UsesGroups = 0x2,
DeterminizationStateKnown = 0x4,
IsDeterminized = 0x8,
IsZeroStateKnown = 0x10,
IsZero = 0x20,
}
}
public enum DeterminizationState
{
Unknown,
IsDeterminized,
IsNonDeterminizable,
}
}
}

Просмотреть файл

@ -32,9 +32,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <remarks>See <a href="http://www.cs.nyu.edu/~mohri/pub/hwa.pdf"/> for algorithm details.</remarks>
public bool TryDeterminize()
{
if (this.Data.DeterminizationState != DeterminizationState.Unknown)
if (this.Data.IsDeterminized != null)
{
return this.Data.DeterminizationState == DeterminizationState.IsDeterminized;
return this.Data.IsDeterminized == true;
}
int maxStatesBeforeStop = Math.Min(this.States.Count * 3, MaxStateCount);
@ -44,7 +44,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (this.UsesGroups)
{
// Determinization will result in lost of group information, which we cannot allow
this.Data = this.Data.WithDeterminizationState(DeterminizationState.IsNonDeterminizable);
this.Data = this.Data.With(isDeterminized: false);
return false;
}
@ -87,6 +87,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (!EnqueueOutgoingTransitions(currentWeightedStateSet))
{
this.Data = this.Data.With(isDeterminized: false);
return false;
}
}
@ -99,7 +100,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var simplification = new Simplification(builder, this.PruneStatesWithLogEndWeightLessThan);
simplification.MergeParallelTransitions(); // Determinization produces a separate transition for each segment
this.Data = builder.GetData().WithDeterminizationState(DeterminizationState.IsDeterminized);
this.Data = builder.GetData().With(isDeterminized: true);
this.PruneStatesWithLogEndWeightLessThan = this.PruneStatesWithLogEndWeightLessThan;
this.LogValueOverride = this.LogValueOverride;
@ -152,8 +153,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
// Checks that all transitions from state end up in the same destination. This is used
// as a very fast "is determenistic" check, that doesn't care about distributions.
// State can have determenistic transitions with different destinations. This case will be
// as a very fast "is deterministic" check, that doesn't care about distributions.
// State can have deterministic transitions with different destinations. This case will be
// handled by slow path.
bool AllDestinationsAreSame(int stateIndex)
{

Просмотреть файл

@ -1006,38 +1006,45 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </remarks>
public bool IsZero()
{
if (this.IsCanonicZero())
// Return cached value if available
if (this.Data.IsZero.HasValue)
{
return true;
return this.Data.IsZero.Value;
}
var visitedStates = new BitArray(this.States.Count, false);
return DoIsZero(this.Start.Index);
// Calculate and cache whether this automaton is zero
var isZero = DoIsZero();
this.Data = this.Data.With(isZero: isZero);
return isZero;
bool DoIsZero(int stateIndex)
bool DoIsZero()
{
if (visitedStates[stateIndex])
{
return true;
}
var visited = new BitArray(this.States.Count, false);
var stack = new Stack<int>();
stack.Push(this.Start.Index);
visited[this.Start.Index] = true;
visitedStates[stateIndex] = true;
var state = this.States[stateIndex];
var isZero = !state.CanEnd;
var transitionIndex = 0;
while (isZero && transitionIndex < state.Transitions.Count)
while (stack.Count > 0)
{
var transition = state.Transitions[transitionIndex];
if (!transition.Weight.IsZero)
var stateIndex = stack.Pop();
var state = this.States[stateIndex];
if (state.CanEnd)
{
isZero = DoIsZero(transition.DestinationStateIndex);
return false;
}
++transitionIndex;
foreach (var transition in state.Transitions)
{
if (!visited[transition.DestinationStateIndex] && !transition.Weight.IsZero)
{
stack.Push(transition.DestinationStateIndex);
visited[transition.DestinationStateIndex] = true;
}
}
}
return isZero;
return true;
}
}
@ -1420,11 +1427,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
simplification.RemoveDeadStates(); // Product can potentially create dead states
simplification.SimplifyIfNeeded();
var bothInputsDeterminized =
automaton1.Data.DeterminizationState == DeterminizationState.IsDeterminized &&
automaton2.Data.DeterminizationState == DeterminizationState.IsDeterminized;
var determinizationState =
bothInputsDeterminized ? DeterminizationState.IsDeterminized : DeterminizationState.Unknown;
var bothInputsDeterminized = automaton1.Data.IsDeterminized == true && automaton2.Data.IsDeterminized == true;
var determinizationState = bothInputsDeterminized ? (bool?)true : null;
this.Data = builder.GetData(determinizationState);
if (this is StringAutomaton && tryDeterminize)
@ -1620,7 +1624,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
public void SetToZero()
{
this.Data = new DataContainer(
0, true, false, DeterminizationState.IsDeterminized, ZeroStates, ZeroTransitions);
0,
ZeroStates,
ZeroTransitions,
isEpsilonFree: true,
usesGroups: false,
isDeterminized: true,
isZero: true);
}
/// <summary>

Просмотреть файл

@ -104,7 +104,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
else
{
Debug.Assert(currentSegmentStateWeights.ContainsKey(segmentBound.DestinationStateId), "We shouldn't exit a state we didn't enter.");
Debug.Assert(!segmentBound.Weight.IsInfinity);
currentSegmentTotal -= segmentBound.Weight;
var prevStateWeight = currentSegmentStateWeights[segmentBound.DestinationStateId];

Просмотреть файл

@ -319,12 +319,26 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
if (x <= 0)
throw new ArgumentException("x <= 0");
double a = -x * x * ddLogP;
if (ddLogP == 0.0)
a = 0.0; // in case x is infinity
double a;
if(double.IsPositiveInfinity(x))
{
if (ddLogP < 0) return Gamma.PointMass(x);
else if (ddLogP == 0) a = 0.0;
else if (forceProper)
{
if (dLogP <= 0) return Gamma.FromShapeAndRate(1, -dLogP);
else return Gamma.PointMass(x);
}
else return Gamma.FromShapeAndRate(-x, -x - dLogP);
}
else
{
a = -x * x * ddLogP;
if (a+1 > double.MaxValue) return Gamma.PointMass(x);
}
if (forceProper)
{
if (dLogP < 0)
if (dLogP <= 0)
{
if (a < 0)
a = 0;
@ -333,7 +347,9 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
double amin = x * dLogP;
if (a < amin)
a = amin;
{
return Gamma.FromShapeAndRate(amin + 1, 0);
}
}
}
double b = a / x - dLogP;

Просмотреть файл

@ -120,10 +120,17 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
double prec = 1.0 / variance;
double meanTimesPrecision = prec * mean;
if ((prec > double.MaxValue) || (Math.Abs(meanTimesPrecision) > double.MaxValue))
if (prec > double.MaxValue)
{
Point = mean;
}
else if (Math.Abs(meanTimesPrecision) > double.MaxValue)
{
// This can happen when precision is too high.
// Lower the precision until meanTimesPrecision fits in the double-precision range.
MeanTimesPrecision = Math.Sign(mean) * double.MaxValue;
Precision = MeanTimesPrecision / mean;
}
else
{
Precision = prec;
@ -176,9 +183,9 @@ namespace Microsoft.ML.Probabilistic.Distributions
double meanTimesPrecision = precision * mean;
if (Math.Abs(meanTimesPrecision) > double.MaxValue)
{
// If the precision is so large that it causes numerical overflow,
// treat the distribution as a point mass.
Point = mean;
// Lower the precision until meanTimesPrecision fits in the double-precision range.
MeanTimesPrecision = Math.Sign(mean) * double.MaxValue;
Precision = MeanTimesPrecision / mean;
}
else
{

Просмотреть файл

@ -259,14 +259,44 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
if (beta == null)
{
//beta = (FixedParameters.KernelOf_B_B + InducingDist.GetVariance()).Inverse();
beta = new PositiveDefiniteMatrix(FixedParameters.NumberBasisPoints, FixedParameters.NumberBasisPoints);
beta.SetToDifference(InducingDist.Precision, InducingDist.Precision*Var_B_B*InducingDist.Precision);
bool UseVarBB = InducingDist.Precision.Trace() < double.MaxValue;
if (UseVarBB)
{
beta = new PositiveDefiniteMatrix(FixedParameters.NumberBasisPoints, FixedParameters.NumberBasisPoints);
beta.SetToDifference(InducingDist.Precision, InducingDist.Precision * Var_B_B * InducingDist.Precision);
}
else
{
beta = GetInverse(FixedParameters.KernelOf_B_B + GetInverse(InducingDist.Precision));
}
}
return beta;
}
}
private static PositiveDefiniteMatrix GetInverse(PositiveDefiniteMatrix A)
{
PositiveDefiniteMatrix result = new PositiveDefiniteMatrix(A.Rows, A.Cols);
LowerTriangularMatrix L = new LowerTriangularMatrix(A.Rows, A.Cols);
L.SetToCholesky(A);
bool[] isZero = new bool[L.Rows];
for (int i = 0; i < L.Rows; i++)
{
if (L[i, i] == 0)
{
isZero[i] = true;
L[i, i] = 1;
}
}
L.SetToInverse(L);
result.SetToOuterTranspose(L);
for (int i = 0; i < isZero.Length; i++)
{
if (isZero[i]) result[i, i] = double.PositiveInfinity;
}
return result;
}
#endregion
#region Calculated properties
@ -485,7 +515,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
double kxx = FixedParameters.Prior.Variance(X);
Vector kxb = FixedParameters.KernelOf_X_B(X);
Gaussian result = new Gaussian(
kxb.Inner(Alpha) + FixedParameters.Prior.Mean(X), kxx - Beta.QuadraticForm(kxb));
kxb.Inner(Alpha) + FixedParameters.Prior.Mean(X), Math.Max(0, kxx - Beta.QuadraticForm(kxb)));
return result;
}
}
@ -1190,7 +1220,11 @@ namespace Microsoft.ML.Probabilistic.Distributions
public Vector YPoints
{
get { return ypoints; }
set { ypoints = value.Clone(); }
set
{
if (value.Count < 2) throw new ArgumentException($"value.Count ({value.Count}) < 2", nameof(value));
ypoints = value.Clone();
}
}
#region IFunction Members

Просмотреть файл

@ -502,7 +502,7 @@ namespace Microsoft.ML.Probabilistic.Factors
// x2ddlogf = (-yv*r-0.5+r*ym*ym)/(yv*r+1)^2 - r*ym*ym/(yv*r+1)^3
// xdlogf + x2ddlogf = (-0.5*yv*r + 0.5*r*ym*ym)/(yv*r+1)^2 - r*ym*ym/(yv*r+1)^3
// as r->0: -0.5*r*(yv + ym*ym)
if (precision == 0) return Gamma.FromShapeAndRate(1.5, 0.5 * (yv + ym * ym));
if (precision == 0 || yv == 0) return Gamma.FromShapeAndRate(1.5, 0.5 * (yv + ym * ym));
// point mass case
// f(r) = N(xm;mm, xv+mv+1/r)
// log f(r) = -0.5*log(xv+mv+1/r) - 0.5*(xm-mm)^2/(xv+mv+1/r)
@ -512,7 +512,7 @@ namespace Microsoft.ML.Probabilistic.Factors
// r^2 (log f)'' = -1/(yv*r + 1) + r*ym^2/(yv*r+1)^2) + 0.5/(yv*r+1)^2 - r*ym^2/(yv*r+1)^3
double vdenom = 1 / (yv * precision + 1);
double ymvdenom = ym * vdenom;
double ymvdenom2 = precision * ymvdenom * ymvdenom;
double ymvdenom2 = (precision > double.MaxValue) ? 0 : (precision * ymvdenom * ymvdenom);
//dlogf = (-0.5 * denom + 0.5 * ym2denom2) * (-v2);
//dlogf = 0.5 * (1 - ym * ymdenom) * denom * v2;
//dlogf = 0.5 * (v - ym * ym/(yv*precision+1))/(yv*precision + 1);

Просмотреть файл

@ -328,8 +328,7 @@ namespace Microsoft.ML.Probabilistic.Factors
{
// X is not a point mass or uniform
double d_p = 2 * isBetween.GetProbTrue() - 1;
double mx, vx;
X.GetMeanAndVariance(out mx, out vx);
double mx = X.GetMean();
double center = MMath.Average(lowerBound, upperBound);
if (d_p == 1.0)
{
@ -370,9 +369,9 @@ namespace Microsoft.ML.Probabilistic.Factors
// In this case, alpha and beta will be very small.
double logZ = LogAverageFactor(isBetween, X, lowerBound, upperBound);
if (logZ == 0) return Gaussian.Uniform();
double logPhiL = Gaussian.GetLogProb(lowerBound, mx, vx);
double logPhiL = X.GetLogProb(lowerBound);
double alphaL = d_p * Math.Exp(logPhiL - logZ);
double logPhiU = Gaussian.GetLogProb(upperBound, mx, vx);
double logPhiU = X.GetLogProb(upperBound);
double alphaU = d_p * Math.Exp(logPhiU - logZ);
double alphaX = alphaL - alphaU;
double betaX = alphaX * alphaX;
@ -419,9 +418,13 @@ namespace Microsoft.ML.Probabilistic.Factors
double rU = MMath.NormalCdfRatio(zU);
double r1U = MMath.NormalCdfMomentRatio(1, zU);
double r3U = MMath.NormalCdfMomentRatio(3, zU) * 6;
if (zU < -1e20)
if (zU < -173205080)
{
// in this regime, rU = -1/zU, r1U = rU*rU
// because rU = -1/zU + 1/zU^3 + ...
// and r1U = 1/zU^2 - 3/zU^4 + ...
// The second term is smaller by a factor of 3/zU^2.
// The threshold satisfies 3/zU^2 == 1e-16 or zU < -sqrt(3e16)
if (expMinus1 > 1e100)
{
double invzUs = 1 / (zU * sqrtPrec);
@ -455,10 +458,18 @@ namespace Microsoft.ML.Probabilistic.Factors
}
}
// Abs is needed to avoid some 32-bit oddities.
double prec2 = (expMinus1Ratio * expMinus1Ratio) /
Math.Abs(r1U / X.Precision * expMinus1 * expMinus1RatioMinus1RatioMinusHalf
+ rU / sqrtPrec * diff * (expMinus1RatioMinus1RatioMinusHalf - delta / 2 * (expMinus1RatioMinus1RatioMinusHalf + 1))
+ diff * diff / 4);
double prec2 = (expMinus1Ratio * expMinus1Ratio * X.Precision) /
Math.Abs(r1U * expMinus1 * expMinus1RatioMinus1RatioMinusHalf
+ rU * diffs * (expMinus1RatioMinus1RatioMinusHalf - delta / 2 * (expMinus1RatioMinus1RatioMinusHalf + 1))
+ diffs * diffs / 4);
if (prec2 > double.MaxValue)
{
// same as above but divide top and bottom by X.Precision, to avoid overflow
prec2 = (expMinus1Ratio * expMinus1Ratio) /
Math.Abs(r1U / X.Precision * expMinus1 * expMinus1RatioMinus1RatioMinusHalf
+ rU / sqrtPrec * diff * (expMinus1RatioMinus1RatioMinusHalf - delta / 2 * (expMinus1RatioMinus1RatioMinusHalf + 1))
+ diff * diff / 4);
}
return Gaussian.FromMeanAndPrecision(mp2, prec2) / X;
}
}
@ -476,9 +487,7 @@ namespace Microsoft.ML.Probabilistic.Factors
else mp2 = lowerBound + r1U / rU / sqrtPrec;
// This approach loses accuracy when r1U/(rU*rU) < 1e-3, which is zU > 3.5
if (zU > 3.5) throw new Exception("zU > 3.5");
double prec2 = rU * rU * X.Precision;
if (prec2 != 0) // avoid 0/0
prec2 /= NormalCdfRatioSqrMinusDerivative(zU, rU, r1U, r3U);
double prec2 = X.Precision * (rU * rU / NormalCdfRatioSqrMinusDerivative(zU, rU, r1U, r3U));
//Console.WriteLine($"z = {zU:r} r = {rU:r} r1 = {r1U:r} r1U/rU = {r1U / rU:r} r1U/rU/sqrtPrec = {r1U / rU / sqrtPrec:r} sqrtPrec = {sqrtPrec:r} mp = {mp2:r}");
return Gaussian.FromMeanAndPrecision(mp2, prec2) / X;
}
@ -571,12 +580,13 @@ namespace Microsoft.ML.Probabilistic.Factors
if (delta == 0) // avoid 0*infinity
qOverPrec = (r1L + drU2) * diff * (drU3 / sqrtPrec - diff / 4 + rL / 2 * diffs / 2 * diff / 2);
double vp = qOverPrec * alphaXcLprecDiff * alphaXcLprecDiff;
if (double.IsNaN(qOverPrec) || 1/vp < X.Precision) return Gaussian.FromMeanAndPrecision(mp, MMath.NextDouble(X.Precision)) / X;
return new Gaussian(mp, vp) / X;
}
else
{
double logZ = LogAverageFactor(isBetween, X, lowerBound, upperBound);
if (d_p == -1.0 && logZ < double.MinValue)
Gaussian GetPointMessage()
{
if (mx == center)
{
@ -596,9 +606,10 @@ namespace Microsoft.ML.Probabilistic.Factors
return Gaussian.PointMass(upperBound);
}
}
double logPhiL = Gaussian.GetLogProb(lowerBound, mx, vx);
if (d_p == -1.0 && logZ < double.MinValue) return GetPointMessage();
double logPhiL = X.GetLogProb(lowerBound);
double alphaL = d_p * Math.Exp(logPhiL - logZ);
double logPhiU = Gaussian.GetLogProb(upperBound, mx, vx);
double logPhiU = X.GetLogProb(upperBound);
double alphaU = d_p * Math.Exp(logPhiU - logZ);
double alphaX = alphaL - alphaU;
double betaX = alphaX * alphaX;
@ -610,6 +621,7 @@ namespace Microsoft.ML.Probabilistic.Factors
else betaL = (X.MeanTimesPrecision - lowerBound * X.Precision) * alphaL; // -(lowerBound - mx) / vx * alphaL;
if (Math.Abs(betaU) > Math.Abs(betaL)) betaX = (betaX + betaL) + betaU;
else betaX = (betaX + betaU) + betaL;
if (betaX > double.MaxValue && d_p == -1.0) return GetPointMessage();
return GaussianOp.GaussianFromAlphaBeta(X, alphaX, betaX, ForceProper);
}
}

Просмотреть файл

@ -70,16 +70,43 @@ namespace Microsoft.ML.Probabilistic.Factors
result.FixedParameters = func.FixedParameters;
result.IncludePrior = false;
double vf = func.Variance(x);
double my, vy;
y.GetMeanAndVariance(out my, out vy);
Vector kbx = func.FixedParameters.KernelOf_X_B(x);
Vector proj = func.FixedParameters.InvKernelOf_B_B * kbx;
double prec = 1.0 / (vy + vf - func.Var_B_B.QuadraticForm(proj));
result.InducingDist.Precision.SetToOuter(proj, proj);
result.InducingDist.Precision.Scale(prec);
result.InducingDist.MeanTimesPrecision.SetTo(proj);
result.InducingDist.MeanTimesPrecision.Scale(prec * my);
// To avoid computing Var_B_B:
// vf - func.Var_B_B.QuadraticForm(proj) = kxx - kbx'*beta*kbx - kbx'*inv(K)*Var_B_B*inv(K)*kbx
// = kxx - kbx'*(beta + inv(K)*Var_B_B*inv(K))*kbx
// = kxx - kbx'*(beta + inv(K)*(K - K*Beta*K)*inv(K))*kbx
// = kxx - kbx'*inv(K)*kbx
// Since Var_B_B = K - K*Beta*K
double kxx = func.FixedParameters.Prior.Variance(x);
if (y.Precision == 0)
{
result.InducingDist.Precision.SetAllElementsTo(0);
result.InducingDist.MeanTimesPrecision.SetTo(proj);
result.InducingDist.MeanTimesPrecision.Scale(y.MeanTimesPrecision);
}
else
{
double my, vy;
y.GetMeanAndVariance(out my, out vy);
double prec = 1.0 / (vy + kxx - kbx.Inner(proj));
//Console.WriteLine($"{vf - func.Var_B_B.QuadraticForm(proj)} {func.FixedParameters.Prior.Variance(x) - func.FixedParameters.InvKernelOf_B_B.QuadraticForm(kbx)}");
if (prec > double.MaxValue || prec < 0)
{
int i = proj.IndexOfMaximum();
result.InducingDist.Precision.SetAllElementsTo(0);
result.InducingDist.Precision[i, i] = double.PositiveInfinity;
result.InducingDist.MeanTimesPrecision.SetAllElementsTo(0);
result.InducingDist.MeanTimesPrecision[i] = my;
}
else
{
result.InducingDist.Precision.SetToOuter(proj, proj);
result.InducingDist.Precision.Scale(prec);
result.InducingDist.MeanTimesPrecision.SetTo(proj);
result.InducingDist.MeanTimesPrecision.Scale(prec * my);
}
}
result.ClearCachedValues();
return result;
}

Просмотреть файл

@ -6,6 +6,7 @@ using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
using Assert = Microsoft.ML.Probabilistic.Tests.AssertHelper;
using Microsoft.ML.Probabilistic.Utilities;
@ -2144,7 +2145,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
Assert.True(0 <= MMath.NormalCdfIntegral(213393529.2046707, -213393529.2046707, -1, 7.2893668811495072E-10).Mantissa);
Assert.True(0 < MMath.NormalCdfIntegral(-0.42146853220760722, 0.42146843802130329, -0.99999999999999989, 6.2292398855983019E-09).Mantissa);
foreach (var x in OperatorTests.Doubles())
Parallel.ForEach (OperatorTests.Doubles(), x =>
{
foreach (var y in OperatorTests.Doubles())
{
@ -2153,7 +2154,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
MMath.NormalCdfIntegral(x, y, r);
}
}
}
});
}
internal void NormalCdfIntegralTest2()

Просмотреть файл

@ -265,9 +265,9 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.Equal(g, Gaussian.PointMass(double.PositiveInfinity));
g.SetMeanAndPrecision(1e4, 1e306);
Assert.Equal(g, Gaussian.PointMass(1e4));
Assert.Equal(new Gaussian(1e4, 1E-306), Gaussian.PointMass(1e4));
Assert.Equal(new Gaussian(1e-155, 1E-312), Gaussian.PointMass(1e-155));
Assert.Equal(Gaussian.FromMeanAndPrecision(1e4, double.MaxValue / 1e4), g);
Assert.Equal(Gaussian.FromMeanAndPrecision(1e4, double.MaxValue/1e4), new Gaussian(1e4, 1E-306));
Assert.Equal(Gaussian.PointMass(1e-155), new Gaussian(1e-155, 1E-312));
Gaussian.FromNatural(1, 1e-309).GetMeanAndVarianceImproper(out m, out v);
if(v > double.MaxValue)
Assert.Equal(0, m);

Просмотреть файл

@ -57,6 +57,8 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void LargestDoubleProductTest2()
{
MMath.LargestDoubleProduct(1.0000000000000005E-09, 1.0000000000000166E-300);
MMath.LargestDoubleProduct(1.0000000000000005E-09, -1.0000000000000166E-300);
MMath.LargestDoubleProduct(0.00115249439895759, 4.9187693503017E-319);
MMath.LargestDoubleProduct(0.00115249439895759, -4.9187693503017E-319);
}
@ -103,6 +105,21 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.True(double.IsPositiveInfinity(a) || MMath.NextDouble(a) - b > sum);
}
[Fact]
public void PrecisionAverageConditional_Point_IsIncreasing()
{
foreach (var precision in DoublesAtLeastZero())
{
foreach (var yv in DoublesAtLeastZero())
{
var result0 = GaussianOp.PrecisionAverageConditional_Point(0, yv, precision);
var result = GaussianOp.PrecisionAverageConditional_Point(MMath.NextDouble(0), yv, precision);
Assert.True(result.Rate >= result0.Rate);
//Trace.WriteLine($"precision={precision} yv={yv}: {result0.Rate} {result.Rate}");
}
}
}
// Test inference on a model where precision is scaled.
internal void GammaProductTest()
{
@ -1260,10 +1277,10 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void GammaLower_IsIncreasingInX()
{
foreach (double a in DoublesGreaterThanZero())
Parallel.ForEach (DoublesGreaterThanZero(), a =>
{
IsIncreasingForAtLeastZero(x => MMath.GammaLower(a, x));
}
});
}
[Fact]
@ -1279,10 +1296,10 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void GammaLower_IsDecreasingInA()
{
foreach (double x in DoublesAtLeastZero())
Parallel.ForEach(DoublesAtLeastZero(), x =>
{
IsIncreasingForAtLeastZero(a => -MMath.GammaLower(a + double.Epsilon, x));
}
});
}
[Fact]
@ -2063,7 +2080,7 @@ zL = (L - mx)*sqrt(prec)
yield return MMath.NextDouble(0);
yield return MMath.PreviousDouble(double.PositiveInfinity);
yield return double.PositiveInfinity;
for (int i = 0; i <= 100; i++)
for (int i = 0; i <= 300; i++)
{
double bigValue = System.Math.Pow(10, i);
yield return -bigValue;
@ -2082,6 +2099,11 @@ zL = (L - mx)*sqrt(prec)
return Doubles().Where(value => value > 0);
}
public static IEnumerable<double> DoublesLessThanZero()
{
return Doubles().Where(value => value < 0);
}
public static IEnumerable<double> DoublesAtLeastZero()
{
return Doubles().Where(value => value >= 0);
@ -2137,11 +2159,10 @@ zL = (L - mx)*sqrt(prec)
double precMaxUlpErrorLowerBound = 0;
double precMaxUlpErrorUpperBound = 0;
Bernoulli precMaxUlpErrorIsBetween = new Bernoulli();
foreach (var isBetween in new[] { Bernoulli.PointMass(true), Bernoulli.PointMass(false), new Bernoulli(0.1) })
foreach(var isBetween in new[] { Bernoulli.PointMass(true), Bernoulli.PointMass(false), new Bernoulli(0.1) })
{
foreach (var lowerBound in Doubles())
Parallel.ForEach (DoublesLessThanZero(), lowerBound =>
{
if (lowerBound >= 0) continue;
//Console.WriteLine($"isBetween = {isBetween}, lowerBound = {lowerBound:r}");
foreach (var upperBound in new[] { -lowerBound })// UpperBounds(lowerBound))
{
@ -2182,7 +2203,7 @@ zL = (L - mx)*sqrt(prec)
}
}
}
}
});
}
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}, isBetween = {meanMaxUlpErrorIsBetween}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}, isBetween = {precMaxUlpErrorIsBetween}");
@ -2215,7 +2236,7 @@ zL = (L - mx)*sqrt(prec)
IEnumerable<double> lowerBounds = Doubles();
// maxUlpError = 22906784576, lowerBound = -0.010000000000000002, upperBound = -0.01
lowerBounds = new double[] { 0 };
foreach (double lowerBound in lowerBounds)
Parallel.ForEach(lowerBounds, lowerBound =>
{
foreach (double upperBound in new double[] { 1 })
//Parallel.ForEach(UpperBounds(lowerBound), upperBound =>
@ -2257,7 +2278,7 @@ zL = (L - mx)*sqrt(prec)
}
Trace.WriteLine($"maxUlpError = {maxUlpError}, lowerBound = {maxUlpErrorLowerBound:r}, upperBound = {maxUlpErrorUpperBound:r}");
}//);
}
});
Assert.True(maxUlpError < 1e3);
}
@ -2280,9 +2301,8 @@ zL = (L - mx)*sqrt(prec)
{
Trace.WriteLine($"lowerBound = {lowerBound:r}, upperBound = {upperBound:r}");
//foreach (var x in new[] { Gaussian.FromNatural(-0.1, 0.010000000000000002) })// Gaussians())
foreach (var x in Gaussians())
Parallel.ForEach (Gaussians().Where(g => !g.IsPointMass), x =>
{
if (x.IsPointMass) continue;
double mx = x.GetMean();
Gaussian toX = DoubleIsBetweenOp.XAverageConditional(isBetween, x, lowerBound, upperBound);
Gaussian xPost;
@ -2341,7 +2361,7 @@ zL = (L - mx)*sqrt(prec)
meanMaxUlpError = meanUlpDiff;
meanMaxUlpErrorLowerBound = lowerBound;
meanMaxUlpErrorUpperBound = upperBound;
//Assert.True(meanUlpDiff < 1e16);
Assert.True(meanUlpDiff < 1e16);
}
double variance2 = xPost2.GetVariance();
double precError2 = MMath.Ulp(xPost2.Precision);
@ -2356,19 +2376,19 @@ zL = (L - mx)*sqrt(prec)
precMaxUlpError = ulpDiff;
precMaxUlpErrorLowerBound = lowerBound;
precMaxUlpErrorUpperBound = upperBound;
//Assert.True(precMaxUlpError < 1e15);
Assert.True(precMaxUlpError < 1e16);
}
}
}
}
});
}//);
Trace.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}");
Trace.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}");
}
// meanMaxUlpError = 4271.53318407361, lowerBound = -1.0000000000000006E-12, upperBound = inf
// precMaxUlpError = 5008, lowerBound = 1E+40, upperBound = 1.00000001E+40
Assert.True(meanMaxUlpError < 1e4);
Assert.True(precMaxUlpError < 1e4);
Assert.True(meanMaxUlpError < 3);
Assert.True(precMaxUlpError < 1e16);
}
[Fact]
@ -2393,7 +2413,8 @@ zL = (L - mx)*sqrt(prec)
if (double.IsNegativeInfinity(lowerBound) && double.IsPositiveInfinity(upperBound))
center = 0;
//foreach (var x in new[] { Gaussian.FromNatural(0, 1e55) })// Gaussians())
foreach (var x in Gaussians())
//foreach (var x in Gaussians())
Parallel.ForEach(Gaussians(), x =>
{
double mx = x.GetMean();
Gaussian toX = DoubleIsBetweenOp.XAverageConditional(isBetween, x, lowerBound, upperBound);
@ -2433,13 +2454,16 @@ zL = (L - mx)*sqrt(prec)
// Increasing the prior mean should increase the posterior mean.
if (mean2 < mean)
{
// TEMPORARY
meanError = MMath.Ulp(mean);
meanError2 = MMath.Ulp(mean2);
double meanUlpDiff = (mean - mean2) / System.Math.Max(meanError, meanError2);
if (meanUlpDiff > meanMaxUlpError)
{
meanMaxUlpError = meanUlpDiff;
meanMaxUlpErrorLowerBound = lowerBound;
meanMaxUlpErrorUpperBound = upperBound;
Assert.True(meanUlpDiff < 1e5);
Assert.True(meanUlpDiff < 1e16);
}
}
// When mx > center, increasing prior mean should increase posterior precision.
@ -2451,19 +2475,19 @@ zL = (L - mx)*sqrt(prec)
precMaxUlpError = ulpDiff;
precMaxUlpErrorLowerBound = lowerBound;
precMaxUlpErrorUpperBound = upperBound;
Assert.True(precMaxUlpError < 1e6);
Assert.True(precMaxUlpError < 1e11);
}
}
}
}
});
}//);
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}");
}
// meanMaxUlpError = 104.001435643838, lowerBound = -1.0000000000000022E-37, upperBound = 9.9000000000000191E-36
// precMaxUlpError = 4960, lowerBound = -1.0000000000000026E-47, upperBound = -9.9999999000000263E-48
Assert.True(meanMaxUlpError < 1e3);
Assert.True(precMaxUlpError < 1e4);
Assert.True(meanMaxUlpError < 1e16);
Assert.True(precMaxUlpError < 1e11);
}
[Fact]
@ -2484,7 +2508,7 @@ zL = (L - mx)*sqrt(prec)
//Parallel.ForEach(UpperBounds(lowerBound), upperBound =>
{
Console.WriteLine($"lowerBound = {lowerBound:r}, upperBound = {upperBound:r}");
foreach (Gaussian x in Gaussians())
Parallel.ForEach (Gaussians(), x =>
{
Gaussian toX = DoubleIsBetweenOp.XAverageConditional(true, x, lowerBound, upperBound);
Gaussian xPost;
@ -2545,15 +2569,15 @@ zL = (L - mx)*sqrt(prec)
}
}
}
}
});
}//);
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}");
}
// meanMaxUlpError = 33584, lowerBound = -1E+30, upperBound = 9.9E+31
// precMaxUlpError = 256, lowerBound = -1, upperBound = 0
Assert.True(meanMaxUlpError < 1e5);
Assert.True(precMaxUlpError < 1e3);
Assert.True(meanMaxUlpError < 1e2);
Assert.True(precMaxUlpError < 1e2);
}
[Fact]
@ -2775,7 +2799,7 @@ weight * (tau + alphaX) + alphaX
// exact posterior mean = -0.00000000025231325216567798206492
// exact posterior variance = 0.00000000000000000003633802275634766987678763433333
expected = Gaussian.FromNatural(-6943505261.522269414985891, 17519383944062174805.8794215);
Assert.True(MaxUlpDiff(expected, result2) <= 5);
Assert.True(MaxUlpDiff(expected, result2) <= 7);
}
[Fact]
@ -2959,7 +2983,7 @@ weight * (tau + alphaX) + alphaX
Gaussian upperBound = Gaussian.FromNatural(412820.08287991461, 423722.55474045349);
for (int i = -10; i <= 0; i++)
{
lowerBound = Gaussian.FromNatural(17028358.45574614*System.Math.Pow(2,i), 9);
lowerBound = Gaussian.FromNatural(17028358.45574614 * System.Math.Pow(2, i), 9);
Gaussian toLowerBound = DoubleIsBetweenOp.LowerBoundAverageConditional_Slow(Bernoulli.PointMass(true), X, lowerBound, upperBound);
Trace.WriteLine($"{lowerBound}: {toLowerBound.MeanTimesPrecision} {toLowerBound.Precision}");
Assert.False(toLowerBound.IsPointMass);
@ -2980,7 +3004,7 @@ weight * (tau + alphaX) + alphaX
double lowerBoundMeanTimesPrecisionMaxUlpError = 0;
for (int i = 0; i < 200; i++)
{
Gaussian X = Gaussian.FromMeanAndPrecision(mean, System.Math.Pow(2, -i*1-20));
Gaussian X = Gaussian.FromMeanAndPrecision(mean, System.Math.Pow(2, -i * 1 - 20));
Gaussian toX = DoubleIsBetweenOp.XAverageConditional_Slow(Bernoulli.PointMass(true), X, lowerBound, upperBound);
Gaussian toLowerBound = toLowerBoundPrev;// DoubleIsBetweenOp.LowerBoundAverageConditional_Slow(Bernoulli.PointMass(true), X, lowerBound, upperBound);
Trace.WriteLine($"{i} {X}: {toX.MeanTimesPrecision:r} {toX.Precision:r} {toLowerBound.MeanTimesPrecision:r} {toLowerBound.Precision:r}");

Просмотреть файл

@ -761,8 +761,8 @@ namespace Microsoft.ML.Probabilistic.Tests
}
/// <summary>
/// Tests whether the point mass computation operations fails due to a stack overflow when
/// an automaton becomes sufficiently large.
/// Tests whether <see cref="StringAutomaton.TryComputePoint"/> fails due to a stack overflow
/// when an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
@ -791,7 +791,8 @@ namespace Microsoft.ML.Probabilistic.Tests
}
/// <summary>
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
/// Tests whether <see cref="StringAutomaton.Product"/> fails due to a stack overflow
/// when an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
@ -821,7 +822,8 @@ namespace Microsoft.ML.Probabilistic.Tests
}
/// <summary>
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
/// Tests whether <see cref="StringAutomaton.GetLogNormalizer"/> fails due to a stack overflow
/// when an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
@ -847,6 +849,41 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
/// <summary>
/// Tests whether <see cref="StringAutomaton.IsZero"/> fails due to a stack overflow
/// when an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
public void IsZeroLargeAutomaton()
{
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
{
var zeroAutomaton = MakeAutomaton(Weight.Zero);
var nonZeroAutomaton = MakeAutomaton(Weight.One);
Assert.True(zeroAutomaton.IsZero());
Assert.False(nonZeroAutomaton.IsZero());
}
StringAutomaton MakeAutomaton(Weight endWeight)
{
const int StateCount = 100_000;
var builder = new StringAutomaton.Builder();
var state = builder.Start;
for (var i = 1; i < StateCount; ++i)
{
state = state.AddTransition('a', Weight.One);
}
state.SetEndWeight(endWeight);
return builder.GetAutomaton();
}
}
/// <summary>
/// Tests creating an automaton from state and transition lists.
/// </summary>
@ -858,15 +895,16 @@ namespace Microsoft.ML.Probabilistic.Tests
var automaton1 = StringAutomaton.FromData(
new StringAutomaton.DataContainer(
0,
true,
false,
StringAutomaton.DeterminizationState.Unknown,
new[]
{
new StringAutomaton.StateData(0, 1, Weight.One),
new StringAutomaton.StateData(1, 0, Weight.One),
},
new[] { new StringAutomaton.Transition(DiscreteChar.PointMass('a'), Weight.One, 1) }));
{
new StringAutomaton.StateData(0, 1, Weight.One),
new StringAutomaton.StateData(1, 0, Weight.One),
},
new[] { new StringAutomaton.Transition(DiscreteChar.PointMass('a'), Weight.One, 1) },
isEpsilonFree: true,
usesGroups: false,
isDeterminized: null,
isZero: null));
StringInferenceTestUtilities.TestValue(automaton1, 1.0, string.Empty, "a");
StringInferenceTestUtilities.TestValue(automaton1, 0.0, "b");
@ -875,11 +913,12 @@ namespace Microsoft.ML.Probabilistic.Tests
var automaton2 = StringAutomaton.FromData(
new StringAutomaton.DataContainer(
0,
true,
false,
StringAutomaton.DeterminizationState.IsDeterminized,
new[] { new StringAutomaton.StateData(0, 0, Weight.Zero) },
Array.Empty<StringAutomaton.Transition>()));
Array.Empty<StringAutomaton.Transition>(),
isEpsilonFree: true,
usesGroups: false,
isDeterminized: true,
isZero: true));
Assert.True(automaton2.IsZero());
// Bad start state index
@ -887,44 +926,48 @@ namespace Microsoft.ML.Probabilistic.Tests
() => StringAutomaton.FromData(
new StringAutomaton.DataContainer(
0,
true,
false,
StringAutomaton.DeterminizationState.IsNonDeterminizable,
Array.Empty<StringAutomaton.StateData>(),
Array.Empty<StringAutomaton.Transition>())));
Array.Empty<StringAutomaton.Transition>(),
isEpsilonFree: true,
usesGroups: false,
isDeterminized: false,
isZero: true)));
// automaton is actually epsilon-free, but data says that it is
Assert.Throws<ArgumentException>(
() => StringAutomaton.FromData(
new StringAutomaton.DataContainer(
0,
false,
false,
StringAutomaton.DeterminizationState.Unknown,
new[] { new StringAutomaton.StateData(0, 0, Weight.Zero) },
Array.Empty<StringAutomaton.Transition>())));
Array.Empty<StringAutomaton.Transition>(),
isEpsilonFree: false,
usesGroups: false,
isDeterminized: null,
isZero: null)));
// automaton is not epsilon-free
Assert.Throws<ArgumentException>(
() => StringAutomaton.FromData(
new StringAutomaton.DataContainer(
0,
false,
false,
StringAutomaton.DeterminizationState.Unknown,
new[] { new StringAutomaton.StateData(0, 1, Weight.Zero) },
new[] { new StringAutomaton.Transition(Option.None, Weight.One, 1) })));
new[] { new StringAutomaton.Transition(Option.None, Weight.One, 1) },
isEpsilonFree: false,
usesGroups: false,
isDeterminized: null,
isZero: null)));
// Incorrect transition index
Assert.Throws<ArgumentException>(
() => StringAutomaton.FromData(
new StringAutomaton.DataContainer(
0,
new[] { new StringAutomaton.StateData(0, 1, Weight.One) },
new[] { new StringAutomaton.Transition(Option.None, Weight.One, 2) },
true,
false,
StringAutomaton.DeterminizationState.Unknown,
new[] { new StringAutomaton.StateData(0, 1, Weight.One) },
new[] { new StringAutomaton.Transition(Option.None, Weight.One, 2) })));
isDeterminized: null,
isZero: null)));
}
#region ToString tests
@ -1929,7 +1972,6 @@ namespace Microsoft.ML.Probabilistic.Tests
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
[Trait("Category", "OpenBug")]
public void Determinize10()
{
var builder = new StringAutomaton.Builder();