ModelCompiler.TraceAllMessages activates Tracing.Trace for all variables (#145)

Tracing.Trace uses System.Diagnostics.Trace.
Added DifficultyAbility.fs to TestFSharp.
This commit is contained in:
Tom Minka 2019-04-01 22:52:22 +01:00 коммит произвёл GitHub
Родитель 9854cb768a
Коммит 8398c1d4f7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 347 добавлений и 208 удалений

Просмотреть файл

@ -60,8 +60,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
}
/// <summary>
/// If true, all messages after each iteration will be logged to csv files in a folder named with the model name.
/// Use MatlabWriter.WriteFromCsvFolder to convert these to a mat file.
/// If true, all variables will implicitly have a TraceMessages attribute.
/// </summary>
public bool TraceAllMessages { get; set; }
@ -987,7 +986,10 @@ namespace Microsoft.ML.Probabilistic.Compiler
if (OptimiseInferenceCode)
tc.AddTransform(new DeadCode2Transform(this));
tc.AddTransform(new ParallelScheduleTransform());
if (TraceAllMessages)
// All messages after each iteration will be logged to csv files in a folder named with the model name.
// Use MatlabWriter.WriteFromCsvFolder to convert these to a mat file.
bool useTracingTransform = false;
if (TraceAllMessages && useTracingTransform)
tc.AddTransform(new TracingTransform());
bool useArraySizeTracing = false;
if (useArraySizeTracing)

Просмотреть файл

@ -838,8 +838,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}
// Support for the 'TraceMessages' and 'ListenToMessages' attributes
if (mi.channelDecl != null && (context.InputAttributes.Has<TraceMessages>(mi.channelDecl) ||
context.InputAttributes.Has<ListenToMessages>(mi.channelDecl)))
if (compiler.TraceAllMessages ||
(mi.channelDecl != null && (context.InputAttributes.Has<TraceMessages>(mi.channelDecl) ||
context.InputAttributes.Has<ListenToMessages>(mi.channelDecl))))
{
string msgText = msg.ToString();
// Look for TraceMessages attribute that matches this message
@ -851,7 +852,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (listenTo != null && listenTo.Containing != null && !msgText.Contains(listenTo.Containing)) listenTo = null;
if ((listenTo != null) || (trace != null))
if ((listenTo != null) || (trace != null) || compiler.TraceAllMessages)
{
IExpression textExpr = DebuggingSupport.GetExpressionTextExpression(msg);
if (listenTo != null)

Просмотреть файл

@ -24,6 +24,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
Dictionary<Set<IVariableDeclaration>, TableInfo> tableOfIndexVars = new Dictionary<Set<IVariableDeclaration>, TableInfo>();
MethodInfo writeMethod, writeBytesMethod, writeLineMethod, flushMethod, disposeMethodInfo;
IMethodDeclaration traceWriterMethod, disposeMethod;
public static bool UseToString = true;
public override string Name
{
@ -66,7 +67,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
var stmts = traceWriterMethod.Body.Statements;
string folder = td.Name;
stmts.Add(Builder.ExprStatement(Builder.StaticMethod(new Func<string, DirectoryInfo>(Directory.CreateDirectory), Builder.LiteralExpr(folder))));
IExpression pathExpr = Builder.BinaryExpr(BinaryOperator.Add, name, Builder.LiteralExpr(".csv"));
IExpression pathExpr = Builder.BinaryExpr(BinaryOperator.Add, name, Builder.LiteralExpr(UseToString ? ".tsv" : ".csv"));
pathExpr = Builder.BinaryExpr(BinaryOperator.Add, Builder.LiteralExpr(folder + "/"), pathExpr);
var writerDecl = Builder.VarDecl("writer", typeof(StreamWriter));
var ctorExpr = Builder.NewObject(typeof(StreamWriter), pathExpr);
@ -116,12 +117,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
StringBuilder header = new StringBuilder();
header.Append("iteration");
output2.Add(GetWriteStatement(writer, Builder.VarRefExpr(iterationVar)));
var delimiter = GetWriteStatement(writer, Builder.LiteralExpr(","));
string delimiter = UseToString ? "\t" : ",";
var writeDelimiter = GetWriteStatement(writer, Builder.LiteralExpr(delimiter));
foreach (var indexVar in table.indexVars)
{
header.Append(",");
header.Append(delimiter);
header.Append(indexVar.Name);
output2.Add(delimiter);
output2.Add(writeDelimiter);
output2.Add(GetWriteStatement(writer, Builder.VarRefExpr(indexVar)));
}
foreach (var messageBaseExpr in table.messageExprs)
@ -141,9 +143,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
conditions.Push(Builder.BinaryExpr(BinaryOperator.IdentityEquality, messageExpr, Builder.LiteralExpr(null)));
if (messageExpr.GetExpressionType().IsPrimitive)
{
header.Append(",");
header.Append(delimiter);
header.Append(varInfo.Name);
output2.Add(delimiter);
output2.Add(writeDelimiter);
output2.Add(GetWriteStatement(writer, AddConditions(messageExpr, conditions)));
}
else
@ -151,9 +153,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
Dictionary<string, IExpression> dict = GetProperties(messageExpr);
foreach (var entry in dict)
{
header.Append(",");
header.Append(delimiter);
header.Append(varInfo.Name + entry.Key);
output2.Add(delimiter);
output2.Add(writeDelimiter);
output2.Add(GetWriteStatement(writer, AddConditions(entry.Value, conditions)));
}
}
@ -170,25 +172,33 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
Dictionary<string, IExpression> dict = new Dictionary<string, IExpression>();
Type type = expr.GetExpressionType();
Type[] faces = type.GetInterfaces();
bool hasGetMean = false;
bool hasGetVariance = false;
foreach (Type face in faces)
if (UseToString)
{
if (face.Name == "CanGetMean`1")
hasGetMean = true;
else if (face.Name == "CanGetVariance`1")
hasGetVariance = true;
var toStringMethod = type.GetMethod("ToString", new Type[0]);
dict["ToString"] = Builder.Method(expr, toStringMethod);
}
if (hasGetMean)
else
{
var meanMethod = type.GetMethod("GetMean", new Type[0]);
dict["Mean"] = Builder.Method(expr, meanMethod);
}
if (hasGetVariance)
{
var varianceMethod = type.GetMethod("GetVariance", new Type[0]);
dict["Variance"] = Builder.Method(expr, varianceMethod);
Type[] faces = type.GetInterfaces();
bool hasGetMean = false;
bool hasGetVariance = false;
foreach (Type face in faces)
{
if (face.Name == "CanGetMean`1")
hasGetMean = true;
else if (face.Name == "CanGetVariance`1")
hasGetVariance = true;
}
if (hasGetMean)
{
var meanMethod = type.GetMethod("GetMean", new Type[0]);
dict["Mean"] = Builder.Method(expr, meanMethod);
}
if (hasGetVariance)
{
var varianceMethod = type.GetMethod("GetVariance", new Type[0]);
dict["Variance"] = Builder.Method(expr, varianceMethod);
}
}
return dict;
}

Просмотреть файл

@ -20,7 +20,7 @@ namespace Microsoft.ML.Probabilistic.Factors
/// <returns><paramref name="input"/></returns>
public static T Trace<T>([IsReturned] T input, string text)
{
Debug.WriteLine(StringUtil.JoinColumns(text, ": ", input));
System.Diagnostics.Trace.WriteLine(StringUtil.JoinColumns(text, ": ", input));
return input;
}

Просмотреть файл

@ -0,0 +1,104 @@
namespace DifficultyAbilityExample
open System.Collections.Generic
open Microsoft.ML.Probabilistic
open Microsoft.ML.Probabilistic.FSharp
open Microsoft.ML.Probabilistic.Models
open Microsoft.ML.Probabilistic.Utilities
open Microsoft.ML.Probabilistic.Distributions
open Microsoft.ML.Probabilistic.Math
module DifficultyAbility =
let main() =
Rand.Restart(0);
let nQuestions = 100
let nSubjects = 40
let nChoices = 4
let abilityPrior = Gaussian(0.0, 1.0)
let difficultyPrior = Gaussian(0.0, 1.0)
let discriminationPrior = Gamma.FromMeanAndVariance(1.0, 0.01)
let Sample(nSubjects:int,nQuestions:int,nChoices:int,abilityPrior:Gaussian,difficultyPrior:Gaussian,discriminationPrior:Gamma)=
let ability= Util.ArrayInit( nSubjects, (fun _-> abilityPrior.Sample()))
let difficulty = Util.ArrayInit(nQuestions, (fun _ -> difficultyPrior.Sample()))
let discrimination = Util.ArrayInit(nQuestions, (fun _ -> discriminationPrior.Sample()))
let trueAnswer = Util.ArrayInit(nQuestions, (fun _ -> Rand.Int(nChoices)))
let response:int[][] = Array.zeroCreate nSubjects
for s in 0..(nSubjects-1) do
response.[s] <- Array.zeroCreate nQuestions
for q in 0..(nQuestions-1) do
let advantage = ability.[s] - difficulty.[q]
let noise = Gaussian.Sample(0.0, discrimination.[q])
let correct = (advantage > noise)
if (correct) then
response.[s].[q] <- trueAnswer.[q]
else
response.[s].[q] <- Rand.Int(nChoices)
(response, ability,difficulty,discrimination,trueAnswer)
let data,trueAbility,trueDifficulty,trueDiscrimination,trueTrueAnswer = Sample(nSubjects,nQuestions,nChoices,abilityPrior,difficultyPrior,discriminationPrior)
let question = Range(nQuestions).Named("question")
let subject = Range(nSubjects).Named("subject")
let choice = Range(nChoices).Named("choice")
//let response = Variable.Array(Variable.Array<int>(question), subject).Named("response")
let response = Variable.Array<VariableArray<int>, int [][]>(Variable.Array<int>(question), subject).Named("response")
response.ObservedValue <- data
let ability = Variable.Array<double>(subject).Named("ability")
Variable.ForeachBlock subject ( fun s -> ability.[s] <- Variable.Random(abilityPrior) )
let difficulty = Variable.Array<double>(question).Named("difficulty")
Variable.ForeachBlock question ( fun q -> difficulty.[q] <- Variable.Random(difficultyPrior) )
let discrimination = Variable.Array<double>(question).Named("discrimination")
Variable.ForeachBlock question ( fun q -> discrimination.[q] <- Variable.Random(discriminationPrior) )
let trueAnswer = Variable.Array<int>(question).Named("trueAnswer")
Variable.ForeachBlock question ( fun q -> trueAnswer.[q] <- Variable.DiscreteUniform(choice) )
Variable.ForeachBlock subject (fun s ->
Variable.ForeachBlock question (fun q ->
let advantage = (ability.[s] - difficulty.[q]).Named("advantage")
let advantageNoisy = Variable.GaussianFromMeanAndPrecision(advantage, discrimination.[q]).Named("advantageNoisy")
let correct = (advantageNoisy >> 0.0).Named("correct")
Variable.IfBlock correct (fun _->response.[s].[q] <- trueAnswer.[q]) (fun _->response.[s].[q] <- Variable.DiscreteUniform(choice))
()
)
)
let engine = InferenceEngine()
engine.NumberOfIterations <- 5
subject.AddAttribute(Models.Attributes.Sequential())
question.AddAttribute(Models.Attributes.Sequential())
let doMajorityVoting = false; // set this to 'true' to do majority voting
if doMajorityVoting then
ability.ObservedValue <- Util.ArrayInit(nSubjects, (fun i -> 0.0))
difficulty.ObservedValue <- Util.ArrayInit(nQuestions, (fun i -> 0.0))
discrimination.ObservedValue <- Util.ArrayInit(nQuestions, (fun i -> 0.0))
let trueAnswerPosterior = engine.Infer<IReadOnlyList<Discrete>>(trueAnswer)
let mutable numCorrect = 0
for q in 0..(nQuestions-1) do
let bestGuess = trueAnswerPosterior.[q].GetMode()
if (bestGuess = trueTrueAnswer.[q]) then
numCorrect<-numCorrect+1
let pctCorrect:float = 100.0 * (float numCorrect) / (float nQuestions)
printfn "%f TrueAnswers correct" pctCorrect
let difficultyPosterior = engine.Infer<IReadOnlyList<Gaussian>>(difficulty)
for q in 0..(System.Math.Min(nQuestions, 4)-1) do
printfn "difficulty[%i] = %A (sampled from %f)" q difficultyPosterior.[q] trueDifficulty.[q]
let discriminationPosterior = engine.Infer<IReadOnlyList<Gamma>>(discrimination)
for q in 0..(System.Math.Min(nQuestions, 4)-1) do
printfn "discrimination[%i] = %A (sampled from %f)" q discriminationPosterior.[q] trueDiscrimination.[q]
let abilityPosterior = engine.Infer<IReadOnlyList<Gaussian>>(ability)
for s in 0..(System.Math.Min(nQuestions, 4)-1) do
printfn "ability[%i] = %A (sampled from %f)" s abilityPosterior.[s] trueAbility.[s]

Просмотреть файл

@ -4,6 +4,10 @@
#light
open System
open System.Diagnostics
let coreAssemblyInfo = FileVersionInfo.GetVersionInfo(typeof<Object>.Assembly.Location)
printfn "%s .NET version %s mscorlib %s" (if Environment.Is64BitProcess then "64-bit" else "32-bit") (Environment.Version.ToString ()) coreAssemblyInfo.ProductVersion
//main Smoke Test .............................................
@ -13,5 +17,6 @@ let _ = GaussianRangesTutorial.ranges.rangesTestFunc()
let _ = ClinicalTrialTutorial.clinical.clinicalTestFunc()
let _ = BayesPointTutorial.bayes.bayesTestFunc()
let _ = MixtureGaussiansTutorial.mixture.mixtureTestFunc()
let _ = DifficultyAbilityExample.DifficultyAbility.main()
Console.ReadLine() |> ignore

Просмотреть файл

@ -7,6 +7,9 @@
<PlatformTarget>AnyCPU</PlatformTarget>
<Configurations>Debug;Release;DebugFull;DebugCore;ReleaseFull;ReleaseCore</Configurations>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugFull|AnyCPU'">
<Prefer32Bit>false</Prefer32Bit>
</PropertyGroup>
<Choose>
<When Condition="'$(Configuration)'=='DebugFull' OR '$(Configuration)'=='ReleaseFull'">
<PropertyGroup>
@ -35,6 +38,7 @@
<ProjectReference Include="..\..\src\FSharpWrapper\FSharpWrapper.fsproj" />
</ItemGroup>
<ItemGroup>
<Compile Include="DifficultyAbility.fs" />
<Compile Include="..\..\src\Shared\SharedAssemblyFileVersion.fs">
<Link>SharedAssemblyFileVersion.fs</Link>
</Compile>

Просмотреть файл

@ -0,0 +1,172 @@
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Serialization;
using Xunit;
namespace Microsoft.ML.Probabilistic.Tests.Core
{
using Assert = Xunit.Assert;
public class MatlabSerializationTests
{
[Fact]
//[DeploymentItem(@"Data\IRT2PL_10_250.mat", "Data")]
public void MatlabReaderTest2()
{
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
#if NETCORE
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
#endif
"Data", "IRT2PL_10_250.mat"));
Assert.Equal(5, dict.Count);
Matrix m = (Matrix)dict["Y"];
Assert.True(m.Rows == 250);
Assert.True(m.Cols == 10);
Assert.True(m[0, 1] == 0.0);
Assert.True(m[1, 0] == 1.0);
m = (Matrix)dict["difficulty"];
Assert.True(m.Rows == 10);
Assert.True(m.Cols == 1);
Assert.True(MMath.AbsDiff(m[1], 0.7773) < 2e-4);
}
[Fact]
////[DeploymentItem(@"Data\test.mat", "Data")]
public void MatlabReaderTest()
{
MatlabReaderTester(Path.Combine(
#if NETCORE
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
#endif
"Data", "test.mat"));
}
private void MatlabReaderTester(string fileName)
{
Dictionary<string, object> dict = MatlabReader.Read(fileName);
Assert.Equal(12, dict.Count);
Matrix aScalar = (Matrix)dict["aScalar"];
Assert.Equal(1, aScalar.Rows);
Assert.Equal(1, aScalar.Cols);
Assert.Equal(5.0, aScalar[0, 0]);
Assert.Equal("string", (string)dict["aString"]);
MatlabReader.ComplexMatrix aComplexScalar = (MatlabReader.ComplexMatrix)dict["aComplexScalar"];
Assert.Equal(5.0, aComplexScalar.Real[0, 0]);
Assert.Equal(3.0, aComplexScalar.Imaginary[0, 0]);
MatlabReader.ComplexMatrix aComplexVector = (MatlabReader.ComplexMatrix)dict["aComplexVector"];
Assert.Equal(1.0, aComplexVector.Real[0, 0]);
Assert.Equal(2.0, aComplexVector.Imaginary[0, 0]);
Assert.Equal(3.0, aComplexVector.Real[0, 1]);
Assert.Equal(4.0, aComplexVector.Imaginary[0, 1]);
var aStruct = (Dictionary<string, object>)dict["aStruct"];
Assert.Equal(2, aStruct.Count);
Assert.Equal(1.0, ((Matrix)aStruct["field1"])[0]);
Assert.Equal("two", (string)aStruct["field2"]);
object[,] aCell = (object[,])dict["aCell"];
Assert.Equal(1.0, ((Matrix)aCell[0, 0])[0]);
int[] intArray = (int[])dict["intArray"];
Assert.Equal(1, intArray[0]);
int[] uintArray = (int[])dict["uintArray"];
Assert.Equal(1, uintArray[0]);
bool[] aLogical = (bool[])dict["aLogical"];
Assert.True(aLogical[0]);
Assert.True(aLogical[1]);
Assert.False(aLogical[2]);
object[,,] aCell3D = (object[,,])dict["aCell3D"];
Assert.Null(aCell3D[0, 0, 0]);
Assert.Equal(7.0, ((Matrix)aCell3D[0, 0, 1])[0, 0]);
Assert.Equal(6.0, ((Matrix)aCell3D[0, 1, 0])[0, 0]);
double[,,,] array4D = (double[,,,])dict["array4D"];
Assert.Equal(4.0, array4D[0, 0, 1, 0]);
Assert.Equal(5.0, array4D[0, 0, 0, 1]);
long[] aLong = (long[])dict["aLong"];
Assert.Equal(1234567890123456789L, aLong[0]);
}
[Fact]
//[DeploymentItem(@"Data\test.mat", "Data")]
public void MatlabWriterTest()
{
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
#if NETCORE
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
#endif
"Data", "test.mat"));
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriterTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
foreach (var entry in dict)
{
writer.Write(entry.Key, entry.Value);
}
}
MatlabReaderTester(fileName);
}
[Fact]
public void MatlabWriteStringDictionaryTest()
{
Dictionary<string, string> dictString = new Dictionary<string, string>();
dictString["a"] = "a";
dictString["b"] = "b";
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringDictionaryTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("dictString", dictString);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
Dictionary<string, object> dict = (Dictionary<string, object>)vars["dictString"];
foreach (var entry in dictString)
{
Assert.Equal(dictString[entry.Key], dict[entry.Key]);
}
}
[Fact]
public void MatlabWriteStringListTest()
{
List<string> strings = new List<string>();
strings.Add("a");
strings.Add("b");
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringListTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("strings", strings);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
string[] array = (string[])vars["strings"];
for (int i = 0; i < array.Length; i++)
{
Assert.Equal(strings[i], array[i]);
}
}
[Fact]
public void MatlabWriteEmptyArrayTest()
{
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteEmptyArrayTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("ints", new int[0]);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
int[] ints = (int[])vars["ints"];
Assert.Empty(ints);
}
[Fact]
public void MatlabWriteNumericNameTest()
{
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteNumericNameTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("24", new int[0]);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
int[] ints = (int[])vars["24"];
Assert.Empty(ints);
}
}
}

