94 строки
3.5 KiB
C#
94 строки
3.5 KiB
C#
// Licensed to the .NET Foundation under one or more agreements.
|
|
// The .NET Foundation licenses this file to you under the MIT license.
|
|
// See the LICENSE file in the project root for more information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using Microsoft.ML.Data;
|
|
using Xunit;
|
|
|
|
namespace Microsoft.ML.Scenarios
|
|
{
|
|
public partial class ScenariosTests
|
|
{
|
|
public class ClusteringPrediction
|
|
{
|
|
[ColumnName("PredictedLabel")]
|
|
public uint SelectedClusterId;
|
|
[ColumnName("Score")]
|
|
public float[] Distance;
|
|
}
|
|
|
|
public class ClusteringData
|
|
{
|
|
[ColumnName("Features")]
|
|
[VectorType(2)]
|
|
public float[] Points;
|
|
}
|
|
|
|
[Fact]
|
|
public void PredictClusters()
|
|
{
|
|
int n = 1000;
|
|
int k = 4;
|
|
var rand = new Random(1);
|
|
var clusters = new ClusteringData[k];
|
|
var data = new ClusteringData[n];
|
|
for (int i = 0; i < k; i++)
|
|
{
|
|
//pick clusters as points on circle with angle to axis X equal to 360*i/k
|
|
clusters[i] = new ClusteringData { Points = new float[2] { (float)Math.Cos(Math.PI * i * 2 / k), (float)Math.Sin(Math.PI * i * 2 / k) } };
|
|
}
|
|
// create data points by randomly picking cluster and shifting point slightly away from it.
|
|
for (int i = 0; i < n; i++)
|
|
{
|
|
var index = rand.Next(0, k);
|
|
var shift = (rand.NextDouble() - 0.5) / 10;
|
|
data[i] = new ClusteringData
|
|
{
|
|
Points = new float[2]
|
|
{
|
|
(float)(clusters[index].Points[0] + shift),
|
|
(float)(clusters[index].Points[1] + shift)
|
|
}
|
|
};
|
|
}
|
|
|
|
var mlContext = new MLContext(seed: 1);
|
|
|
|
// Turn the data into the ML.NET data view.
|
|
// We can use CreateDataView or ReadFromEnumerable, depending on whether 'churnData' is an IList,
|
|
// or merely an IEnumerable.
|
|
var trainData = mlContext.Data.LoadFromEnumerable(data);
|
|
var testData = mlContext.Data.LoadFromEnumerable(clusters);
|
|
|
|
// Create Estimator
|
|
var pipe = mlContext.Clustering.Trainers.KMeans("Features", numberOfClusters: k);
|
|
|
|
// Train the pipeline
|
|
var trainedModel = pipe.Fit(trainData);
|
|
|
|
// Validate that initial points we pick up as centers of cluster during data generation belong to different clusters.
|
|
var labels = new HashSet<uint>();
|
|
var predictFunction = mlContext.Model.CreatePredictionEngine<ClusteringData, ClusteringPrediction>(trainedModel);
|
|
|
|
for (int i = 0; i < k; i++)
|
|
{
|
|
var scores = predictFunction.Predict(clusters[i]);
|
|
Assert.True(!labels.Contains(scores.SelectedClusterId));
|
|
labels.Add(scores.SelectedClusterId);
|
|
}
|
|
|
|
// Evaluate the trained pipeline
|
|
var predicted = trainedModel.Transform(testData);
|
|
var metrics = mlContext.Clustering.Evaluate(predicted);
|
|
|
|
//Label is not specified, so NMI would be equal to NaN
|
|
Assert.Equal(double.NaN, metrics.NormalizedMutualInformation);
|
|
//Calculate dbi is false by default so Dbi would be 0
|
|
Assert.Equal(0d, metrics.DaviesBouldinIndex);
|
|
Assert.Equal(0d, metrics.AverageDistance, 0.00001);
|
|
}
|
|
}
|
|
}
|