add random data point corruption ability, change to scatter

This commit is contained in:
Philip Ball 2019-08-28 17:17:50 +01:00
Родитель 6b84def27a
Коммит 3855935efd
1 изменённых файлов: 44 добавлений и 9 удалений

Просмотреть файл

@ -13,6 +13,12 @@ using System.Threading;
namespace Microsoft.ML.Probabilistic.Tutorials
{
/// <summary>
/// This script generates a dataset which has the following pipeline:
/// 1) randomly sample a 1D function from a GP;
/// 2) pick a random subset of 'numData' points;
/// 3) pick a further random proportion 'propCorrupt' of 'numData' to corrupt according to a uniform distribution with a range of -3 to 3
/// </summary>
[Example("Applications", "A Gaussian Process regression example")]
class GaussianProcessDataGenerator
{
@ -25,8 +31,14 @@ namespace Microsoft.ML.Probabilistic.Tutorials
return;
}
// Number of datapoints
int numData = 30;
// The proportion of points to randomly corrupt
double propCorrupt = 0.3;
// The points to evaluate
Vector[] inputs = this.VectorLinSpace(-5, 5, 51);
Vector[] inputs = this.VectorRange(-5, 5, numData, true);
// Set up the GP prior, a distribution over functions, which will be filled in later
Variable<SparseGP> prior = Variable.New<SparseGP>().Named("prior");
@ -42,7 +54,7 @@ namespace Microsoft.ML.Probabilistic.Tutorials
Variable<double> score = Variable.FunctionEvaluate(f, x[j]).Named("score");
// The basis
Vector[] basis = VectorLinSpace(-5, 5, 6);
Vector[] basis = this.VectorRange(-5, 5, 6, false);
// The kernel
IKernelFunction kf;
@ -57,18 +69,26 @@ namespace Microsoft.ML.Probabilistic.Tutorials
var randomFunc = sgp.Sample();
// plotting boilerplate
var p1 = new OxyPlot.Series.LineSeries
var p1 = new OxyPlot.Series.ScatterSeries
{
Title = "Random Function"
};
Console.WriteLine("");
Console.WriteLine("Random function evaluations:");
Random rng = new Random();
// get random data
for (int i = 0; i < inputs.Length; i++)
{
double post = randomFunc.Evaluate(inputs[i]);
Console.WriteLine("f({0}) = {1}", inputs[i], post);
p1.Points.Add(new DataPoint(inputs[i][0], post));
// corrupt data point if it we haven't exceed the proportion we wish to corrupt
if (i < propCorrupt * numData)
{
double sign = rng.NextDouble() > 0.5 ? 1 : -1;
double distance = rng.NextDouble() * 3;
post = (sign * distance) + post;
}
p1.Points.Add(new OxyPlot.Series.ScatterPoint(inputs[i][0], post));
}
var model = new PlotModel();
@ -77,6 +97,8 @@ namespace Microsoft.ML.Probabilistic.Tutorials
Thread thread = new Thread(() => DisplayPNG(model));
thread.SetApartmentState(ApartmentState.STA);
thread.Start();
Console.WriteLine("Plotting complete: Generated {0} points with {1} corrupted", numData, (int)System.Math.Ceiling(numData * propCorrupt));
}
private void DisplayPNG(PlotModel model)
@ -86,13 +108,26 @@ namespace Microsoft.ML.Probabilistic.Tutorials
PngExporter.Export(model, outputToFile, 600, 400, OxyColors.White);
}
private Vector[] VectorLinSpace(int min, int max, int len)
/// <summary>
/// Generates a 1D vector with length len having a min and max; data points are randomly distributed and ordered if specified
/// </summary>
private Vector[] VectorRange(double min, double max, int len, bool random)
{
Vector[] inputs = new Vector[len];
Random rng = new Random();
for (int i = 0; i < len; i++)
{
double num = i / (double)(len - 1);
double num = new double();
if (random)
{
num = rng.NextDouble();
}
else
{
num = i / (double)(len - 1);
}
num = num * (max - min);
num += min;
inputs[i] = Vector.FromArray(new double[1] { num });