зеркало из https://github.com/dotnet/infer.git
Added Discrete.Truncate (#446)
* Added DiscreteEstimator.Add(DiscreteEstimator)
This commit is contained in:
Родитель
f7852106ca
Коммит
e139be6539
|
@ -1042,5 +1042,21 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
return prob.SumISq() - mean * mean;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a distribution with reduced support.
|
||||
/// </summary>
|
||||
/// <param name="lowerBound">The smallest allowed value.</param>
|
||||
/// <param name="upperBound">The largest allowed value.</param>
|
||||
/// <returns></returns>
|
||||
public Discrete Truncate(int lowerBound, int upperBound)
|
||||
{
|
||||
Vector probs = this.prob.Subvector(0, Math.Min(upperBound + 1, this.Dimension));
|
||||
if (lowerBound > 0)
|
||||
{
|
||||
probs.SetSubvector(0, Vector.Zero(lowerBound));
|
||||
}
|
||||
return new Discrete(probs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,16 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds all items in another estimator to this estimator.
|
||||
/// </summary>
|
||||
/// <param name="that">Another estimator</param>
|
||||
public void Add(DiscreteEstimator that)
|
||||
{
|
||||
NProb.SetToSum(NProb, that.NProb);
|
||||
N += that.N;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds a discrete distribution item to the estimator
|
||||
/// </summary>
|
||||
|
|
|
@ -24,15 +24,18 @@ namespace Microsoft.ML.Probabilistic.Tests.Core
|
|||
Assert.Equal(0UL, new ulong[0].Sum());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This also tests DiscreteEstimator.Add(DiscreteEstimator)
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void TakeRandom_HasCorrectDistribution()
|
||||
{
|
||||
Rand.Restart(0);
|
||||
int universe = 10;
|
||||
int count = 5;
|
||||
DiscreteEstimator discreteEstimator = new DiscreteEstimator(universe);
|
||||
for (int trial = 0; trial < 10000; trial++)
|
||||
var combinedEstimator = ParallelEnumerable.Range(0, 10000).Select(block =>
|
||||
{
|
||||
DiscreteEstimator discreteEstimator = new DiscreteEstimator(universe);
|
||||
HashSet<int> set = new HashSet<int>();
|
||||
foreach (var value in Enumerable.Range(0, universe).TakeRandom(count))
|
||||
{
|
||||
|
@ -40,8 +43,9 @@ namespace Microsoft.ML.Probabilistic.Tests.Core
|
|||
set.Add(value);
|
||||
discreteEstimator.Add(value);
|
||||
}
|
||||
}
|
||||
var dist = discreteEstimator.GetDistribution(Discrete.Uniform(universe));
|
||||
return discreteEstimator;
|
||||
}).Aggregate((a, b) => { a.Add(b); return a; });
|
||||
var dist = combinedEstimator.GetDistribution(Discrete.Uniform(universe));
|
||||
for (int i = 0; i < universe; i++)
|
||||
{
|
||||
Assert.True(dist[i] > 0.08 && dist[i] < 0.12);
|
||||
|
|
|
@ -1412,6 +1412,20 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Assert.True(d.IsPointMass && d.Point == 2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DiscreteTruncateTest()
|
||||
{
|
||||
Discrete d = new Discrete(SparseVector.FromArray(0.1, 0.0, 0.3, 0.6));
|
||||
Discrete truncated = d.Truncate(0, 2);
|
||||
Assert.True(truncated.Dimension == 3);
|
||||
truncated = d.Truncate(1, 2);
|
||||
Assert.True(truncated.IsPointMass);
|
||||
Assert.True(truncated.Point == 2);
|
||||
truncated = d.Truncate(2, 2);
|
||||
Assert.True(truncated.IsPointMass);
|
||||
Assert.True(truncated.Point == 2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[Trait("Category", "ModifiesGlobals")]
|
||||
public void DirichletTest()
|
||||
|
|
Загрузка…
Ссылка в новой задаче