Changed the BCC confusion matrix prior (#367)

Changed the BCC confusion matrix prior so that TrueLabels can be inferred when LabelCount==2.
Fixed serialization of BCC posteriors.  BCC posteriors now save to the Results folder.
Fixed serialization example code.
This commit is contained in:
Tom Minka 2021-09-29 07:09:38 +01:00 коммит произвёл GitHub
Родитель c90ad0bf1e
Коммит 76ec66dcca
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 43 добавлений и 30 удалений

Просмотреть файл

@ -199,7 +199,7 @@ namespace Crowdsourcing
results.RunDawidSkene(subData, calculateAccuracy);
break;
default: // Run BCC models
results.RunBCC(modelName, subData, data, model, Results.RunMode.ClearResults, calculateAccuracy, communityCount, false);
results.RunBCC(resultsDir + modelName, subData, data, model, Results.RunMode.ClearResults, calculateAccuracy, communityCount, false);
break;
}
}

Просмотреть файл

@ -18,8 +18,13 @@ using Range = Microsoft.ML.Probabilistic.Models.Range;
namespace Crowdsourcing
{
/// <summary>
/// The BCC model class.
/// Implements statistical inference for (community-based and non-community-based) Bayesian classifier combination.
/// </summary>
/// <remarks>
/// References:
/// Matteo Venanzi, John Guiver, Gabriella Kazai, Pushmeet Kohli, and Milad Shokouhi. Community-Based Bayesian Aggregation Models for Crowdsourcing. In Proceedings of the 23rd International World Wide Web Conference, WWW2014, ACM, April 2014.
/// H.C. Kim and Z. Ghahramani. Bayesian classifier combination. International Conference on Articial Intelligence and Statistics, pages 619-627, 2012.
/// </remarks>
public class BCC
{
/// <summary>
@ -188,14 +193,13 @@ namespace Crowdsourcing
/// <param name="priors">The priors.</param>
protected virtual void SetPriors(int workerCount, Posteriors priors)
{
int numClasses = c.SizeAsInt;
WorkerCount.ObservedValue = workerCount;
if (priors == null)
{
BackgroundLabelProbPrior.ObservedValue = Dirichlet.Uniform(numClasses);
BackgroundLabelProbPrior.ObservedValue = Dirichlet.Uniform(LabelCount);
var confusionMatrixPrior = GetConfusionMatrixPrior();
ConfusionMatrixPrior.ObservedValue = Util.ArrayInit(workerCount, worker => Util.ArrayInit(numClasses, lab => confusionMatrixPrior[lab]));
TrueLabelConstraint.ObservedValue = Util.ArrayInit(TaskCount, t => Discrete.Uniform(numClasses));
ConfusionMatrixPrior.ObservedValue = Util.ArrayInit(workerCount, worker => Util.ArrayInit(LabelCount, lab => confusionMatrixPrior[lab]));
TrueLabelConstraint.ObservedValue = Util.ArrayInit(TaskCount, t => Discrete.Uniform(LabelCount));
}
else
{
@ -283,7 +287,10 @@ namespace Crowdsourcing
var confusionMatrixPrior = new Dirichlet[LabelCount];
for (int d = 0; d < LabelCount; d++)
{
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => i == d ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * (LabelCount - 1) : 1.0));
// The prior prefers diagonal confusion matrices.
// The paper says "Each row of π(k)_c has a Dirichlet prior with pseudo counts 1 expect for the diagonal count set to C 1." but that does not work for C=2.
// Instead of following the paper, this code sets the diagonal to C.
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => (i == d) ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * LabelCount : 1.0));
}
return confusionMatrixPrior;

Просмотреть файл

@ -193,14 +193,14 @@ namespace Crowdsourcing
results.RunDawidSkene(data, true);
break;
default:
results.RunBCC(modelName, data, data, model, Results.RunMode.ClearResults, false, communityCount, false, false);
results.RunBCC(ResultsDir + modelName, data, data, model, Results.RunMode.ClearResults, false, communityCount, false, false);
break;
}
// Write the inference results on a csv file
using (StreamWriter writer = new StreamWriter(ResultsDir + "endpoints.csv", true))
{
writer.WriteLine("{0}:,{1:0.000},{2:0.0000}", modelName, results.Accuracy, results.NegativeLogProb);
writer.WriteLine("{0},{1:0.000},{2:0.0000}", modelName, results.Accuracy, results.NegativeLogProb);
}
return results;
}

Просмотреть файл

