зеркало из https://github.com/dotnet/infer.git
ModelCompiler.TraceAllMessages activates Tracing.Trace for all variables (#145)
Tracing.Trace uses System.Diagnostics.Trace. Added DifficultyAbility.fs to TestFSharp.
This commit is contained in:
Родитель
9854cb768a
Коммит
8398c1d4f7
|
@ -60,8 +60,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// If true, all messages after each iteration will be logged to csv files in a folder named with the model name.
|
||||
/// Use MatlabWriter.WriteFromCsvFolder to convert these to a mat file.
|
||||
/// If true, all variables will implicitly have a TraceMessages attribute.
|
||||
/// </summary>
|
||||
public bool TraceAllMessages { get; set; }
|
||||
|
||||
|
@ -987,7 +986,10 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
if (OptimiseInferenceCode)
|
||||
tc.AddTransform(new DeadCode2Transform(this));
|
||||
tc.AddTransform(new ParallelScheduleTransform());
|
||||
if (TraceAllMessages)
|
||||
// All messages after each iteration will be logged to csv files in a folder named with the model name.
|
||||
// Use MatlabWriter.WriteFromCsvFolder to convert these to a mat file.
|
||||
bool useTracingTransform = false;
|
||||
if (TraceAllMessages && useTracingTransform)
|
||||
tc.AddTransform(new TracingTransform());
|
||||
bool useArraySizeTracing = false;
|
||||
if (useArraySizeTracing)
|
||||
|
|
|
@ -838,8 +838,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}
|
||||
|
||||
// Support for the 'TraceMessages' and 'ListenToMessages' attributes
|
||||
if (mi.channelDecl != null && (context.InputAttributes.Has<TraceMessages>(mi.channelDecl) ||
|
||||
context.InputAttributes.Has<ListenToMessages>(mi.channelDecl)))
|
||||
if (compiler.TraceAllMessages ||
|
||||
(mi.channelDecl != null && (context.InputAttributes.Has<TraceMessages>(mi.channelDecl) ||
|
||||
context.InputAttributes.Has<ListenToMessages>(mi.channelDecl))))
|
||||
{
|
||||
string msgText = msg.ToString();
|
||||
// Look for TraceMessages attribute that matches this message
|
||||
|
@ -851,7 +852,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (listenTo != null && listenTo.Containing != null && !msgText.Contains(listenTo.Containing)) listenTo = null;
|
||||
|
||||
|
||||
if ((listenTo != null) || (trace != null))
|
||||
if ((listenTo != null) || (trace != null) || compiler.TraceAllMessages)
|
||||
{
|
||||
IExpression textExpr = DebuggingSupport.GetExpressionTextExpression(msg);
|
||||
if (listenTo != null)
|
||||
|
|
|
@ -24,6 +24,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
Dictionary<Set<IVariableDeclaration>, TableInfo> tableOfIndexVars = new Dictionary<Set<IVariableDeclaration>, TableInfo>();
|
||||
MethodInfo writeMethod, writeBytesMethod, writeLineMethod, flushMethod, disposeMethodInfo;
|
||||
IMethodDeclaration traceWriterMethod, disposeMethod;
|
||||
public static bool UseToString = true;
|
||||
|
||||
public override string Name
|
||||
{
|
||||
|
@ -66,7 +67,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
var stmts = traceWriterMethod.Body.Statements;
|
||||
string folder = td.Name;
|
||||
stmts.Add(Builder.ExprStatement(Builder.StaticMethod(new Func<string, DirectoryInfo>(Directory.CreateDirectory), Builder.LiteralExpr(folder))));
|
||||
IExpression pathExpr = Builder.BinaryExpr(BinaryOperator.Add, name, Builder.LiteralExpr(".csv"));
|
||||
IExpression pathExpr = Builder.BinaryExpr(BinaryOperator.Add, name, Builder.LiteralExpr(UseToString ? ".tsv" : ".csv"));
|
||||
pathExpr = Builder.BinaryExpr(BinaryOperator.Add, Builder.LiteralExpr(folder + "/"), pathExpr);
|
||||
var writerDecl = Builder.VarDecl("writer", typeof(StreamWriter));
|
||||
var ctorExpr = Builder.NewObject(typeof(StreamWriter), pathExpr);
|
||||
|
@ -116,12 +117,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
StringBuilder header = new StringBuilder();
|
||||
header.Append("iteration");
|
||||
output2.Add(GetWriteStatement(writer, Builder.VarRefExpr(iterationVar)));
|
||||
var delimiter = GetWriteStatement(writer, Builder.LiteralExpr(","));
|
||||
string delimiter = UseToString ? "\t" : ",";
|
||||
var writeDelimiter = GetWriteStatement(writer, Builder.LiteralExpr(delimiter));
|
||||
foreach (var indexVar in table.indexVars)
|
||||
{
|
||||
header.Append(",");
|
||||
header.Append(delimiter);
|
||||
header.Append(indexVar.Name);
|
||||
output2.Add(delimiter);
|
||||
output2.Add(writeDelimiter);
|
||||
output2.Add(GetWriteStatement(writer, Builder.VarRefExpr(indexVar)));
|
||||
}
|
||||
foreach (var messageBaseExpr in table.messageExprs)
|
||||
|
@ -141,9 +143,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
conditions.Push(Builder.BinaryExpr(BinaryOperator.IdentityEquality, messageExpr, Builder.LiteralExpr(null)));
|
||||
if (messageExpr.GetExpressionType().IsPrimitive)
|
||||
{
|
||||
header.Append(",");
|
||||
header.Append(delimiter);
|
||||
header.Append(varInfo.Name);
|
||||
output2.Add(delimiter);
|
||||
output2.Add(writeDelimiter);
|
||||
output2.Add(GetWriteStatement(writer, AddConditions(messageExpr, conditions)));
|
||||
}
|
||||
else
|
||||
|
@ -151,9 +153,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
Dictionary<string, IExpression> dict = GetProperties(messageExpr);
|
||||
foreach (var entry in dict)
|
||||
{
|
||||
header.Append(",");
|
||||
header.Append(delimiter);
|
||||
header.Append(varInfo.Name + entry.Key);
|
||||
output2.Add(delimiter);
|
||||
output2.Add(writeDelimiter);
|
||||
output2.Add(GetWriteStatement(writer, AddConditions(entry.Value, conditions)));
|
||||
}
|
||||
}
|
||||
|
@ -170,25 +172,33 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
Dictionary<string, IExpression> dict = new Dictionary<string, IExpression>();
|
||||
Type type = expr.GetExpressionType();
|
||||
Type[] faces = type.GetInterfaces();
|
||||
bool hasGetMean = false;
|
||||
bool hasGetVariance = false;
|
||||
foreach (Type face in faces)
|
||||
if (UseToString)
|
||||
{
|
||||
if (face.Name == "CanGetMean`1")
|
||||
hasGetMean = true;
|
||||
else if (face.Name == "CanGetVariance`1")
|
||||
hasGetVariance = true;
|
||||
var toStringMethod = type.GetMethod("ToString", new Type[0]);
|
||||
dict["ToString"] = Builder.Method(expr, toStringMethod);
|
||||
}
|
||||
if (hasGetMean)
|
||||
else
|
||||
{
|
||||
var meanMethod = type.GetMethod("GetMean", new Type[0]);
|
||||
dict["Mean"] = Builder.Method(expr, meanMethod);
|
||||
}
|
||||
if (hasGetVariance)
|
||||
{
|
||||
var varianceMethod = type.GetMethod("GetVariance", new Type[0]);
|
||||
dict["Variance"] = Builder.Method(expr, varianceMethod);
|
||||
Type[] faces = type.GetInterfaces();
|
||||
bool hasGetMean = false;
|
||||
bool hasGetVariance = false;
|
||||
foreach (Type face in faces)
|
||||
{
|
||||
if (face.Name == "CanGetMean`1")
|
||||
hasGetMean = true;
|
||||
else if (face.Name == "CanGetVariance`1")
|
||||
hasGetVariance = true;
|
||||
}
|
||||
if (hasGetMean)
|
||||
{
|
||||
var meanMethod = type.GetMethod("GetMean", new Type[0]);
|
||||
dict["Mean"] = Builder.Method(expr, meanMethod);
|
||||
}
|
||||
if (hasGetVariance)
|
||||
{
|
||||
var varianceMethod = type.GetMethod("GetVariance", new Type[0]);
|
||||
dict["Variance"] = Builder.Method(expr, varianceMethod);
|
||||
}
|
||||
}
|
||||
return dict;
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
/// <returns><paramref name="input"/></returns>
|
||||
public static T Trace<T>([IsReturned] T input, string text)
|
||||
{
|
||||
Debug.WriteLine(StringUtil.JoinColumns(text, ": ", input));
|
||||
System.Diagnostics.Trace.WriteLine(StringUtil.JoinColumns(text, ": ", input));
|
||||
return input;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
namespace DifficultyAbilityExample
|
||||
|
||||
open System.Collections.Generic
|
||||
open Microsoft.ML.Probabilistic
|
||||
open Microsoft.ML.Probabilistic.FSharp
|
||||
open Microsoft.ML.Probabilistic.Models
|
||||
open Microsoft.ML.Probabilistic.Utilities
|
||||
open Microsoft.ML.Probabilistic.Distributions
|
||||
open Microsoft.ML.Probabilistic.Math
|
||||
|
||||
module DifficultyAbility =
|
||||
let main() =
|
||||
Rand.Restart(0);
|
||||
let nQuestions = 100
|
||||
let nSubjects = 40
|
||||
let nChoices = 4
|
||||
let abilityPrior = Gaussian(0.0, 1.0)
|
||||
let difficultyPrior = Gaussian(0.0, 1.0)
|
||||
let discriminationPrior = Gamma.FromMeanAndVariance(1.0, 0.01)
|
||||
|
||||
let Sample(nSubjects:int,nQuestions:int,nChoices:int,abilityPrior:Gaussian,difficultyPrior:Gaussian,discriminationPrior:Gamma)=
|
||||
let ability= Util.ArrayInit( nSubjects, (fun _-> abilityPrior.Sample()))
|
||||
let difficulty = Util.ArrayInit(nQuestions, (fun _ -> difficultyPrior.Sample()))
|
||||
let discrimination = Util.ArrayInit(nQuestions, (fun _ -> discriminationPrior.Sample()))
|
||||
let trueAnswer = Util.ArrayInit(nQuestions, (fun _ -> Rand.Int(nChoices)))
|
||||
let response:int[][] = Array.zeroCreate nSubjects
|
||||
for s in 0..(nSubjects-1) do
|
||||
response.[s] <- Array.zeroCreate nQuestions
|
||||
for q in 0..(nQuestions-1) do
|
||||
let advantage = ability.[s] - difficulty.[q]
|
||||
let noise = Gaussian.Sample(0.0, discrimination.[q])
|
||||
let correct = (advantage > noise)
|
||||
if (correct) then
|
||||
response.[s].[q] <- trueAnswer.[q]
|
||||
else
|
||||
response.[s].[q] <- Rand.Int(nChoices)
|
||||
(response, ability,difficulty,discrimination,trueAnswer)
|
||||
|
||||
|
||||
let data,trueAbility,trueDifficulty,trueDiscrimination,trueTrueAnswer = Sample(nSubjects,nQuestions,nChoices,abilityPrior,difficultyPrior,discriminationPrior)
|
||||
|
||||
let question = Range(nQuestions).Named("question")
|
||||
let subject = Range(nSubjects).Named("subject")
|
||||
let choice = Range(nChoices).Named("choice")
|
||||
|
||||
//let response = Variable.Array(Variable.Array<int>(question), subject).Named("response")
|
||||
let response = Variable.Array<VariableArray<int>, int [][]>(Variable.Array<int>(question), subject).Named("response")
|
||||
|
||||
response.ObservedValue <- data
|
||||
|
||||
let ability = Variable.Array<double>(subject).Named("ability")
|
||||
Variable.ForeachBlock subject ( fun s -> ability.[s] <- Variable.Random(abilityPrior) )
|
||||
let difficulty = Variable.Array<double>(question).Named("difficulty")
|
||||
Variable.ForeachBlock question ( fun q -> difficulty.[q] <- Variable.Random(difficultyPrior) )
|
||||
let discrimination = Variable.Array<double>(question).Named("discrimination")
|
||||
Variable.ForeachBlock question ( fun q -> discrimination.[q] <- Variable.Random(discriminationPrior) )
|
||||
let trueAnswer = Variable.Array<int>(question).Named("trueAnswer")
|
||||
Variable.ForeachBlock question ( fun q -> trueAnswer.[q] <- Variable.DiscreteUniform(choice) )
|
||||
|
||||
Variable.ForeachBlock subject (fun s ->
|
||||
Variable.ForeachBlock question (fun q ->
|
||||
let advantage = (ability.[s] - difficulty.[q]).Named("advantage")
|
||||
let advantageNoisy = Variable.GaussianFromMeanAndPrecision(advantage, discrimination.[q]).Named("advantageNoisy")
|
||||
let correct = (advantageNoisy >> 0.0).Named("correct")
|
||||
Variable.IfBlock correct (fun _->response.[s].[q] <- trueAnswer.[q]) (fun _->response.[s].[q] <- Variable.DiscreteUniform(choice))
|
||||
()
|
||||
)
|
||||
)
|
||||
|
||||
let engine = InferenceEngine()
|
||||
engine.NumberOfIterations <- 5
|
||||
subject.AddAttribute(Models.Attributes.Sequential())
|
||||
question.AddAttribute(Models.Attributes.Sequential())
|
||||
let doMajorityVoting = false; // set this to 'true' to do majority voting
|
||||
if doMajorityVoting then
|
||||
ability.ObservedValue <- Util.ArrayInit(nSubjects, (fun i -> 0.0))
|
||||
difficulty.ObservedValue <- Util.ArrayInit(nQuestions, (fun i -> 0.0))
|
||||
discrimination.ObservedValue <- Util.ArrayInit(nQuestions, (fun i -> 0.0))
|
||||
|
||||
let trueAnswerPosterior = engine.Infer<IReadOnlyList<Discrete>>(trueAnswer)
|
||||
|
||||
let mutable numCorrect = 0
|
||||
for q in 0..(nQuestions-1) do
|
||||
let bestGuess = trueAnswerPosterior.[q].GetMode()
|
||||
if (bestGuess = trueTrueAnswer.[q]) then
|
||||
numCorrect<-numCorrect+1
|
||||
|
||||
let pctCorrect:float = 100.0 * (float numCorrect) / (float nQuestions)
|
||||
printfn "%f TrueAnswers correct" pctCorrect
|
||||
|
||||
let difficultyPosterior = engine.Infer<IReadOnlyList<Gaussian>>(difficulty)
|
||||
|
||||
|
||||
for q in 0..(System.Math.Min(nQuestions, 4)-1) do
|
||||
printfn "difficulty[%i] = %A (sampled from %f)" q difficultyPosterior.[q] trueDifficulty.[q]
|
||||
|
||||
let discriminationPosterior = engine.Infer<IReadOnlyList<Gamma>>(discrimination)
|
||||
for q in 0..(System.Math.Min(nQuestions, 4)-1) do
|
||||
printfn "discrimination[%i] = %A (sampled from %f)" q discriminationPosterior.[q] trueDiscrimination.[q]
|
||||
|
||||
let abilityPosterior = engine.Infer<IReadOnlyList<Gaussian>>(ability)
|
||||
for s in 0..(System.Math.Min(nQuestions, 4)-1) do
|
||||
printfn "ability[%i] = %A (sampled from %f)" s abilityPosterior.[s] trueAbility.[s]
|
||||
|
|
@ -4,6 +4,10 @@
|
|||
|
||||
#light
|
||||
open System
|
||||
open System.Diagnostics
|
||||
|
||||
let coreAssemblyInfo = FileVersionInfo.GetVersionInfo(typeof<Object>.Assembly.Location)
|
||||
printfn "%s .NET version %s mscorlib %s" (if Environment.Is64BitProcess then "64-bit" else "32-bit") (Environment.Version.ToString ()) coreAssemblyInfo.ProductVersion
|
||||
|
||||
//main Smoke Test .............................................
|
||||
|
||||
|
@ -13,5 +17,6 @@ let _ = GaussianRangesTutorial.ranges.rangesTestFunc()
|
|||
let _ = ClinicalTrialTutorial.clinical.clinicalTestFunc()
|
||||
let _ = BayesPointTutorial.bayes.bayesTestFunc()
|
||||
let _ = MixtureGaussiansTutorial.mixture.mixtureTestFunc()
|
||||
let _ = DifficultyAbilityExample.DifficultyAbility.main()
|
||||
|
||||
Console.ReadLine() |> ignore
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
<PlatformTarget>AnyCPU</PlatformTarget>
|
||||
<Configurations>Debug;Release;DebugFull;DebugCore;ReleaseFull;ReleaseCore</Configurations>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugFull|AnyCPU'">
|
||||
<Prefer32Bit>false</Prefer32Bit>
|
||||
</PropertyGroup>
|
||||
<Choose>
|
||||
<When Condition="'$(Configuration)'=='DebugFull' OR '$(Configuration)'=='ReleaseFull'">
|
||||
<PropertyGroup>
|
||||
|
@ -35,6 +38,7 @@
|
|||
<ProjectReference Include="..\..\src\FSharpWrapper\FSharpWrapper.fsproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Compile Include="DifficultyAbility.fs" />
|
||||
<Compile Include="..\..\src\Shared\SharedAssemblyFileVersion.fs">
|
||||
<Link>SharedAssemblyFileVersion.fs</Link>
|
||||
</Compile>
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Serialization;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Tests.Core
|
||||
{
|
||||
using Assert = Xunit.Assert;
|
||||
|
||||
public class MatlabSerializationTests
|
||||
{
|
||||
[Fact]
|
||||
//[DeploymentItem(@"Data\IRT2PL_10_250.mat", "Data")]
|
||||
public void MatlabReaderTest2()
|
||||
{
|
||||
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
|
||||
#if NETCORE
|
||||
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
|
||||
#endif
|
||||
"Data", "IRT2PL_10_250.mat"));
|
||||
Assert.Equal(5, dict.Count);
|
||||
Matrix m = (Matrix)dict["Y"];
|
||||
Assert.True(m.Rows == 250);
|
||||
Assert.True(m.Cols == 10);
|
||||
Assert.True(m[0, 1] == 0.0);
|
||||
Assert.True(m[1, 0] == 1.0);
|
||||
m = (Matrix)dict["difficulty"];
|
||||
Assert.True(m.Rows == 10);
|
||||
Assert.True(m.Cols == 1);
|
||||
Assert.True(MMath.AbsDiff(m[1], 0.7773) < 2e-4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
////[DeploymentItem(@"Data\test.mat", "Data")]
|
||||
public void MatlabReaderTest()
|
||||
{
|
||||
MatlabReaderTester(Path.Combine(
|
||||
#if NETCORE
|
||||
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
|
||||
#endif
|
||||
"Data", "test.mat"));
|
||||
}
|
||||
|
||||
private void MatlabReaderTester(string fileName)
|
||||
{
|
||||
Dictionary<string, object> dict = MatlabReader.Read(fileName);
|
||||
Assert.Equal(12, dict.Count);
|
||||
Matrix aScalar = (Matrix)dict["aScalar"];
|
||||
Assert.Equal(1, aScalar.Rows);
|
||||
Assert.Equal(1, aScalar.Cols);
|
||||
Assert.Equal(5.0, aScalar[0, 0]);
|
||||
Assert.Equal("string", (string)dict["aString"]);
|
||||
MatlabReader.ComplexMatrix aComplexScalar = (MatlabReader.ComplexMatrix)dict["aComplexScalar"];
|
||||
Assert.Equal(5.0, aComplexScalar.Real[0, 0]);
|
||||
Assert.Equal(3.0, aComplexScalar.Imaginary[0, 0]);
|
||||
MatlabReader.ComplexMatrix aComplexVector = (MatlabReader.ComplexMatrix)dict["aComplexVector"];
|
||||
Assert.Equal(1.0, aComplexVector.Real[0, 0]);
|
||||
Assert.Equal(2.0, aComplexVector.Imaginary[0, 0]);
|
||||
Assert.Equal(3.0, aComplexVector.Real[0, 1]);
|
||||
Assert.Equal(4.0, aComplexVector.Imaginary[0, 1]);
|
||||
var aStruct = (Dictionary<string, object>)dict["aStruct"];
|
||||
Assert.Equal(2, aStruct.Count);
|
||||
Assert.Equal(1.0, ((Matrix)aStruct["field1"])[0]);
|
||||
Assert.Equal("two", (string)aStruct["field2"]);
|
||||
object[,] aCell = (object[,])dict["aCell"];
|
||||
Assert.Equal(1.0, ((Matrix)aCell[0, 0])[0]);
|
||||
int[] intArray = (int[])dict["intArray"];
|
||||
Assert.Equal(1, intArray[0]);
|
||||
int[] uintArray = (int[])dict["uintArray"];
|
||||
Assert.Equal(1, uintArray[0]);
|
||||
bool[] aLogical = (bool[])dict["aLogical"];
|
||||
Assert.True(aLogical[0]);
|
||||
Assert.True(aLogical[1]);
|
||||
Assert.False(aLogical[2]);
|
||||
object[,,] aCell3D = (object[,,])dict["aCell3D"];
|
||||
Assert.Null(aCell3D[0, 0, 0]);
|
||||
Assert.Equal(7.0, ((Matrix)aCell3D[0, 0, 1])[0, 0]);
|
||||
Assert.Equal(6.0, ((Matrix)aCell3D[0, 1, 0])[0, 0]);
|
||||
double[,,,] array4D = (double[,,,])dict["array4D"];
|
||||
Assert.Equal(4.0, array4D[0, 0, 1, 0]);
|
||||
Assert.Equal(5.0, array4D[0, 0, 0, 1]);
|
||||
long[] aLong = (long[])dict["aLong"];
|
||||
Assert.Equal(1234567890123456789L, aLong[0]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
//[DeploymentItem(@"Data\test.mat", "Data")]
|
||||
public void MatlabWriterTest()
|
||||
{
|
||||
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
|
||||
#if NETCORE
|
||||
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
|
||||
#endif
|
||||
"Data", "test.mat"));
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriterTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
foreach (var entry in dict)
|
||||
{
|
||||
writer.Write(entry.Key, entry.Value);
|
||||
}
|
||||
}
|
||||
MatlabReaderTester(fileName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteStringDictionaryTest()
|
||||
{
|
||||
Dictionary<string, string> dictString = new Dictionary<string, string>();
|
||||
dictString["a"] = "a";
|
||||
dictString["b"] = "b";
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringDictionaryTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("dictString", dictString);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
Dictionary<string, object> dict = (Dictionary<string, object>)vars["dictString"];
|
||||
foreach (var entry in dictString)
|
||||
{
|
||||
Assert.Equal(dictString[entry.Key], dict[entry.Key]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteStringListTest()
|
||||
{
|
||||
List<string> strings = new List<string>();
|
||||
strings.Add("a");
|
||||
strings.Add("b");
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringListTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("strings", strings);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
string[] array = (string[])vars["strings"];
|
||||
for (int i = 0; i < array.Length; i++)
|
||||
{
|
||||
Assert.Equal(strings[i], array[i]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteEmptyArrayTest()
|
||||
{
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteEmptyArrayTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("ints", new int[0]);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
int[] ints = (int[])vars["ints"];
|
||||
Assert.Empty(ints);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteNumericNameTest()
|
||||
{
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteNumericNameTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("24", new int[0]);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
int[] ints = (int[])vars["24"];
|
||||
Assert.Empty(ints);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -454,13 +454,14 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
// generate data from the model
|
||||
var hPrior = new Gaussian(hMean, hVariance);
|
||||
var hSample = Util.ArrayInit(n, i => hPrior.Sample());
|
||||
// When xMultiplier != 1, we have model mismatch so we want the learned xPrecision to decrease.
|
||||
double xMultiplier = 5;
|
||||
var xData = Util.ArrayInit(n, i => Gaussian.Sample(xMultiplier * hSample[i], xPrecisionTrue));
|
||||
var yData = Util.ArrayInit(n, i => Gaussian.Sample(hSample[i], yPrecisionTrue));
|
||||
x.ObservedValue = xData;
|
||||
y.ObservedValue = yData;
|
||||
|
||||
// N(x; ah, vx) N(h; mh, vh) = N(h; mh + k*(x - a*mh), (1-ka)vh)
|
||||
// N(x; a*h, vx) N(h; mh, vh) = N(h; mh + k*(x - a*mh), (1-ka)vh)
|
||||
// where k = vh*a/(a^2*vh + vx)
|
||||
// if x = a*x' then k(x - a*mh) = a*k(x' - mh)
|
||||
// a*k = vh/(vh + vx/a^2)
|
||||
|
|
|
@ -209,6 +209,18 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
//(new ModelTests()).CoinRunLengths();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TraceAllMessagesTest()
|
||||
{
|
||||
Variable<double> x = Variable.GaussianFromMeanAndPrecision(0, 1);
|
||||
Variable.ConstrainPositive(x);
|
||||
Variable.ConstrainPositive(x);
|
||||
|
||||
InferenceEngine engine = new InferenceEngine();
|
||||
engine.Compiler.TraceAllMessages = true;
|
||||
engine.Infer(x);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MarginalWrongDistributionError()
|
||||
{
|
||||
|
|
|
@ -20,10 +20,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
public class PsychTests
|
||||
{
|
||||
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
|
||||
#pragma warning disable 162
|
||||
#endif
|
||||
|
||||
internal void LogisticIrtTest()
|
||||
{
|
||||
Variable<int> numStudents = Variable.New<int>().Named("numStudents");
|
||||
|
@ -40,7 +36,8 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
response[student, question] = Variable.BernoulliFromLogOdds(((ability[student] - difficulty[question]).Named("minus")*discrimination[question]).Named("product"));
|
||||
bool[,] data;
|
||||
double[] discriminationTrue = new double[0];
|
||||
if (false)
|
||||
bool useDummyData = false;
|
||||
if (useDummyData)
|
||||
{
|
||||
data = new bool[4,2];
|
||||
for (int i = 0; i < data.GetLength(0); i++)
|
||||
|
@ -70,10 +67,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Console.WriteLine(StringUtil.JoinColumns(engine.Infer(discrimination), " should be ", StringUtil.ToString(discriminationTrue)));
|
||||
}
|
||||
|
||||
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
|
||||
#pragma warning restore 162
|
||||
#endif
|
||||
|
||||
public static bool[,] ConvertToBool(double[,] array)
|
||||
{
|
||||
int rows = array.GetLength(0);
|
||||
|
@ -89,168 +82,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
return result;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
//[DeploymentItem(@"Data\IRT2PL_10_250.mat", "Data")]
|
||||
public void MatlabReaderTest2()
|
||||
{
|
||||
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
|
||||
#if NETCORE
|
||||
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
|
||||
#endif
|
||||
"Data", "IRT2PL_10_250.mat"));
|
||||
Assert.Equal(5, dict.Count);
|
||||
Matrix m = (Matrix) dict["Y"];
|
||||
Assert.True(m.Rows == 250);
|
||||
Assert.True(m.Cols == 10);
|
||||
Assert.True(m[0, 1] == 0.0);
|
||||
Assert.True(m[1, 0] == 1.0);
|
||||
m = (Matrix) dict["difficulty"];
|
||||
Assert.True(m.Rows == 10);
|
||||
Assert.True(m.Cols == 1);
|
||||
Assert.True(MMath.AbsDiff(m[1], 0.7773) < 2e-4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
////[DeploymentItem(@"Data\test.mat", "Data")]
|
||||
public void MatlabReaderTest()
|
||||
{
|
||||
MatlabReaderTester(Path.Combine(
|
||||
#if NETCORE
|
||||
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
|
||||
#endif
|
||||
"Data", "test.mat"));
|
||||
}
|
||||
|
||||
private void MatlabReaderTester(string fileName)
|
||||
{
|
||||
Dictionary<string, object> dict = MatlabReader.Read(fileName);
|
||||
Assert.Equal(12, dict.Count);
|
||||
Matrix aScalar = (Matrix) dict["aScalar"];
|
||||
Assert.Equal(1, aScalar.Rows);
|
||||
Assert.Equal(1, aScalar.Cols);
|
||||
Assert.Equal(5.0, aScalar[0, 0]);
|
||||
Assert.Equal("string", (string) dict["aString"]);
|
||||
MatlabReader.ComplexMatrix aComplexScalar = (MatlabReader.ComplexMatrix) dict["aComplexScalar"];
|
||||
Assert.Equal(5.0, aComplexScalar.Real[0, 0]);
|
||||
Assert.Equal(3.0, aComplexScalar.Imaginary[0, 0]);
|
||||
MatlabReader.ComplexMatrix aComplexVector = (MatlabReader.ComplexMatrix) dict["aComplexVector"];
|
||||
Assert.Equal(1.0, aComplexVector.Real[0, 0]);
|
||||
Assert.Equal(2.0, aComplexVector.Imaginary[0, 0]);
|
||||
Assert.Equal(3.0, aComplexVector.Real[0, 1]);
|
||||
Assert.Equal(4.0, aComplexVector.Imaginary[0, 1]);
|
||||
var aStruct = (Dictionary<string, object>) dict["aStruct"];
|
||||
Assert.Equal(2, aStruct.Count);
|
||||
Assert.Equal(1.0, ((Matrix) aStruct["field1"])[0]);
|
||||
Assert.Equal("two", (string) aStruct["field2"]);
|
||||
object[,] aCell = (object[,]) dict["aCell"];
|
||||
Assert.Equal(1.0, ((Matrix) aCell[0, 0])[0]);
|
||||
int[] intArray = (int[]) dict["intArray"];
|
||||
Assert.Equal(1, intArray[0]);
|
||||
int[] uintArray = (int[])dict["uintArray"];
|
||||
Assert.Equal(1, uintArray[0]);
|
||||
bool[] aLogical = (bool[]) dict["aLogical"];
|
||||
Assert.True(aLogical[0]);
|
||||
Assert.True(aLogical[1]);
|
||||
Assert.False(aLogical[2]);
|
||||
object[,,] aCell3D = (object[,,]) dict["aCell3D"];
|
||||
Assert.Null(aCell3D[0, 0, 0]);
|
||||
Assert.Equal(7.0, ((Matrix) aCell3D[0, 0, 1])[0, 0]);
|
||||
Assert.Equal(6.0, ((Matrix) aCell3D[0, 1, 0])[0, 0]);
|
||||
double[,,,] array4D = (double[,,,]) dict["array4D"];
|
||||
Assert.Equal(4.0, array4D[0, 0, 1, 0]);
|
||||
Assert.Equal(5.0, array4D[0, 0, 0, 1]);
|
||||
long[] aLong = (long[]) dict["aLong"];
|
||||
Assert.Equal(1234567890123456789L, aLong[0]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
//[DeploymentItem(@"Data\test.mat", "Data")]
|
||||
public void MatlabWriterTest()
|
||||
{
|
||||
Dictionary<string, object> dict = MatlabReader.Read(Path.Combine(
|
||||
#if NETCORE
|
||||
Path.GetDirectoryName(typeof(PsychTests).Assembly.Location), // work dir is not the one with Microsoft.ML.Probabilistic.Tests.dll on netcore and neither is .Location on netfull
|
||||
#endif
|
||||
"Data", "test.mat"));
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriterTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
foreach (var entry in dict)
|
||||
{
|
||||
writer.Write(entry.Key, entry.Value);
|
||||
}
|
||||
}
|
||||
MatlabReaderTester(fileName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteStringDictionaryTest()
|
||||
{
|
||||
Dictionary<string, string> dictString = new Dictionary<string, string>();
|
||||
dictString["a"] = "a";
|
||||
dictString["b"] = "b";
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringDictionaryTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("dictString", dictString);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
Dictionary<string, object> dict = (Dictionary<string, object>)vars["dictString"];
|
||||
foreach (var entry in dictString)
|
||||
{
|
||||
Assert.Equal(dictString[entry.Key], dict[entry.Key]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteStringListTest()
|
||||
{
|
||||
List<string> strings = new List<string>();
|
||||
strings.Add("a");
|
||||
strings.Add("b");
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteStringListTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("strings", strings);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
string[] array = (string[])vars["strings"];
|
||||
for (int i = 0; i < array.Length; i++)
|
||||
{
|
||||
Assert.Equal(strings[i], array[i]);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteEmptyArrayTest()
|
||||
{
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteEmptyArrayTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("ints", new int[0]);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
int[] ints = (int[])vars["ints"];
|
||||
Assert.Empty(ints);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatlabWriteNumericNameTest()
|
||||
{
|
||||
string fileName = $"{System.IO.Path.GetTempPath()}MatlabWriteNumericNameTest{Environment.CurrentManagedThreadId}.mat";
|
||||
using (MatlabWriter writer = new MatlabWriter(fileName))
|
||||
{
|
||||
writer.Write("24", new int[0]);
|
||||
}
|
||||
Dictionary<string, object> vars = MatlabReader.Read(fileName);
|
||||
int[] ints = (int[])vars["24"];
|
||||
Assert.Empty(ints);
|
||||
}
|
||||
|
||||
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
|
||||
#pragma warning disable 162
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// Nonconjugate VMP crashes with improper message on the first iteration.
|
||||
/// </summary>
|
||||
|
@ -280,7 +111,8 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
//response.AddAttribute(new MarginalPrototype(new Gaussian()));
|
||||
bool[,] data;
|
||||
double[] discriminationTrue = new double[0];
|
||||
if (false)
|
||||
bool useDummyData = false;
|
||||
if (useDummyData)
|
||||
{
|
||||
data = new bool[4,2];
|
||||
for (int i = 0; i < data.GetLength(0); i++)
|
||||
|
@ -315,10 +147,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Console.WriteLine(marg[i].GetMean() + " \t " + discriminationTrue[i]);
|
||||
}
|
||||
|
||||
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
|
||||
#pragma warning restore 162
|
||||
#endif
|
||||
|
||||
internal void LogisticIrtProductExpTest()
|
||||
{
|
||||
int numStudents = 20;
|
||||
|
|
Загрузка…
Ссылка в новой задаче