machinelearning/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs

94 строки
3.5 KiB
C#

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Xunit;
namespace Microsoft.ML.Scenarios
{
public partial class ScenariosTests
{
public class ClusteringPrediction
{
[ColumnName("PredictedLabel")]
public uint SelectedClusterId;
[ColumnName("Score")]
public float[] Distance;
}
public class ClusteringData
{
[ColumnName("Features")]
[VectorType(2)]
public float[] Points;
}
[Fact]
public void PredictClusters()
{
int n = 1000;
int k = 4;
var rand = new Random(1);
var clusters = new ClusteringData[k];
var data = new ClusteringData[n];
for (int i = 0; i < k; i++)
{
//pick clusters as points on circle with angle to axis X equal to 360*i/k
clusters[i] = new ClusteringData { Points = new float[2] { (float)Math.Cos(Math.PI * i * 2 / k), (float)Math.Sin(Math.PI * i * 2 / k) } };
}
// create data points by randomly picking cluster and shifting point slightly away from it.
for (int i = 0; i < n; i++)
{
var index = rand.Next(0, k);
var shift = (rand.NextDouble() - 0.5) / 10;
data[i] = new ClusteringData
{
Points = new float[2]
{
(float)(clusters[index].Points[0] + shift),
(float)(clusters[index].Points[1] + shift)
}
};
}
var mlContext = new MLContext(seed: 1);
// Turn the data into the ML.NET data view.
// We can use CreateDataView or ReadFromEnumerable, depending on whether 'churnData' is an IList,
// or merely an IEnumerable.
var trainData = mlContext.Data.LoadFromEnumerable(data);
var testData = mlContext.Data.LoadFromEnumerable(clusters);
// Create Estimator
var pipe = mlContext.Clustering.Trainers.KMeans("Features", numberOfClusters: k);
// Train the pipeline
var trainedModel = pipe.Fit(trainData);
// Validate that initial points we pick up as centers of cluster during data generation belong to different clusters.
var labels = new HashSet<uint>();
var predictFunction = mlContext.Model.CreatePredictionEngine<ClusteringData, ClusteringPrediction>(trainedModel);
for (int i = 0; i < k; i++)
{
var scores = predictFunction.Predict(clusters[i]);
Assert.True(!labels.Contains(scores.SelectedClusterId));
labels.Add(scores.SelectedClusterId);
}
// Evaluate the trained pipeline
var predicted = trainedModel.Transform(testData);
var metrics = mlContext.Clustering.Evaluate(predicted);
//Label is not specified, so NMI would be equal to NaN
Assert.Equal(double.NaN, metrics.NormalizedMutualInformation);
//Calculate dbi is false by default so Dbi would be 0
Assert.Equal(0d, metrics.DaviesBouldinIndex);
Assert.Equal(0d, metrics.AverageDistance, 0.00001);
}
}
}