@ -6,8 +6,12 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.Serialization;
using System.Xml;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Serialization;
using Microsoft.ML.Probabilistic.Utilities;
namespace Crowdsourcing
@ -354,6 +358,7 @@ namespace Crowdsourcing
{
CommunityModel communityModel = model as CommunityModel;
IsCommunityModel = communityModel != null;
string communityPriorsFileName = modelName + "CommunityPriors.xml";
if (this.Mapping == null)
{
@ -398,7 +403,7 @@ namespace Crowdsourcing
ClearResults();
if (mode == RunMode.LoadAndUseCommunityPriors && IsCommunityModel)
{
priors = DeserializeCommunityPosteriors(modelName, numCommunities);
priors = DeserializeCommunityPosteriors(communityPriorsFileName, numCommunities);
}
break;
}
@ -426,49 +431,51 @@ namespace Crowdsourcing
/// Serialize parameters
if (serialize)
{
using (FileStream stream = new FileStream(modelName + ".xml", FileMode.Create))
var type = IsCommunityModel ? typeof(CommunityModel.Posteriors) : typeof(BCC.Posteriors);
DataContractSerializer serializer = new DataContractSerializer(type, new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
string posteriorsFileName = modelName + ".xml";
using (XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(new FileStream(posteriorsFileName, FileMode.Create)))
{
var serializer = new System.Xml.Serialization.XmlSerializer(IsCommunityModel ? typeof(CommunityModel.Posteriors) : typeof(BCC.Posteriors));
serializer.Serialize(stream, posteriors);
serializer.WriteObject(writer, posteriors);
}
}
if (serializeCommunityPosteriors && IsCommunityModel)
{
SerializeCommunityPosteriors(modelName);
SerializeCommunityPosteriors(communityPriorsFileName);
}
}
/// <summary>
/// Serializes the posteriors on an xml file.
/// </summary>
/// <param name="modelName">The model name.</param>
void SerializeCommunityPosteriors(string modelName)
/// <param name="fileName">The file name.</param>
void SerializeCommunityPosteriors(string fileName)
{
NonTaskWorkerParameters ntwp = new NonTaskWorkerParameters();
ntwp.BackgroundLabelProb = BackgroundLabelProb;
ntwp.CommunityProb = CommunityProb;
ntwp.CommunityScoreMatrix = CommunityScoreMatrix;
using (FileStream stream = new FileStream(modelName + "CommunityPriors.xml", FileMode.Create))
DataContractSerializer serializer = new DataContractSerializer(typeof(NonTaskWorkerParameters), new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
using (XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(new FileStream(fileName, FileMode.Create)))
{
var serializer = new System.Xml.Serialization.XmlSerializer(typeof(NonTaskWorkerParameters));
serializer.Serialize(stream, ntwp);
serializer.WriteObject(writer, ntwp);
}
}
/// <summary>
/// Deserializes the parameters of CBCC from an xml file (used in the LoadAndUseCommunityPriors mode).
/// </summary>
/// <param name="modelName">The model name.</param>
/// <param name="fileName">The file name.</param>
/// <param name="numCommunities">The number of communities.</param>
/// <returns></returns>
CommunityModel.Posteriors DeserializeCommunityPosteriors(string modelName, int numCommunities)
CommunityModel.Posteriors DeserializeCommunityPosteriors(string fileName, int numCommunities)
{
CommunityModel.Posteriors cbccPriors = new CommunityModel.Posteriors();
using (FileStream stream = new FileStream(modelName + "CommunityPriors.xml", FileMode.Open))
DataContractSerializer serializer = new DataContractSerializer(typeof(NonTaskWorkerParameters), new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
using (XmlDictionaryReader reader = XmlDictionaryReader.CreateTextReader(new FileStream(fileName, FileMode.Open), new XmlDictionaryReaderQuotas()))
{
var serializer = new System.Xml.Serialization.XmlSerializer(typeof(NonTaskWorkerParameters));
var ntwp = (NonTaskWorkerParameters)serializer.Deserialize(stream);
var ntwp = (NonTaskWorkerParameters)serializer.ReadObject(reader);
if (ntwp.BackgroundLabelProb.Dimension != Mapping.LabelCount)
{
@ -524,7 +531,7 @@ namespace Crowdsourcing
}
/// <summary>
/// Updates the results of with the new posteriors.
/// Updates the results with the new posteriors.
/// </summary>
/// <param name="posteriors">The posteriors.</param>
/// <param name="mode">The mode (for example training, prediction, etc.).</param>

Просмотреть файл

@ -65,7 +65,7 @@ namespace CrowdsourcingWithWords
var confusionMatrixPrior = new Dirichlet[LabelCount];
for (int d = 0; d < LabelCount; d++)
{
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => i == d ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * (LabelCount - 1) : 1.0));
confusionMatrixPrior[d] = new Dirichlet(Util.ArrayInit(LabelCount, i => i == d ? (InitialWorkerBelief / (1 - InitialWorkerBelief)) * LabelCount : 1.0));
}
return confusionMatrixPrior;

Просмотреть файл

@ -1474,17 +1474,16 @@ namespace Microsoft.ML.Probabilistic.Tests
internal void XmlSerializationExample()
{
Dirichlet d = new Dirichlet(3.0, 1.0, 2.0);
string fileName = "temp.xml";
DataContractSerializer serializer = new DataContractSerializer(typeof(Dirichlet), new DataContractSerializerSettings { DataContractResolver = new InferDataContractResolver() });
// write to disk
using (FileStream stream = new FileStream("temp.xml", FileMode.Create))
using (XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(new FileStream(fileName, FileMode.Create)))
{
XmlDictionaryWriter writer = XmlDictionaryWriter.CreateTextWriter(stream);
serializer.WriteObject(writer, d);
}
// read from disk
using (FileStream stream = new FileStream("temp.xml", FileMode.Open))
using (XmlDictionaryReader reader = XmlDictionaryReader.CreateTextReader(new FileStream(fileName, FileMode.Open), new XmlDictionaryReaderQuotas()))
{
XmlDictionaryReader reader = XmlDictionaryReader.CreateTextReader(stream, new XmlDictionaryReaderQuotas());
Dirichlet d2 = (Dirichlet)serializer.ReadObject(reader);
Console.WriteLine(d2);
}