Просмотреть файл

@ -454,13 +454,14 @@ namespace Microsoft.ML.Probabilistic.Tests
// generate data from the model
var hPrior = new Gaussian(hMean, hVariance);
var hSample = Util.ArrayInit(n, i => hPrior.Sample());
// When xMultiplier != 1, we have model mismatch so we want the learned xPrecision to decrease.
double xMultiplier = 5;
var xData = Util.ArrayInit(n, i => Gaussian.Sample(xMultiplier * hSample[i], xPrecisionTrue));
var yData = Util.ArrayInit(n, i => Gaussian.Sample(hSample[i], yPrecisionTrue));
x.ObservedValue = xData;
y.ObservedValue = yData;
// N(x; ah, vx) N(h; mh, vh) = N(h; mh + k*(x - a*mh), (1-ka)vh)
// N(x; a*h, vx) N(h; mh, vh) = N(h; mh + k*(x - a*mh), (1-ka)vh)
// where k = vh*a/(a^2*vh + vx)
// if x = a*x' then k(x - a*mh) = a*k(x' - mh)
// a*k = vh/(vh + vx/a^2)

Просмотреть файл

@ -209,6 +209,18 @@ namespace Microsoft.ML.Probabilistic.Tests
//(new ModelTests()).CoinRunLengths();
}
[Fact]
public void TraceAllMessagesTest()
{
Variable<double> x = Variable.GaussianFromMeanAndPrecision(0, 1);
Variable.ConstrainPositive(x);
Variable.ConstrainPositive(x);
InferenceEngine engine = new InferenceEngine();
engine.Compiler.TraceAllMessages = true;
engine.Infer(x);
}
[Fact]
public void MarginalWrongDistributionError()
{

Просмотреть файл

@ -20,10 +20,6 @@ namespace Microsoft.ML.Probabilistic.Tests
public class PsychTests
{
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning disable 162
#endif
internal void LogisticIrtTest()
{
Variable<int> numStudents = Variable.New<int>().Named("numStudents");
@ -40,7 +36,8 @@ namespace Microsoft.ML.Probabilistic.Tests
response[student, question] = Variable.BernoulliFromLogOdds(((ability[student] - difficulty[question]).Named("minus")*discrimination[question]).Named("product"));
bool[,] data;
double[] discriminationTrue = new double[0];
if (false)
bool useDummyData = false;
if (useDummyData)
{
data = new bool[4,2];
for (int i = 0; i < data.GetLength(0); i++)
@ -70,10 +67,6 @@ namespace Microsoft.ML.Probabilistic.Tests
Console.WriteLine(StringUtil.JoinColumns(engine.Infer(discrimination), " should be ", StringUtil.ToString(discriminationTrue)));
}
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning restore 162
#endif
public static bool[,] ConvertToBool(double[,] array)
{
int rows = array.GetLength(0);
@ -89,168 +82,6 @@ namespace Microsoft.ML.Probabilistic.Tests
return result;
}
[Fact]
//[DeploymentItem(@"Data\IRT2PL_10_250.mat", "Data")]
public void MatlabReaderTest2()
{
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
#if NETCORE
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
#endif
"Data", "IRT2PL_10_250.mat"));
Assert.Equal(5, dict.Count);
Matrix m = (Matrix) dict["Y"];
Assert.True(m.Rows == 250);
Assert.True(m.Cols == 10);
Assert.True(m[0, 1] == 0.0);
Assert.True(m[1, 0] == 1.0);
m = (Matrix) dict["difficulty"];
Assert.True(m.Rows == 10);
Assert.True(m.Cols == 1);
Assert.True(MMath.AbsDiff(m[1], 0.7773) < 2e-4);
}
[Fact]
////[DeploymentItem(@"Data\test.mat", "Data")]
public void MatlabReaderTest()
{
MatlabReaderTester(Path.Combine(
#if NETCORE
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
#endif
"Data", "test.mat"));
}
private void MatlabReaderTester(string fileName)
{
Dictionary<string, object> dict = MatlabReader.Read(fileName);
Assert.Equal(12, dict.Count);
Matrix aScalar = (Matrix) dict["aScalar"];
Assert.Equal(1, aScalar.Rows);
Assert.Equal(1, aScalar.Cols);
Assert.Equal(5.0, aScalar[0, 0]);
Assert.Equal("string", (string) dict["aString"]);
MatlabReader.ComplexMatrix aComplexScalar = (MatlabReader.ComplexMatrix) dict["aComplexScalar"];
Assert.Equal(5.0, aComplexScalar.Real[0, 0]);
Assert.Equal(3.0, aComplexScalar.Imaginary[0, 0]);
MatlabReader.ComplexMatrix aComplexVector = (MatlabReader.ComplexMatrix) dict["aComplexVector"];
Assert.Equal(1.0, aComplexVector.Real[0, 0]);
Assert.Equal(2.0, aComplexVector.Imaginary[0, 0]);
Assert.Equal(3.0, aComplexVector.Real[0, 1]);
Assert.Equal(4.0, aComplexVector.Imaginary[0, 1]);
var aStruct = (Dictionary<string, object>) dict["aStruct"];
Assert.Equal(2, aStruct.Count);
Assert.Equal(1.0, ((Matrix) aStruct["field1"])[0]);
Assert.Equal("two", (string) aStruct["field2"]);
object[,] aCell = (object[,]) dict["aCell"];
Assert.Equal(1.0, ((Matrix) aCell[0, 0])[0]);
int[] intArray = (int[]) dict["intArray"];
Assert.Equal(1, intArray[0]);
int[] uintArray = (int[])dict["uintArray"];
Assert.Equal(1, uintArray[0]);
bool[] aLogical = (bool[]) dict["aLogical"];
Assert.True(aLogical[0]);
Assert.True(aLogical[1]);
Assert.False(aLogical[2]);
object[,,] aCell3D = (object[,,]) dict["aCell3D"];
Assert.Null(aCell3D[0, 0, 0]);
Assert.Equal(7.0, ((Matrix) aCell3D[0, 0, 1])[0, 0]);
Assert.Equal(6.0, ((Matrix) aCell3D[0, 1, 0])[0, 0]);
double[,,,] array4D = (double[,,,]) dict["array4D"];
Assert.Equal(4.0, array4D[0, 0, 1, 0]);
Assert.Equal(5.0, array4D[0, 0, 0, 1]);
long[] aLong = (long[]) dict["aLong"];
Assert.Equal(1234567890123456789L, aLong[0]);
}
[Fact]
//[DeploymentItem(@"Data\test.mat", "Data")]
public void MatlabWriterTest()
{
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
#if NETCORE
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
#endif
"Data", "test.mat"));
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriterTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
foreach (var entry in dict)
{
writer.Write(entry.Key, entry.Value);
}
}
MatlabReaderTester(fileName);
}
[Fact]
public void MatlabWriteStringDictionaryTest()
{
Dictionary<string, string> dictString = new Dictionary<string, string>();
dictString["a"] = "a";
dictString["b"] = "b";
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringDictionaryTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("dictString", dictString);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
Dictionary<string, object> dict = (Dictionary<string, object>)vars["dictString"];
foreach (var entry in dictString)
{
Assert.Equal(dictString[entry.Key], dict[entry.Key]);
}
}
[Fact]
public void MatlabWriteStringListTest()
{
List<string> strings = new List<string>();
strings.Add("a");
strings.Add("b");
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringListTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("strings", strings);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
string[] array = (string[])vars["strings"];
for (int i = 0; i < array.Length; i++)
{
Assert.Equal(strings[i], array[i]);
}
}
[Fact]
public void MatlabWriteEmptyArrayTest()
{
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteEmptyArrayTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("ints", new int[0]);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
int[] ints = (int[])vars["ints"];
Assert.Empty(ints);
}
[Fact]
public void MatlabWriteNumericNameTest()
{
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteNumericNameTest{Environment.CurrentManagedThreadId}.mat";
using (MatlabWriter writer = new MatlabWriter(fileName))
{
writer.Write("24", new int[0]);
}
Dictionary<string, object> vars = MatlabReader.Read(fileName);
int[] ints = (int[])vars["24"];
Assert.Empty(ints);
}
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning disable 162
#endif
/// <summary>
/// Nonconjugate VMP crashes with improper message on the first iteration.
/// </summary>
@ -280,7 +111,8 @@ namespace Microsoft.ML.Probabilistic.Tests
//response.AddAttribute(new MarginalPrototype(new Gaussian()));
bool[,] data;
double[] discriminationTrue = new double[0];
if (false)
bool useDummyData = false;
if (useDummyData)
{
data = new bool[4,2];
for (int i = 0; i < data.GetLength(0); i++)
@ -315,10 +147,6 @@ namespace Microsoft.ML.Probabilistic.Tests
Console.WriteLine(marg[i].GetMean() + " \t " + discriminationTrue[i]);
}
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning restore 162
#endif
internal void LogisticIrtProductExpTest()
{
int numStudents = 20;