зеркало из https://github.com/dotnet/infer.git
Changed the BCC confusion matrix prior (#367)
Changed the BCC confusion matrix prior so that TrueLabels can be inferred when LabelCount==2. Fixed serialization of BCC posteriors. BCC posteriors now save to the Results folder. Fixed serialization example code.
This commit is contained in:
Родитель
c90ad0bf1e
Коммит
76ec66dcca
|
@ -199,7 +199,7 @@ namespace Crowdsourcing
|
|||
results.RunDawidSkene(subData, calculateAccuracy);
|
||||
break;
|
||||
default: // Run BCC models
|
||||
results.RunBCC(modelName, subData, data, model, Results.RunMode.ClearResults, calculateAccuracy, communityCount, false);
|
||||
results.RunBCC(resultsDir + modelName, subData, data, model, Results.RunMode.ClearResults, calculateAccuracy, communityCount, false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,8 +18,13 @@ using Range = Microsoft.ML.Probabilistic.Models.Range;
|
|||
namespace Crowdsourcing
|
||||
{
|
||||
/// <summary>
|
||||
/// The BCC model class.
|
||||
/// Implements statistical inference for (community-based and non-community-based) Bayesian classifier combination.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// References:
|
||||
/// Matteo Venanzi, John Guiver, Gabriella Kazai, Pushmeet Kohli, and Milad Shokouhi. Community-Based Bayesian Aggregation Models for Crowdsourcing. In Proceedings of the 23rd International World Wide Web Conference, WWW2014, ACM, April 2014.
|
||||
/// H.C. Kim and Z. Ghahramani. Bayesian classifier combination. International Conference on Articial Intelligence and Statistics, pages 619-627, 2012.
|
||||
/// </remarks>
|
||||
public class BCC
|
||||
{
|
||||
/// <summary>
|
||||
|
@ -188,14 +193,13 @@ namespace Crowdsourcing
|
|||
/// <param name="priors">The priors.</param>
|
||||
protected virtual void SetPriors(int workerCount, Posteriors priors)
|
||||
{
|
||||
int numClasses = c.SizeAsInt;
|
||||
WorkerCount.ObservedValue = workerCount;
|
||||
if (priors == null)
|
||||
{
|
||||
BackgroundLabelProbPrior.ObservedValue = Dirichlet.Uniform(numClasses);
|
||||
BackgroundLabelProbPrior.ObservedValue = Dirichlet.Uniform(LabelCount);
|
||||
var confusionMatrixPrior = GetConfusionMatrixPrior();
|
||||
ConfusionMatrixPrior.ObservedValue = Util.ArrayInit(workerCount, worker => Util.ArrayInit(numClasses, lab => confusionMatrixPrior[lab]));
|
||||
TrueLabelConstraint.ObservedValue = Util.ArrayInit(TaskCount, t => Discrete.Uniform(numClasses));
|
||||
ConfusionMatrixPrior.ObservedValue = Util.ArrayInit(workerCount, worker => Util.ArrayInit(LabelCount, lab => confusionMatrixPrior[lab]));
|
||||
TrueLabelConstraint.ObservedValue = Util.ArrayInit(TaskCount, t => Discrete.Uniform(LabelCount));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -283,7 +287,10 @@ namespace Crowdsourcing
|
|||
var confusionMatrixPrior = new Dirichlet[LabelCount];
|
||||
for (int d = 0; d < LabelCount; d++)
|
||||
{
|
||||
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => i == d ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * (LabelCount - 1) : 1.0));
|
||||
// The prior prefers diagonal confusion matrices.
|
||||
// The paper says "Each row of π(k)_c has a Dirichlet prior with pseudo counts 1 expect for the diagonal count set to C − 1." but that does not work for C=2.
|
||||
// Instead of following the paper, this code sets the diagonal to C.
|
||||
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => (i == d) ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * LabelCount : 1.0));
|
||||
}
|
||||
|
||||
return confusionMatrixPrior;
|
||||
|
|
|
@ -193,14 +193,14 @@ namespace Crowdsourcing
|
|||
results.RunDawidSkene(data, true);
|
||||
break;
|
||||
default:
|
||||
results.RunBCC(modelName, data, data, model, Results.RunMode.ClearResults, false, communityCount, false, false);
|
||||
results.RunBCC(ResultsDir + modelName, data, data, model, Results.RunMode.ClearResults, false, communityCount, false, false);
|
||||
break;
|
||||
}
|
||||
|
||||
// Write the inference results on a csv file
|
||||
using (StreamWriter writer = new StreamWriter(ResultsDir + "endpoints.csv", true))
|
||||
{
|
||||
writer.WriteLine("{0}:,{1:0.000},{2:0.0000}", modelName, results.Accuracy, results.NegativeLogProb);
|
||||
writer.WriteLine("{0},{1:0.000},{2:0.0000}", modelName, results.Accuracy, results.NegativeLogProb);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
|
|
@ -6,8 +6,12 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Runtime.Serialization;
|
||||
using System.Xml;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Serialization;
|
||||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
|
||||
namespace Crowdsourcing
|
||||
|
@ -354,6 +358,7 @@ namespace Crowdsourcing
|
|||
{
|
||||
CommunityModel communityModel = model as CommunityModel;
|
||||
IsCommunityModel = communityModel != null;
|
||||
string communityPriorsFileName = modelName + "CommunityPriors.xml";
|
||||
|
||||
if (this.Mapping == null)
|
||||
{
|
||||
|
@ -398,7 +403,7 @@ namespace Crowdsourcing
|
|||
ClearResults();
|
||||
if (mode == RunMode.LoadAndUseCommunityPriors && IsCommunityModel)
|
||||
{
|
||||
priors = DeserializeCommunityPosteriors(modelName, numCommunities);
|
||||
priors = DeserializeCommunityPosteriors(communityPriorsFileName, numCommunities);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -426,49 +431,51 @@ namespace Crowdsourcing
|
|||
/// Serialize parameters
|
||||
if (serialize)
|
||||
{
|
||||
using (FileStream stream = new FileStream(modelName + ".xml", FileMode.Create))
|
||||
var type = IsCommunityModel ? typeof(CommunityModel.Posteriors) : typeof(BCC.Posteriors);
|
||||
DataContractSerializer serializer = new DataContractSerializer(type, new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
|
||||
string posteriorsFileName = modelName + ".xml";
|
||||
using (XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(new FileStream(posteriorsFileName, FileMode.Create)))
|
||||
{
|
||||
var serializer = new System.Xml.Serialization.XmlSerializer(IsCommunityModel ? typeof(CommunityModel.Posteriors) : typeof(BCC.Posteriors));
|
||||
serializer.Serialize(stream, posteriors);
|
||||
serializer.WriteObject(writer, posteriors);
|
||||
}
|
||||
}
|
||||
|
||||
if (serializeCommunityPosteriors && IsCommunityModel)
|
||||
{
|
||||
SerializeCommunityPosteriors(modelName);
|
||||
SerializeCommunityPosteriors(communityPriorsFileName);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Serializes the posteriors on an xml file.
|
||||
/// </summary>
|
||||
/// <param name="modelName">The model name.</param>
|
||||
void SerializeCommunityPosteriors(string modelName)
|
||||
/// <param name="fileName">The file name.</param>
|
||||
void SerializeCommunityPosteriors(string fileName)
|
||||
{
|
||||
NonTaskWorkerParameters ntwp = new NonTaskWorkerParameters();
|
||||
ntwp.BackgroundLabelProb = BackgroundLabelProb;
|
||||
ntwp.CommunityProb = CommunityProb;
|
||||
ntwp.CommunityScoreMatrix = CommunityScoreMatrix;
|
||||
using (FileStream stream = new FileStream(modelName + "CommunityPriors.xml", FileMode.Create))
|
||||
DataContractSerializer serializer = new DataContractSerializer(typeof(NonTaskWorkerParameters), new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
|
||||
using (XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(new FileStream(fileName, FileMode.Create)))
|
||||
{
|
||||
var serializer = new System.Xml.Serialization.XmlSerializer(typeof(NonTaskWorkerParameters));
|
||||
serializer.Serialize(stream, ntwp);
|
||||
serializer.WriteObject(writer, ntwp);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes the parameters of CBCC from an xml file (used in the LoadAndUseCommunityPriors mode).
|
||||
/// </summary>
|
||||
/// <param name="modelName">The model name.</param>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <param name="numCommunities">The number of communities.</param>
|
||||
/// <returns></returns>
|
||||
CommunityModel.Posteriors DeserializeCommunityPosteriors(string modelName, int numCommunities)
|
||||
CommunityModel.Posteriors DeserializeCommunityPosteriors(string fileName, int numCommunities)
|
||||
{
|
||||
CommunityModel.Posteriors cbccPriors = new CommunityModel.Posteriors();
|
||||
using (FileStream stream = new FileStream(modelName + "CommunityPriors.xml", FileMode.Open))
|
||||
DataContractSerializer serializer = new DataContractSerializer(typeof(NonTaskWorkerParameters), new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
|
||||
using (XmlDictionaryReader reader = XmlDictionaryReader.CreateTextReader(new FileStream(fileName, FileMode.Open), new XmlDictionaryReaderQuotas()))
|
||||
{
|
||||
var serializer = new System.Xml.Serialization.XmlSerializer(typeof(NonTaskWorkerParameters));
|
||||
var ntwp = (NonTaskWorkerParameters)serializer.Deserialize(stream);
|
||||
var ntwp = (NonTaskWorkerParameters)serializer.ReadObject(reader);
|
||||
|
||||
if (ntwp.BackgroundLabelProb.Dimension != Mapping.LabelCount)
|
||||
{
|
||||
|
@ -524,7 +531,7 @@ namespace Crowdsourcing
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Updates the results of with the new posteriors.
|
||||
/// Updates the results with the new posteriors.
|
||||
/// </summary>
|
||||
/// <param name="posteriors">The posteriors.</param>
|
||||
/// <param name="mode">The mode (for example training, prediction, etc.).</param>
|
||||
|
|
|
@ -65,7 +65,7 @@ namespace CrowdsourcingWithWords
|
|||
var confusionMatrixPrior = new Dirichlet[LabelCount];
|
||||
for (int d = 0; d < LabelCount; d++)
|
||||
{
|
||||
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => i == d ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * (LabelCount - 1) : 1.0));
|
||||
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => i == d ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * LabelCount : 1.0));
|
||||
}
|
||||
|
||||
return confusionMatrixPrior;
|
||||
|
|
|
@ -1474,17 +1474,16 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
internal void XmlSerializationExample()
|
||||
{
|
||||
Dirichlet d = new Dirichlet(3.0, 1.0, 2.0);
|
||||
string fileName = "temp.xml";
|
||||
DataContractSerializer serializer = new DataContractSerializer(typeof(Dirichlet), new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
|
||||
// write to disk
|
||||
using (FileStream stream = new FileStream("temp.xml", FileMode.Create))
|
||||
using (XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(new FileStream(fileName, FileMode.Create)))
|
||||
{
|
||||
XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(stream);
|
||||
serializer.WriteObject(writer, d);
|
||||
}
|
||||
// read from disk
|
||||
using (FileStream stream = new FileStream("temp.xml", FileMode.Open))
|
||||
using (XmlDictionaryReader reader = XmlDictionaryReader.CreateTextReader(new FileStream(fileName, FileMode.Open), new XmlDictionaryReaderQuotas()))
|
||||
{
|
||||
XmlDictionaryReader reader = XmlDictionaryReader.CreateTextReader(stream, new XmlDictionaryReaderQuotas());
|
||||
Dirichlet d2 = (Dirichlet)serializer.ReadObject(reader);
|
||||
Console.WriteLine(d2);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче