Prevent DataFrame.Sample() method from returning duplicated rows (#2939)
* resolves #2806 * replace forloop with ArraySegment<T> * reduce shuffle loop operations from O(Rows.Count) to O(numberOfRows)
This commit is contained in:
Родитель
54d3f56cfe
Коммит
f217d59432
|
@ -328,14 +328,28 @@ namespace Microsoft.Data.Analysis
|
|||
/// <param name="numberOfRows">Number of rows in the returned DataFrame</param>
|
||||
public DataFrame Sample(int numberOfRows)
|
||||
{
|
||||
Random rand = new Random();
|
||||
PrimitiveDataFrameColumn<long> indices = new PrimitiveDataFrameColumn<long>("Indices", numberOfRows);
|
||||
int randMaxValue = (int)Math.Min(Int32.MaxValue, Rows.Count);
|
||||
for (long i = 0; i < numberOfRows; i++)
|
||||
if (numberOfRows > Rows.Count)
|
||||
{
|
||||
indices[i] = rand.Next(randMaxValue);
|
||||
throw new ArgumentException(string.Format(Strings.ExceedsNumberOfRows, Rows.Count), nameof(numberOfRows));
|
||||
}
|
||||
|
||||
int shuffleLowerLimit = 0;
|
||||
int shuffleUpperLimit = (int)Math.Min(Int32.MaxValue, Rows.Count);
|
||||
|
||||
int[] shuffleArray = Enumerable.Range(0, shuffleUpperLimit).ToArray();
|
||||
Random rand = new Random();
|
||||
while (shuffleLowerLimit < numberOfRows)
|
||||
{
|
||||
int randomIndex = rand.Next(shuffleLowerLimit, shuffleUpperLimit);
|
||||
int temp = shuffleArray[shuffleLowerLimit];
|
||||
shuffleArray[shuffleLowerLimit] = shuffleArray[randomIndex];
|
||||
shuffleArray[randomIndex] = temp;
|
||||
shuffleLowerLimit++;
|
||||
}
|
||||
ArraySegment<int> segment = new ArraySegment<int>(shuffleArray, 0, shuffleLowerLimit);
|
||||
|
||||
PrimitiveDataFrameColumn<int> indices = new PrimitiveDataFrameColumn<int>("indices", segment);
|
||||
|
||||
return Clone(indices);
|
||||
}
|
||||
|
||||
|
|
|
@ -132,6 +132,15 @@ namespace Microsoft.Data {
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Looks up a localized string similar to Parameter.Count exceeds the number of rows({0}) in the DataFrame .
|
||||
/// </summary>
|
||||
internal static string ExceedsNumberOfRows {
|
||||
get {
|
||||
return ResourceManager.GetString("ExceedsNumberOfRows", resourceCulture);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Looks up a localized string similar to Expected either {0} or {1} to be provided.
|
||||
/// </summary>
|
||||
|
|
|
@ -141,6 +141,9 @@
|
|||
<data name="ExceedsNumberOfColumns" xml:space="preserve">
|
||||
<value>Parameter.Count exceeds the number of columns({0}) in the DataFrame </value>
|
||||
</data>
|
||||
<data name="ExceedsNumberOfRows" xml:space="preserve">
|
||||
<value>Parameter.Count exceeds the number of rows({0}) in the DataFrame </value>
|
||||
</data>
|
||||
<data name="ExpectedEitherGuessRowsOrDataTypes" xml:space="preserve">
|
||||
<value>Expected either {0} or {1} to be provided</value>
|
||||
</data>
|
||||
|
@ -186,4 +189,4 @@
|
|||
<data name="SpansMultipleBuffers" xml:space="preserve">
|
||||
<value>Cannot span multiple buffers</value>
|
||||
</data>
|
||||
</root>
|
||||
</root>
|
||||
|
|
|
@ -1561,9 +1561,20 @@ namespace Microsoft.Data.Analysis.Tests
|
|||
public void TestSample()
|
||||
{
|
||||
DataFrame df = MakeDataFrameWithAllColumnTypes(10);
|
||||
DataFrame sampled = df.Sample(3);
|
||||
Assert.Equal(3, sampled.Rows.Count);
|
||||
DataFrame sampled = df.Sample(7);
|
||||
Assert.Equal(7, sampled.Rows.Count);
|
||||
Assert.Equal(df.Columns.Count, sampled.Columns.Count);
|
||||
|
||||
// all sampled rows should be unique.
|
||||
HashSet<int?> uniqueRowValues = new HashSet<int?>();
|
||||
foreach(int? value in sampled.Columns["Int"])
|
||||
{
|
||||
uniqueRowValues.Add(value);
|
||||
}
|
||||
Assert.Equal(uniqueRowValues.Count, sampled.Rows.Count);
|
||||
|
||||
// should throw exception as sample size is greater than dataframe rows
|
||||
Assert.Throws<ArgumentException>(()=> df.Sample(13));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
Загрузка…
Ссылка в новой задаче