* Added DiscreteEstimator.Add(DiscreteEstimator)
This commit is contained in:
Tom Minka 2023-08-07 16:04:29 +01:00 коммит произвёл GitHub
Родитель f7852106ca
Коммит e139be6539
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 48 добавлений и 4 удалений

Просмотреть файл

@ -1042,5 +1042,21 @@ namespace Microsoft.ML.Probabilistic.Distributions
return prob.SumISq() - mean * mean;
}
}
/// <summary>
/// Creates a distribution with reduced support.
/// </summary>
/// <param name="lowerBound">The smallest allowed value.</param>
/// <param name="upperBound">The largest allowed value.</param>
/// <returns></returns>
public Discrete Truncate(int lowerBound, int upperBound)
{
Vector probs = this.prob.Subvector(0, Math.Min(upperBound + 1, this.Dimension));
if (lowerBound > 0)
{
probs.SetSubvector(0, Vector.Zero(lowerBound));
}
return new Discrete(probs);
}
}
}

Просмотреть файл

@ -46,6 +46,16 @@ namespace Microsoft.ML.Probabilistic.Distributions
return result;
}
/// <summary>
/// Adds all items in another estimator to this estimator.
/// </summary>
/// <param name="that">Another estimator</param>
public void Add(DiscreteEstimator that)
{
NProb.SetToSum(NProb, that.NProb);
N += that.N;
}
/// <summary>
/// Adds a discrete distribution item to the estimator
/// </summary>

Просмотреть файл

@ -24,15 +24,18 @@ namespace Microsoft.ML.Probabilistic.Tests.Core
Assert.Equal(0UL, new ulong[0].Sum());
}
/// <summary>
/// This also tests DiscreteEstimator.Add(DiscreteEstimator)
/// </summary>
[Fact]
public void TakeRandom_HasCorrectDistribution()
{
Rand.Restart(0);
int universe = 10;
int count = 5;
DiscreteEstimator discreteEstimator = new DiscreteEstimator(universe);
for (int trial = 0; trial < 10000; trial++)
var combinedEstimator = ParallelEnumerable.Range(0, 10000).Select(block =>
{
DiscreteEstimator discreteEstimator = new DiscreteEstimator(universe);
HashSet<int> set = new HashSet<int>();
foreach (var value in Enumerable.Range(0, universe).TakeRandom(count))
{
@ -40,8 +43,9 @@ namespace Microsoft.ML.Probabilistic.Tests.Core
set.Add(value);
discreteEstimator.Add(value);
}
}
var dist = discreteEstimator.GetDistribution(Discrete.Uniform(universe));
return discreteEstimator;
}).Aggregate((a, b) => { a.Add(b); return a; });
var dist = combinedEstimator.GetDistribution(Discrete.Uniform(universe));
for (int i = 0; i < universe; i++)
{
Assert.True(dist[i] > 0.08 && dist[i] < 0.12);

Просмотреть файл

@ -1412,6 +1412,20 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.True(d.IsPointMass && d.Point == 2);
}
[Fact]
public void DiscreteTruncateTest()
{
Discrete d = new Discrete(SparseVector.FromArray(0.1, 0.0, 0.3, 0.6));
Discrete truncated = d.Truncate(0, 2);
Assert.True(truncated.Dimension == 3);
truncated = d.Truncate(1, 2);
Assert.True(truncated.IsPointMass);
Assert.True(truncated.Point == 2);
truncated = d.Truncate(2, 2);
Assert.True(truncated.IsPointMass);
Assert.True(truncated.Point == 2);
}
[Fact]
[Trait("Category", "ModifiesGlobals")]
public void DirichletTest()