From 70ae46cd792333df856b0b88b34cd1cea68c10cc Mon Sep 17 00:00:00 2001 From: Ivan Korostelev Date: Tue, 28 May 2019 15:52:38 +0100 Subject: [PATCH] Fix DiscreteChar.Sample() (#153) Due to refactoring mistake `sampleProb / prob * intervalLength` turned into `sampleProb / (prob * intervalLength)` which is obviosly incorrect. Fixed that + added a rudimentary test for `DiscreteChar.Sample()` --- src/Runtime/Distributions/DiscreteChar.cs | 2 +- test/Tests/Strings/DiscreteCharTest.cs | 27 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/Runtime/Distributions/DiscreteChar.cs b/src/Runtime/Distributions/DiscreteChar.cs index 648a463b..7b501f50 100644 --- a/src/Runtime/Distributions/DiscreteChar.cs +++ b/src/Runtime/Distributions/DiscreteChar.cs @@ -859,7 +859,7 @@ namespace Microsoft.ML.Probabilistic.Distributions sampleProb -= prob.Value; if (sampleProb < 0) { - return (char)(interval.StartInclusive - sampleProb / (prob * intervalLength).Value); + return (char)(interval.StartInclusive - sampleProb / interval.Probability.Value); } } diff --git a/test/Tests/Strings/DiscreteCharTest.cs b/test/Tests/Strings/DiscreteCharTest.cs index f8248a50..eb759727 100644 --- a/test/Tests/Strings/DiscreteCharTest.cs +++ b/test/Tests/Strings/DiscreteCharTest.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Probabilistic.Math; + namespace Microsoft.ML.Probabilistic.Tests { using System; @@ -64,6 +66,31 @@ namespace Microsoft.ML.Probabilistic.Tests Assert.True(false); } + [Fact] + [Trait("Category", "StringInference")] + public void SampleFromUniformCharDistribution() + { + // Make test deterministic + Rand.Restart(7); + + // 10 chars in distribution + const int numChars = 10; + const int numSamples = 100000; + var dist = DiscreteChar.UniformInRanges("aj"); + + var hist = Vector.Zero(numChars); + for (var i = 0; i < numSamples; ++i) + { + hist[dist.Sample() - 'a'] += 1; + } + + hist = hist * (1.0 / numSamples); + var unif = Vector.Constant(numChars, 1.0 / numChars); + var maxDiff = hist.MaxDiff(unif); + + Assert.True(maxDiff < 0.01); + } + /// /// Tests the support of a character distribution. ///