зеркало из https://github.com/dotnet/infer.git
Better handling of Sequential attribute (#279)
IncrementTransform handles GetJaggedItemsOp and GetDeepJaggedItemsOp. IndexingTransform gives a warning for unimplemented cases instead of throwing. Improved code doc for Damp functions. Changed "#if HAS_BINARY_FORMATTER" to "#if NETFULL" Removed Assert.Timeout from performance tests
This commit is contained in:
Родитель
90607a287f
Коммит
6186fff2a6
|
@ -313,7 +313,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
// forwardExpr = GetItemsOp<>.ItemsAverageConditional(backwardExpr[index], *, marginalExpr, indices, index, forwardExpr)
|
||||
// then when backwardExpr is updated, we insert the following statement:
|
||||
// MarginalIncrement(marginalExpr, forwardExpr, backwardExpr[index], indices, index)
|
||||
if (Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsOp<>), "ItemsAverageConditional"))
|
||||
if (Recognizer.IsStaticGenericMethod(imie, typeof(GetItemsOp<>), "ItemsAverageConditional") ||
|
||||
Recognizer.IsStaticGenericMethod(imie, typeof(GetJaggedItemsOp<>), "ItemsAverageConditional") ||
|
||||
Recognizer.IsStaticGenericMethod(imie, typeof(GetDeepJaggedItemsOp<>), "ItemsAverageConditional"))
|
||||
{
|
||||
IExpression backwardExpr = imie.Arguments[0];
|
||||
object backwardDecl = Recognizer.GetArrayDeclaration(backwardExpr);
|
||||
|
|
|
@ -353,6 +353,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
innerLoops.Add(innerLoop2);
|
||||
indexTarget = index2.Target;
|
||||
// This limit must match the number of handled cases below.
|
||||
if (innerLoops.Count == 3) break;
|
||||
}
|
||||
else
|
||||
break;
|
||||
|
|
|
@ -118,15 +118,15 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
public static class Damp
|
||||
{
|
||||
/// <summary>
|
||||
/// Copy a value and damp the backward message.
|
||||
/// Copy a value and damp the message from the copy to the original.
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="value"></param>
|
||||
/// <param name="stepsize">1.0 means no damping, 0.0 is infinite damping.</param>
|
||||
/// <returns></returns>
|
||||
/// <remarks>
|
||||
/// If you use this factor, be sure to increase the number of algorithm iterations appropriately.
|
||||
/// The number of iterations should increase according to the reciprocal of stepsize.
|
||||
/// If you use this factor, be sure to multiply your convergence tolerance by stepsize.
|
||||
/// Equivalently, use the same convergence tolerance but measure the difference between iterations spaced apart by 1/stepsize.
|
||||
/// </remarks>
|
||||
public static T Backward<T>([IsReturned] T value, double stepsize)
|
||||
{
|
||||
|
@ -134,15 +134,15 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Copy a value and damp the forward message.
|
||||
/// Copy a value and damp the message from the original to the copy.
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="value"></param>
|
||||
/// <param name="stepsize">1.0 means no damping, 0.0 is infinite damping.</param>
|
||||
/// <returns></returns>
|
||||
/// <remarks>
|
||||
/// If you use this factor, be sure to increase the number of algorithm iterations appropriately.
|
||||
/// The number of iterations should increase according to the reciprocal of stepsize.
|
||||
/// If you use this factor, be sure to multiply your convergence tolerance by stepsize.
|
||||
/// Equivalently, use the same convergence tolerance but measure the difference between iterations spaced apart by 1/stepsize.
|
||||
/// </remarks>
|
||||
public static T Forward<T>([IsReturned] T value, double stepsize)
|
||||
{
|
||||
|
|
|
@ -56,6 +56,7 @@ namespace TestApp
|
|||
//InferenceEngine.DefaultEngine.Compiler.CompilerChoice = Microsoft.ML.Probabilistic.Compiler.CompilerChoice.Roslyn;
|
||||
//InferenceEngine.DefaultEngine.Compiler.GenerateInMemory = false;
|
||||
InferenceEngine.DefaultEngine.Compiler.WriteSourceFiles = true;
|
||||
InferenceEngine.DefaultEngine.Compiler.IncludeDebugInformation = true;
|
||||
//InferenceEngine.DefaultEngine.Compiler.OptimiseInferenceCode = false;
|
||||
//InferenceEngine.DefaultEngine.Compiler.FreeMemory = false;
|
||||
//InferenceEngine.DefaultEngine.Compiler.ReturnCopies = false;
|
||||
|
|
|
@ -18,6 +18,9 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
{
|
||||
using GaussianArray = DistributionStructArray<Gaussian, double>;
|
||||
|
||||
/// <summary>
|
||||
/// Measure the speed of different implementations of the Bayes Point Machine.
|
||||
/// </summary>
|
||||
public class BpmSpeedTests
|
||||
{
|
||||
public class Instance
|
||||
|
@ -82,14 +85,15 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
#pragma warning disable 162
|
||||
#endif
|
||||
|
||||
static readonly string dataFolder = @"c:\Users\minka\Downloads\rcv1";
|
||||
|
||||
public static void Rcv1Test(double wVariance, double biasVariance)
|
||||
{
|
||||
int count = 0;
|
||||
if (false)
|
||||
{
|
||||
int maxFeatureIndex = 0;
|
||||
//TODO: change path
|
||||
foreach (Instance instance in new VwReader(@"c:\Users\minka\Downloads\rcv1\rcv1.train.vw.gz"))
|
||||
foreach (Instance instance in new VwReader(Path.Combine(dataFolder, "rcv1.train.vw.gz")))
|
||||
{
|
||||
count++;
|
||||
if (count%10000 == 0) Console.WriteLine(count);
|
||||
|
@ -105,18 +109,15 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var predict = new BpmPredict2();
|
||||
train.SetPriors(nf, wVariance, biasVariance);
|
||||
|
||||
//TODO: change path
|
||||
StreamWriter writer = new StreamWriter(@"c:\Users\minka\Downloads\rcv1\log.txt");
|
||||
StreamWriter writer = new StreamWriter(Path.Combine(dataFolder, "log.txt"));
|
||||
int errors = 0;
|
||||
//int errors2 = 0;
|
||||
//TODO: change path
|
||||
//StreamReader reader = new StreamReader(@"c:\Users\minka\Downloads\rcv1\preds.txt");
|
||||
//StreamReader reader = new StreamReader(Path.Combine(dataFolder, "preds.txt");
|
||||
// takes 92s to train
|
||||
// takes 74s just to read the data
|
||||
// takes 15s just to do 'wc' on the data
|
||||
// there are 781265 data points in train, 23149 in test
|
||||
//TODO: change path
|
||||
foreach (Instance instance in new VwReader(@"c:\Users\minka\Downloads\rcv1\rcv1.train.vw.gz"))
|
||||
foreach (Instance instance in new VwReader(Path.Combine(dataFolder, "rcv1.train.vw.gz")))
|
||||
{
|
||||
predict.SetPriors(train.wPost, train.biasPost);
|
||||
bool yPred = predict.Predict(instance);
|
||||
|
@ -135,16 +136,15 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
writer.Dispose();
|
||||
#if HAS_BINARY_FORMATTER
|
||||
// In the .NET 5.0 BinaryFormatter is obsolete
|
||||
// and would produce errors. This test code should be migrated.
|
||||
// See https://github.com/GrabYourPitchforks/docs/blob/bf_obsoletion_docs/docs/standard/serialization/resolving-binaryformatter-obsoletion-errors.md
|
||||
#if NETFULL
|
||||
// In the .NET 5.0 BinaryFormatter is obsolete
|
||||
// and would produce errors. This test code should be migrated.
|
||||
// See https://aka.ms/binaryformatter
|
||||
|
||||
if (true)
|
||||
{
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
//TODO: change path
|
||||
using (Stream stream = File.Create(@"c:\Users\minka\Downloads\rcv1\weights.bin"))
|
||||
using (Stream stream = File.Create(Path.Combine(dataFolder, "weights.bin")))
|
||||
{
|
||||
serializer.Serialize(stream, train.wPost);
|
||||
serializer.Serialize(stream, train.biasPost);
|
||||
|
@ -170,17 +170,16 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
#pragma warning restore 162
|
||||
#endif
|
||||
|
||||
#if HAS_BINARY_FORMATTER
|
||||
#if NETFULL
|
||||
public static void Rcv1Test2()
|
||||
{
|
||||
GaussianArray wPost;
|
||||
Gaussian biasPost;
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
//TODO: change path
|
||||
using (Stream stream = File.OpenRead(@"c:\Users\minka\Downloads\rcv1\weights.bin"))
|
||||
using (Stream stream = File.OpenRead(Path.Combine(dataFolder, "weights.bin")))
|
||||
{
|
||||
wPost = (GaussianArray) serializer.Deserialize(stream);
|
||||
biasPost = (Gaussian) serializer.Deserialize(stream);
|
||||
wPost = (GaussianArray)serializer.Deserialize(stream);
|
||||
biasPost = (Gaussian)serializer.Deserialize(stream);
|
||||
}
|
||||
if (true)
|
||||
{
|
||||
|
@ -192,8 +191,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
predict.SetPriors(wPost, biasPost);
|
||||
int count = 0;
|
||||
int errors = 0;
|
||||
//TODO: change path
|
||||
foreach (Instance instance in new VwReader(@"c:\Users\minka\Downloads\rcv1\rcv1.test.vw.gz"))
|
||||
foreach (Instance instance in new VwReader(Path.Combine(dataFolder, "rcv1.test.vw.gz")))
|
||||
{
|
||||
bool yPred = predict.Predict(instance);
|
||||
if (yPred != instance.label) errors++;
|
||||
|
@ -211,8 +209,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
double biasVariance = 10;
|
||||
train.SetPriors(nf, wVariance, biasVariance);
|
||||
int count = 0;
|
||||
//TODO: change path
|
||||
foreach (Instance instance in new VwReader(@"c:\Users\minka\Downloads\rcv1\rcv1.train.vw.gz"))
|
||||
foreach (Instance instance in new VwReader(Path.Combine(dataFolder, "rcv1.train.vw.gz")))
|
||||
{
|
||||
train.Train(instance);
|
||||
count++;
|
||||
|
|
|
@ -87,7 +87,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
mc.AssertEqualTo(mc2);
|
||||
}
|
||||
|
||||
#if HAS_BINARY_FORMATTER
|
||||
#if NETFULL
|
||||
[Fact]
|
||||
public void BinaryFormatterTest()
|
||||
{
|
||||
|
@ -109,7 +109,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
mc.AssertEqualTo(mc2);
|
||||
}
|
||||
|
||||
#if HAS_BINARY_FORMATTER
|
||||
#if NETFULL
|
||||
[Fact]
|
||||
public void VectorSerializeTests()
|
||||
{
|
||||
|
@ -311,7 +311,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
|
||||
#if HAS_BINARY_FORMATTER
|
||||
#if NETFULL
|
||||
private static T CloneBinaryFormatter<T>(T obj)
|
||||
{
|
||||
var bf = new BinaryFormatter();
|
||||
|
|
|
@ -1464,7 +1464,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Console.WriteLine(engine.Infer(mean));
|
||||
}
|
||||
|
||||
#if HAS_BINARY_FORMATTER
|
||||
#if NETFULL
|
||||
internal void BinarySerializationExample()
|
||||
{
|
||||
Dirichlet d = new Dirichlet(3.0, 1.0, 2.0);
|
||||
|
|
|
@ -2137,6 +2137,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
|
||||
// Gives the wrong answer due to an undetected deterministic loop.
|
||||
// Chooses one of two possible solutions instead of averaging them.
|
||||
[Fact]
|
||||
[Trait("Category", "OpenBug")]
|
||||
public void AndOrXorTest()
|
||||
|
@ -2146,6 +2147,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var xt = x & (~y);
|
||||
var yt = (~x) & y;
|
||||
var z = yt | xt;
|
||||
// This is an equivalent model that gets the right answer.
|
||||
//z = (x != y);
|
||||
z.ObservedValue = true;
|
||||
z.Name = "z";
|
||||
|
@ -2153,14 +2155,15 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Bernoulli xExpected = new Bernoulli(0.5);
|
||||
Bernoulli xActual = ie.Infer<Bernoulli>(x);
|
||||
Console.WriteLine("x = {0} should be {1}", xActual, xExpected);
|
||||
Bernoulli yExpected = new Bernoulli(0.5);
|
||||
Bernoulli yActual = ie.Infer<Bernoulli>(y);
|
||||
Console.WriteLine("y = {0} should be {1}", yActual, yExpected);
|
||||
Assert.True(xExpected.MaxDiff(xActual) < 1e-10);
|
||||
}
|
||||
|
||||
// An interesting failure case for belief propagation.
|
||||
// Fails due to slow convergence around a deterministic loop.
|
||||
// An interesting case for belief propagation.
|
||||
// Constant propagation would help to remove the final "and" factor and make the model simpler.
|
||||
[Fact]
|
||||
[Trait("Category", "OpenBug")]
|
||||
public void AndOrXorTest2()
|
||||
{
|
||||
Variable<bool> x = Variable.Bernoulli(0.5).Named("x");
|
||||
|
|
|
@ -679,7 +679,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
|
||||
#if HAS_BINARY_FORMATTER
|
||||
#if NETFULL
|
||||
if (false)
|
||||
{
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
|
|
|
@ -23,6 +23,12 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// </summary>
|
||||
public class StringInferencePerformanceTests
|
||||
{
|
||||
private void AssertTimeout(Action action, int timeout)
|
||||
{
|
||||
// Don't impose a time limit since runtimes are very inconsistent on Azure.
|
||||
action();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Measures automaton normalization performance.
|
||||
/// </summary>
|
||||
|
@ -31,7 +37,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void AutomatonNormalizationPerformance1()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
var builder = new StringAutomaton.Builder();
|
||||
var nextState = builder.Start.AddTransitionsForSequence("abc");
|
||||
|
@ -53,7 +59,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void AutomatonNormalizationPerformance2()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
var builder = new StringAutomaton.Builder();
|
||||
var nextState = builder.Start.AddTransitionsForSequence("abc");
|
||||
|
@ -80,7 +86,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void AutomatonNormalizationPerformance3()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
var builder = new StringAutomaton.Builder();
|
||||
builder.Start.AddSelfTransition('a', Weight.FromValue(0.5));
|
||||
|
@ -107,7 +113,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void RegexpBuildingPerformanceTest1()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
StringDistribution dist =
|
||||
StringDistribution.OneOf(
|
||||
|
@ -128,7 +134,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void RegexpBuildingPerformanceTest2()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
StringDistribution dist = StringDistribution.OneOf(StringDistribution.Lower(), StringDistribution.Upper());
|
||||
for (int i = 0; i < 3; ++i)
|
||||
|
@ -150,7 +156,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void RegexpBuildingPerformanceTest3()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
StringDistribution dist = StringFormatOp_RequireEveryPlaceholder_NoArgumentNames.FormatAverageConditional(
|
||||
StringDistribution.String("aaaaaaaaaaa"),
|
||||
|
@ -170,7 +176,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void StringFormatPerformanceTest1()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
Rand.Restart(777);
|
||||
|
||||
|
@ -216,7 +222,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Trait("Category", "StringInference")]
|
||||
public void StringFormatPerformanceTest2()
|
||||
{
|
||||
Assert.Timeout(() =>
|
||||
AssertTimeout(() =>
|
||||
{
|
||||
Rand.Restart(777);
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче