Improve DataFrame.OrderBy method by providing ability to define how null values should be positioned (#7118)

This commit is contained in:
Aleksei Smirnov 2024-06-06 09:15:07 +03:00 коммит произвёл GitHub
Родитель 4bc753a404
Коммит b67107b157
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 133 добавлений и 66 удалений

Просмотреть файл

@ -260,19 +260,22 @@ namespace Microsoft.Data.Analysis
/// <summary>
/// Orders the data frame by a specified column.
/// </summary>
/// <param name="columnName">The column name to order by</param>
public DataFrame OrderBy(string columnName)
/// <param name="columnName">The column name to order by.</param>
/// <param name="ascending">Sorting order.</param>
/// <param name="putNullValuesLast">If true, null values are always put at the end.</param>
public DataFrame OrderBy(string columnName, bool ascending = true, bool putNullValuesLast = true)
{
return Sort(columnName, isAscending: true);
return Sort(columnName, ascending, putNullValuesLast);
}
/// <summary>
/// Orders the data frame by a specified column in descending order.
/// </summary>
/// <param name="columnName">The column name to order by</param>
public DataFrame OrderByDescending(string columnName)
/// <param name="columnName">The column name to order by.</param>
/// <param name="putNullValuesLast">If true, null values are always put at the end.</param>
public DataFrame OrderByDescending(string columnName, bool putNullValuesLast = true)
{
return Sort(columnName, isAscending: false);
return Sort(columnName, false, putNullValuesLast);
}
/// <summary>
@ -657,19 +660,16 @@ namespace Microsoft.Data.Analysis
_schema = null;
}
private DataFrame Sort(string columnName, bool isAscending)
private DataFrame Sort(string columnName, bool ascending, bool putNullValuesLast)
{
DataFrameColumn column = Columns[columnName];
PrimitiveDataFrameColumn<long> sortIndices = column.GetAscendingSortIndices(out Int64DataFrameColumn nullIndices);
for (long i = 0; i < nullIndices.Length; i++)
{
sortIndices.Append(nullIndices[i]);
}
PrimitiveDataFrameColumn<long> sortIndices = column.GetSortIndices(ascending, putNullValuesLast);
List<DataFrameColumn> newColumns = new List<DataFrameColumn>(Columns.Count);
for (int i = 0; i < Columns.Count; i++)
{
DataFrameColumn oldColumn = Columns[i];
DataFrameColumn newColumn = oldColumn.Clone(sortIndices, !isAscending);
DataFrameColumn newColumn = oldColumn.Clone(sortIndices);
Debug.Assert(newColumn.NullCount == oldColumn.NullCount);
newColumns.Add(newColumn);
}

Просмотреть файл

@ -230,13 +230,14 @@ namespace Microsoft.Data.Analysis
protected abstract DataFrameColumn CloneImplementation(long numberOfNullsToAppend = 0);
/// <summary>
/// Returns a copy of this column sorted by its values
/// Returns a copy of this column sorted by its values.
/// </summary>
/// <param name="ascending"></param>
public virtual DataFrameColumn Sort(bool ascending = true)
/// <param name="ascending">Sorting order.</param>
/// <param name="putNullValuesLast">If true, null values are always put at the end.</param>
public DataFrameColumn Sort(bool ascending = true, bool putNullValuesLast = true)
{
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
return Clone(sortIndices, !ascending, NullCount);
PrimitiveDataFrameColumn<long> sortIndices = GetSortIndices(ascending, putNullValuesLast);
return Clone(sortIndices);
}
/// <summary>
@ -441,20 +442,30 @@ namespace Microsoft.Data.Analysis
}
/// <summary>
/// Returns the indices of non-null values that, when applied, result in this column being sorted in ascending order. Also returns the indices of null values in <paramref name="nullIndices"/>.
/// Returns the indices that, when applied, result in this column being sorted./>.
/// </summary>
/// <param name="nullIndices">Indices of values that are <see langword="null"/>.</param>
internal virtual PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices) => throw new NotImplementedException();
/// <param name="ascending">Sorting order.</param>
/// <param name="putNullValuesLast">If true, null values are always put at the end.</param>
internal abstract PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending, bool putNullValuesLast);
internal delegate long GetBufferSortIndex(int bufferIndex, int sortIndex);
internal delegate ValueTuple<T, int> GetValueAndBufferSortIndexAtBuffer<T>(int bufferIndex, int valueIndex);
internal delegate int GetBufferLengthAtIndex(int bufferIndex);
internal void PopulateColumnSortIndicesWithHeap<T>(SortedDictionary<T, List<ValueTuple<int, int>>> heapOfValueAndListOfTupleOfSortAndBufferIndex,
protected delegate long GetBufferSortIndex(int bufferIndex, int sortIndex);
protected delegate ValueTuple<T, int> GetValueAndBufferSortIndexAtBuffer<T>(int bufferIndex, int valueIndex);
protected delegate int GetBufferLengthAtIndex(int bufferIndex);
protected static void PopulateColumnSortIndicesWithHeap<T>(SortedDictionary<T, List<ValueTuple<int, int>>> heapOfValueAndListOfTupleOfSortAndBufferIndex,
PrimitiveDataFrameColumn<long> columnSortIndices,
PrimitiveDataFrameColumn<long> columnNullIndices,
bool ascending,
bool putNullValuesLast,
GetBufferSortIndex getBufferSortIndex,
GetValueAndBufferSortIndexAtBuffer<T> getValueAndBufferSortIndexAtBuffer,
GetBufferLengthAtIndex getBufferLengthAtIndex)
{
long i = ascending ? columnNullIndices.Length : columnSortIndices.Length - 1;
if (putNullValuesLast)
i -= columnNullIndices.Length;
while (heapOfValueAndListOfTupleOfSortAndBufferIndex.Count > 0)
{
KeyValuePair<T, List<ValueTuple<int, int>>> minElement = heapOfValueAndListOfTupleOfSortAndBufferIndex.ElementAt(0);
@ -473,7 +484,9 @@ namespace Microsoft.Data.Analysis
int sortIndex = sortAndBufferIndex.sortIndex;
int bufferIndex = sortAndBufferIndex.bufferIndex;
long bufferSortIndex = getBufferSortIndex(bufferIndex, sortIndex);
columnSortIndices.Append(bufferSortIndex);
columnSortIndices[ascending ? i++ : i--] = bufferSortIndex;
if (sortIndex + 1 < getBufferLengthAtIndex(bufferIndex))
{
int nextSortIndex = sortIndex + 1;
@ -486,6 +499,14 @@ namespace Microsoft.Data.Analysis
}
}
//Fill Nulls
var start = putNullValuesLast ? columnSortIndices.Length - columnNullIndices.Length : 0;
for (long j = 0; j < columnNullIndices.Length; j++)
{
columnSortIndices[start + j] = columnNullIndices[j];
}
}
internal static int FloorLog2PlusOne(int n)
{

Просмотреть файл

@ -367,8 +367,7 @@ namespace Microsoft.Data.Analysis
return new StringArray(numberOfRows, offsetsBuffer, dataBuffer, nullBuffer, nullCount, indexInBuffer);
}
/// <inheritdoc/>
public override DataFrameColumn Sort(bool ascending = true) => throw new NotSupportedException();
internal override PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending, bool putNullValuesLast) => throw new NotSupportedException();
public new ArrowStringDataFrameColumn Clone(long numberOfNullsToAppend = 0)
{

Просмотреть файл

@ -176,22 +176,17 @@ namespace Microsoft.Data.Analysis
public override DataFrameColumn Filter<U>(U min, U max) => throw new NotSupportedException();
public new StringDataFrameColumn Sort(bool ascending = true)
public new StringDataFrameColumn Sort(bool ascending = true, bool putNullValuesLast = true)
{
PrimitiveDataFrameColumn<long> columnSortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
return Clone(columnSortIndices, !ascending, NullCount);
return (StringDataFrameColumn)base.Sort(ascending, putNullValuesLast);
}
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices)
internal override PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending, bool putNullValuesLast)
{
PrimitiveDataFrameColumn<long> columnSortIndices = GetSortIndices(Comparer<string>.Default, out nullIndices);
return columnSortIndices;
}
var comparer = Comparer<string>.Default;
private PrimitiveDataFrameColumn<long> GetSortIndices(Comparer<string> comparer, out Int64DataFrameColumn columnNullIndices)
{
List<int[]> bufferSortIndices = new List<int[]>(_stringBuffers.Count);
columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
var columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
long nullIndicesSlot = 0;
foreach (List<string> buffer in _stringBuffers)
{
@ -241,11 +236,21 @@ namespace Microsoft.Data.Analysis
heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferSortIndex.Item1, new List<ValueTuple<int, int>>() { (valueAndBufferSortIndex.Item2, i) });
}
}
PrimitiveDataFrameColumn<long> columnSortIndices = new PrimitiveDataFrameColumn<long>("SortIndices");
var columnSortIndices = new Int64DataFrameColumn("SortIndices", Length);
GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Length);
GetValueAndBufferSortIndexAtBuffer<string> getValueAtBuffer = new GetValueAndBufferSortIndexAtBuffer<string>((int bufferIndex, int sortIndex) => GetFirstNonNullValueStartingAtIndex(bufferIndex, sortIndex));
GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Length);
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAtBuffer, getBufferLengthAtIndex);
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex,
columnSortIndices,
columnNullIndices,
ascending,
putNullValuesLast,
getBufferSortIndex,
getValueAtBuffer,
getBufferLengthAtIndex);
return columnSortIndices;
}

