Added ReceiverOperatingCharacteristic struct

This commit is contained in:
Vijay Sharma 2018-10-31 23:07:45 +00:00
Родитель cb65e18b79
Коммит 247199a84e
6 изменённых файлов: 93 добавлений и 21 удалений

Просмотреть файл

@ -393,7 +393,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="instanceSource">The instance source.</param>
/// <param name="predictions">The predictions.</param>
/// <returns>The computed receiver operating characteristic curve.</returns>
public IEnumerable<Pair<double, double>> ReceiverOperatingCharacteristicCurve(
public IEnumerable<ReceiverOperatingCharacteristic> ReceiverOperatingCharacteristicCurve(
TLabel positiveClassLabel,
TInstanceSource instanceSource,
IEnumerable<IDictionary<TLabel, double>> predictions)
@ -410,7 +410,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="labelSource">The label source.</param>
/// <param name="predictions">The predictions.</param>
/// <returns>The computed receiver operating characteristic curve.</returns>
public IEnumerable<Pair<double, double>> ReceiverOperatingCharacteristicCurve(
public IEnumerable<ReceiverOperatingCharacteristic> ReceiverOperatingCharacteristicCurve(
TLabel positiveClassLabel,
TInstanceSource instanceSource,
TLabelSource labelSource,
@ -429,7 +429,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="instanceSource">The instance source.</param>
/// <param name="predictions">The predictions.</param>
/// <returns>The computed receiver operating characteristic curve.</returns>
public IEnumerable<Pair<double, double>> ReceiverOperatingCharacteristicCurve(
public IEnumerable<ReceiverOperatingCharacteristic> ReceiverOperatingCharacteristicCurve(
TLabel positiveClassLabel,
TLabel negativeClassLabel,
TInstanceSource instanceSource,
@ -448,7 +448,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="labelSource">The label source.</param>
/// <param name="predictions">The predictions.</param>
/// <returns>The computed receiver operating characteristic curve.</returns>
public IEnumerable<Pair<double, double>> ReceiverOperatingCharacteristicCurve(
public IEnumerable<ReceiverOperatingCharacteristic> ReceiverOperatingCharacteristicCurve(
TLabel positiveClassLabel,
TLabel negativeClassLabel,
TInstanceSource instanceSource,

Просмотреть файл

@ -487,7 +487,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <remarks>
/// All instances not contained in <paramref name="positiveInstances"/> are assumed to belong to the 'negative' class.
/// </remarks>
public static IEnumerable<Pair<double, double>> ReceiverOperatingCharacteristicCurve<TInstance>(
public static IEnumerable<ReceiverOperatingCharacteristic> ReceiverOperatingCharacteristicCurve<TInstance>(
IEnumerable<TInstance> positiveInstances, IEnumerable<KeyValuePair<TInstance, double>> instanceScores)
{
if (positiveInstances == null)
@ -523,7 +523,7 @@ namespace Microsoft.ML.Probabilistic.Learners
double falsePositiveRate;
double truePositiveRate;
double previousScore = double.NaN;
var rocCurve = new List<Pair<double, double>>();
var rocCurve = new List<ReceiverOperatingCharacteristic>();
foreach (var instance in sortedInstanceScores)
{
@ -532,7 +532,7 @@ namespace Microsoft.ML.Probabilistic.Learners
{
falsePositiveRate = falsePositivesCount / (double)negativesCount;
truePositiveRate = truePositivesCount / (double)positivesCount;
rocCurve.Add(Pair.Create(falsePositiveRate, truePositiveRate));
rocCurve.Add(new ReceiverOperatingCharacteristic(truePositiveRate, falsePositiveRate));
previousScore = score;
}
@ -549,7 +549,7 @@ namespace Microsoft.ML.Probabilistic.Learners
// Add point for (1,1)
falsePositiveRate = falsePositivesCount / (double)negativesCount;
truePositiveRate = truePositivesCount / (double)positivesCount;
rocCurve.Add(Pair.Create(falsePositiveRate, truePositiveRate));
rocCurve.Add(new ReceiverOperatingCharacteristic(truePositiveRate, falsePositiveRate));
return rocCurve;
}

Просмотреть файл

@ -0,0 +1,76 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners
{
using Microsoft.ML.Probabilistic.Utilities;
/// <summary>
/// Struct which holds a single receiver operating characteristic (ROC)
/// </summary>
public struct ReceiverOperatingCharacteristic
{
/// <summary>
/// Initializes a new instance of the <see cref="ReceiverOperatingCharacteristic"/> struct.
/// </summary>
/// <param name="truePositiveRate">The true positive rate (TPR)</param>
/// /// <param name="falsePositiveRate">The false positive rate (FPR)</param>
public ReceiverOperatingCharacteristic(double truePositiveRate, double falsePositiveRate)
: this()
{
this.TruePositiveRate = truePositiveRate;
this.FalsePositiveRate = falsePositiveRate;
}
/// <summary>
/// Gets the <see cref="TruePositiveRate"/>
/// </summary>
public readonly double TruePositiveRate;
/// <summary>
/// Gets the <see cref="FalsePositiveRate"/>.
/// </summary>
public readonly double FalsePositiveRate;
/// <summary>
/// Gets the string representation of this <see cref="ReceiverOperatingCharacteristic"/>.
/// </summary>
/// <returns>The string representation of the <see cref="ReceiverOperatingCharacteristic"/>.</returns>
public override string ToString()
{
return $"{this.TruePositiveRate}, {this.FalsePositiveRate}";
}
/// <summary>
/// Checks if this object is equal to <paramref name="obj"/>.
/// </summary>
/// <param name="obj">The object to compare this object with.</param>
/// <returns>
/// <see langword="true"/> if this object is equal to <paramref name="obj"/>,
/// <see langword="false"/> otherwise.
/// </returns>
public override bool Equals(object obj)
{
if (obj is ReceiverOperatingCharacteristic receiverOperatingCharacteristic)
{
return object.Equals(this.TruePositiveRate, receiverOperatingCharacteristic.TruePositiveRate) && object.Equals(this.FalsePositiveRate, receiverOperatingCharacteristic.FalsePositiveRate);
}
return false;
}
/// <summary>
/// Computes the hash code of this object.
/// </summary>
/// <returns>The computed hash code.</returns>
public override int GetHashCode()
{
int result = Hash.Start;
result = Hash.Combine(result, this.TruePositiveRate.GetHashCode());
result = Hash.Combine(result, this.FalsePositiveRate.GetHashCode());
return result;
}
}
}

Просмотреть файл

@ -599,10 +599,10 @@ namespace Microsoft.ML.Probabilistic.Learners.Runners
writer.WriteLine("#");
writer.WriteLine("# Class '" + positiveClassLabel + "' (versus the rest)");
writer.WriteLine("#");
writer.WriteLine("# False positive rate (FPR), true positive rate (TPR)");
writer.WriteLine("# True positive rate (TPR), False positive rate (FPR)");
foreach (var point in rocCurve)
{
writer.WriteLine("{0}, {1}", point.First, point.Second);
writer.WriteLine(point);
}
}
}

Просмотреть файл

@ -176,17 +176,17 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
public void RocCurveTest()
{
// Curve for perfect predictions
var expected = new[] { Pair.Create(0.0, 0.0), Pair.Create(0.0, 1.0), Pair.Create(1.0, 1.0) };
var expected = new[] { new ReceiverOperatingCharacteristic(0.0, 0.0), new ReceiverOperatingCharacteristic(1.0, 0.0), new ReceiverOperatingCharacteristic(1.0, 1.0) };
var actual = this.evaluator.ReceiverOperatingCharacteristicCurve(LabelSet[0], this.groundTruth, this.groundTruth).ToArray();
Xunit.Assert.Equal(expected, actual);
// Curve for imperfect predictions (one-versus-rest)
expected = new[] { Pair.Create(0.0, 0.0), Pair.Create(0.5, 0.0), Pair.Create(0.5, 1 / 3.0), Pair.Create(0.5, 2 / 3.0), Pair.Create(1.0, 1.0) };
expected = new[] { new ReceiverOperatingCharacteristic(0.0, 0.0), new ReceiverOperatingCharacteristic(0.0, 0.5), new ReceiverOperatingCharacteristic(1 / 3.0, 0.5), new ReceiverOperatingCharacteristic(2 / 3.0, 0.5), new ReceiverOperatingCharacteristic(1.0, 1.0) };
actual = this.evaluator.ReceiverOperatingCharacteristicCurve(LabelSet[0], this.groundTruth, this.predictions).ToArray();
Xunit.Assert.Equal(expected, actual); // matches below AUC = 5/12
// Curve for imperfect predictions (one-versus-another)
expected = new[] { Pair.Create(0.0, 0.0), Pair.Create(0.0, 1 / 3.0), Pair.Create(0.0, 2 / 3.0), Pair.Create(1.0, 1.0) };
expected = new[] { new ReceiverOperatingCharacteristic(0.0, 0.0), new ReceiverOperatingCharacteristic(1 / 3.0, 0.0), new ReceiverOperatingCharacteristic(2 / 3.0, 0.0), new ReceiverOperatingCharacteristic(1.0, 1.0) };
actual = this.evaluator.ReceiverOperatingCharacteristicCurve(LabelSet[0], LabelSet[1], this.groundTruth, this.predictions).ToArray();
Xunit.Assert.Equal(expected, actual); // matches below AUC = 5/6

Просмотреть файл

@ -352,9 +352,9 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
// Duplicate instance scores, duplicate positive instances
var expectedCurve = new[]
{
Pair.Create<double, double>(0, 0),
Pair.Create<double, double>(0.5, 1),
Pair.Create<double, double>(1, 1)
new ReceiverOperatingCharacteristic(0, 0),
new ReceiverOperatingCharacteristic(1, 0.5),
new ReceiverOperatingCharacteristic(1, 1)
};
var computedCurve = Metrics.ReceiverOperatingCharacteristicCurve(new[] { 1, 1, 2 }, new Dictionary<int, double> { { 1, 0.5 }, { 2, 0.5 }, { 3, 0.5 }, { 4, 0 } }).ToArray();
foreach (var tuple in computedCurve)
@ -362,11 +362,7 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
Console.WriteLine(tuple);
}
Assert.Equal(expectedCurve.Length, computedCurve.Length);
for (int i = 0; i < expectedCurve.Length; i++)
{
Assert.Equal(expectedCurve[i], computedCurve[i]);
}
Xunit.Assert.Equal(expectedCurve, computedCurve);
// No positive instance scores
Assert.Throws<ArgumentException>(() => Metrics.ReceiverOperatingCharacteristicCurve(new int[] { }, new Dictionary<int, double> { { 1, 1 } }));