Просмотреть файл

@ -394,5 +394,7 @@ namespace Microsoft.Data.Analysis
throw new NotSupportedException();
}
internal override PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending, bool putNullValuesLast) => throw new NotImplementedException();
}
}

Просмотреть файл

@ -12,22 +12,18 @@ namespace Microsoft.Data.Analysis
public partial class PrimitiveDataFrameColumn<T> : DataFrameColumn
where T : unmanaged
{
public new PrimitiveDataFrameColumn<T> Sort(bool ascending = true)
/// <inheritdoc/>
public new PrimitiveDataFrameColumn<T> Sort(bool ascending = true, bool putNullValuesLast = true)
{
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
return Clone(sortIndices, !ascending, NullCount);
return (PrimitiveDataFrameColumn<T>)base.Sort(ascending, putNullValuesLast);
}
internal override PrimitiveDataFrameColumn<long> GetAscendingSortIndices(out Int64DataFrameColumn nullIndices)
internal override PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending = true, bool putNullValuesLast = true)
{
Int64DataFrameColumn sortIndices = GetSortIndices(Comparer<T>.Default, out nullIndices);
return sortIndices;
}
var comparer = Comparer<T>.Default;
private Int64DataFrameColumn GetSortIndices(IComparer<T> comparer, out Int64DataFrameColumn columnNullIndices)
{
List<List<int>> bufferSortIndices = new List<List<int>>(_columnContainer.Buffers.Count);
columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
var columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount);
long nullIndicesSlot = 0;
// Sort each buffer first
for (int b = 0; b < _columnContainer.Buffers.Count; b++)
@ -57,6 +53,7 @@ namespace Microsoft.Data.Analysis
}
bufferSortIndices.Add(nonNullSortIndices);
}
// Simple merge sort to build the full column's sort indices
ValueTuple<T, int> GetFirstNonNullValueAndBufferIndexStartingAtIndex(int bufferIndex, int startIndex)
{
@ -80,6 +77,7 @@ namespace Microsoft.Data.Analysis
}
return (value, startIndex);
}
SortedDictionary<T, List<ValueTuple<int, int>>> heapOfValueAndListOfTupleOfSortAndBufferIndex = new SortedDictionary<T, List<ValueTuple<int, int>>>(comparer);
IList<ReadOnlyDataFrameBuffer<T>> buffers = _columnContainer.Buffers;
for (int i = 0; i < buffers.Count; i++)
@ -100,11 +98,20 @@ namespace Microsoft.Data.Analysis
heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferIndex.Item1, new List<ValueTuple<int, int>>() { (valueAndBufferIndex.Item2, i) });
}
}
Int64DataFrameColumn columnSortIndices = new Int64DataFrameColumn("SortIndices");
Int64DataFrameColumn columnSortIndices = new Int64DataFrameColumn("SortIndices", Length);
GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Count);
GetValueAndBufferSortIndexAtBuffer<T> getValueAndBufferSortIndexAtBuffer = new GetValueAndBufferSortIndexAtBuffer<T>((int bufferIndex, int sortIndex) => GetFirstNonNullValueAndBufferIndexStartingAtIndex(bufferIndex, sortIndex));
GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Count);
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAndBufferSortIndexAtBuffer, getBufferLengthAtIndex);
PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex,
columnSortIndices,
columnNullIndices,
ascending,
putNullValuesLast,
getBufferSortIndex,
getValueAndBufferSortIndexAtBuffer,
getBufferLengthAtIndex);
return columnSortIndices;
}

Просмотреть файл

@ -255,20 +255,20 @@ namespace Microsoft.Data.Analysis
public override double Median()
{
// Not the most efficient implementation. Using a selection algorithm here would be O(n) instead of O(nLogn)
if (Length == 0)
var notNullValuesCount = Length - NullCount;
if (notNullValuesCount == 0)
return 0;
PrimitiveDataFrameColumn<long> sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _);
long middle = sortIndices.Length / 2;
PrimitiveDataFrameColumn<long> sortIndices = GetSortIndices();
long middle = notNullValuesCount / 2;
double middleValue = (double)Convert.ChangeType(this[sortIndices[middle].Value].Value, typeof(double));
if (sortIndices.Length % 2 == 0)
if (notNullValuesCount % 2 == 0)
{
double otherMiddleValue = (double)Convert.ChangeType(this[sortIndices[middle - 1].Value].Value, typeof(double));
return (middleValue + otherMiddleValue) / 2;
}
else
{
return middleValue;
}
return middleValue;
}
public override double Mean()

Просмотреть файл

@ -238,29 +238,56 @@ namespace Microsoft.Data.Analysis.Tests
df.Columns["Int"][19] = -1;
df.Columns["Int"][5] = 2000;
// Sort by "Int" in ascending order
// Sort by "Int" in ascending order and nulls last
var sortedDf = df.OrderBy("Int");
Assert.Null(sortedDf.Columns["Int"][19]);
Assert.Equal(-1, sortedDf.Columns["Int"][0]);
Assert.Equal(100, sortedDf.Columns["Int"][17]);
Assert.Equal(2000, sortedDf.Columns["Int"][18]);
// Sort by "Int" in descending order
// Sort by "Int" in descending order and nulls last
sortedDf = df.OrderByDescending("Int");
Assert.Null(sortedDf.Columns["Int"][19]);
Assert.Equal(-1, sortedDf.Columns["Int"][18]);
Assert.Equal(100, sortedDf.Columns["Int"][1]);
Assert.Equal(2000, sortedDf.Columns["Int"][0]);
// Sort by "Int" in ascending order and nulls first
sortedDf = df.OrderBy("Int", putNullValuesLast: false);
Assert.Null(sortedDf.Columns["Int"][0]);
Assert.Equal(-1, sortedDf.Columns["Int"][1]);
Assert.Equal(100, sortedDf.Columns["Int"][18]);
Assert.Equal(2000, sortedDf.Columns["Int"][19]);
// Sort by "Int" in descending order and nulls first
sortedDf = df.OrderByDescending("Int", putNullValuesLast: false);
Assert.Null(sortedDf.Columns["Int"][0]);
Assert.Equal(-1, sortedDf.Columns["Int"][19]);
Assert.Equal(100, sortedDf.Columns["Int"][2]);
Assert.Equal(2000, sortedDf.Columns["Int"][1]);
// Sort by "String" in ascending order
// Sort by "String" in ascending order and nulls last
sortedDf = df.OrderBy("String");
Assert.Null(sortedDf.Columns["Int"][19]);
Assert.Equal(1, sortedDf.Columns["Int"][1]);
Assert.Equal(8, sortedDf.Columns["Int"][17]);
Assert.Equal(9, sortedDf.Columns["Int"][18]);
// Sort by "String" in descending order
// Sort by "String" in descending order and nulls last
sortedDf = df.OrderByDescending("String");
Assert.Null(sortedDf.Columns["Int"][19]);
Assert.Equal(8, sortedDf.Columns["Int"][1]);
Assert.Equal(9, sortedDf.Columns["Int"][0]);
// Sort by "String" in ascending order and nulls first
sortedDf = df.OrderBy("String", putNullValuesLast: false);
Assert.Null(sortedDf.Columns["Int"][0]);
Assert.Equal(1, sortedDf.Columns["Int"][2]);
Assert.Equal(8, sortedDf.Columns["Int"][18]);
Assert.Equal(9, sortedDf.Columns["Int"][19]);
// Sort by "String" in descending order and nulls first
sortedDf = df.OrderByDescending("String", putNullValuesLast: false);
Assert.Null(sortedDf.Columns["Int"][0]);
Assert.Equal(8, sortedDf.Columns["Int"][2]);
Assert.Equal(9, sortedDf.Columns["Int"][1]);
@ -1305,13 +1332,19 @@ namespace Microsoft.Data.Analysis.Tests
}
[Fact]
public void TestMeanMedian()
public void TestMean()
{
DataFrame df = MakeDataFrameWithNumericColumns(10, true, 0);
Assert.Equal(40.0 / 9.0, df["Decimal"].Mean());
Assert.Equal(4, df["Decimal"].Median());
}
[Fact]
public void TestMedian()
{
DataFrame df = MakeDataFrameWithNumericColumns(10, true, 0);
Assert.Equal(4, df["Decimal"].Median());
}
[Fact]