Implement DataFrameColumn Apply and DropNulls methods (#7123)

* Refactor DropNulls method

* Refactor ArrowStringDataFrameColumn Clone()

* Refactor StringDataFrameColumn Clone()

* Refactor VBufferDataFrameColumn Clone and FillNulls methods

* Add Apply method

* Add DropNulls method to columns

* Add comments

* Fix warnings in tests

* Fix url to gpt2 ext vocabulary
This commit is contained in:
Aleksei Smirnov 2024-06-11 19:10:46 +03:00 коммит произвёл GitHub
Родитель b67107b157
Коммит 0c51cc6b90
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
13 изменённых файлов: 543 добавлений и 200 удалений

Просмотреть файл

@ -192,15 +192,15 @@ namespace Microsoft.Data.Analysis
/// </summary> /// </summary>
public DataFrame Clone() public DataFrame Clone()
{ {
return Clone(mapIndices: null, invertMapIndices: false); return Clone(mapIndices: null);
} }
private DataFrame Clone(DataFrameColumn mapIndices = null, bool invertMapIndices = false) private DataFrame Clone(DataFrameColumn mapIndices = null)
{ {
List<DataFrameColumn> newColumns = new List<DataFrameColumn>(Columns.Count); List<DataFrameColumn> newColumns = new List<DataFrameColumn>(Columns.Count);
for (int i = 0; i < Columns.Count; i++) for (int i = 0; i < Columns.Count; i++)
{ {
newColumns.Add(Columns[i].Clone(mapIndices, invertMapIndices)); newColumns.Add(Columns[i].Clone(mapIndices));
} }
return new DataFrame(newColumns); return new DataFrame(newColumns);
} }
@ -411,31 +411,46 @@ namespace Microsoft.Data.Analysis
/// <param name="options"></param> /// <param name="options"></param>
public DataFrame DropNulls(DropNullOptions options = DropNullOptions.Any) public DataFrame DropNulls(DropNullOptions options = DropNullOptions.Any)
{ {
DataFrame ret = new DataFrame(); var filter = new BooleanDataFrameColumn("Filter");
PrimitiveDataFrameColumn<bool> filter = new PrimitiveDataFrameColumn<bool>("Filter");
if (options == DropNullOptions.Any) if (options == DropNullOptions.Any)
{ {
filter.AppendMany(true, Rows.Count); filter.AppendMany(true, Rows.Count);
var buffers = filter.ColumnContainer.Buffers;
for (int i = 0; i < Columns.Count; i++) foreach (var column in Columns)
{ {
DataFrameColumn column = Columns[i]; long index = 0;
filter.ApplyElementwise((bool? value, long index) => for (int b = 0; b < buffers.Count; b++)
{ {
return value.Value && (column[index] == null ? false : true); var span = buffers.GetOrCreateMutable(b).Span;
});
for (int i = 0; i < span.Length; i++)
{
span[i] = span[i] && column.IsValid(index);
index++;
}
}
} }
} }
else else
{ {
filter.AppendMany(false, Rows.Count); filter.AppendMany(false, Rows.Count);
for (int i = 0; i < Columns.Count; i++) var buffers = filter.ColumnContainer.Buffers;
foreach (var column in Columns)
{ {
DataFrameColumn column = Columns[i]; long index = 0;
filter.ApplyElementwise((bool? value, long index) => for (int b = 0; b < buffers.Count; b++)
{ {
return value.Value || (column[index] == null ? false : true); var span = buffers.GetOrCreateMutable(b).Span;
});
for (int i = 0; i < span.Length; i++)
{
span[i] = span[i] || column.IsValid(index);
index++;
}
}
} }
} }
return this[filter]; return this[filter];

Просмотреть файл

@ -140,6 +140,13 @@ namespace Microsoft.Data.Analysis
[Obsolete] [Obsolete]
public void SetName(string newName, DataFrame dataFrame) => SetName(newName); public void SetName(string newName, DataFrame dataFrame) => SetName(newName);
/// <summary>
/// Indicates if the value at this <paramref name="index"/> is valid (not <see langword="null"/>).
/// </summary>
/// <param name="index">The index to look up.</param>
/// <returns>A boolean value indicating the validity at this <paramref name="index"/>.</returns>
public virtual bool IsValid(long index) => NullCount == 0 || this[index] != null;
/// <summary> /// <summary>
/// The type of data this column holds. /// The type of data this column holds.
/// </summary> /// </summary>
@ -300,7 +307,14 @@ namespace Microsoft.Data.Analysis
/// <param name="inPlace">Indicates if the operation should be performed in place</param> /// <param name="inPlace">Indicates if the operation should be performed in place</param>
public virtual DataFrameColumn FillNulls(object value, bool inPlace = false) => FillNullsImplementation(value, inPlace); public virtual DataFrameColumn FillNulls(object value, bool inPlace = false) => FillNullsImplementation(value, inPlace);
protected virtual DataFrameColumn FillNullsImplementation(object value, bool inPlace) => throw new NotImplementedException(); protected abstract DataFrameColumn FillNullsImplementation(object value, bool inPlace);
/// <summary>
/// Returns a <see cref="DataFrameColumn"/> with no missing values.
/// </summary>
public virtual DataFrameColumn DropNulls() => DropNullsImplementation();
protected abstract DataFrameColumn DropNullsImplementation();
// Arrow related APIs // Arrow related APIs
protected internal virtual Field GetArrowField() => throw new NotImplementedException(); protected internal virtual Field GetArrowField() => throw new NotImplementedException();

Просмотреть файл

@ -69,12 +69,8 @@ namespace Microsoft.Data.Analysis
/// <inheritdoc/> /// <inheritdoc/>
public override long NullCount => _nullCount; public override long NullCount => _nullCount;
/// <summary> /// <inheritdoc/>
/// Indicates if the value at this <paramref name="index"/> is <see langword="null" />. public override bool IsValid(long index) => NullCount == 0 || GetValidityBit(index);
/// </summary>
/// <param name="index">The index to look up.</param>
/// <returns>A boolean value indicating the validity at this <paramref name="index"/>.</returns>
public bool IsValid(long index) => NullCount == 0 || GetValidityBit(index);
private bool GetValidityBit(long index) private bool GetValidityBit(long index)
{ {
@ -435,24 +431,42 @@ namespace Microsoft.Data.Analysis
return ret; return ret;
} }
private ArrowStringDataFrameColumn CloneImplementation<U>(PrimitiveDataFrameColumn<U> mapIndices, bool invertMapIndices) private ArrowStringDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn<int> mapIndices, bool invertMapIndices)
where U : unmanaged
{ {
ArrowStringDataFrameColumn ret = new ArrowStringDataFrameColumn(Name); ArrowStringDataFrameColumn ret = new ArrowStringDataFrameColumn(Name);
mapIndices.ApplyElementwise((U? mapIndex, long rowIndex) => for (long i = 0; i < mapIndices.Length; i++)
{ {
if (mapIndex == null) int? index = mapIndices[invertMapIndices ? mapIndices.Length - 1 - i : i];
if (index == null)
{ {
ret.Append(default); ret.Append(default);
return mapIndex; continue;
} }
long index = invertMapIndices ? mapIndices.Length - 1 - rowIndex : rowIndex; ret.Append(IsValid(index.Value) ? GetBytes(index.Value) : default(ReadOnlySpan<byte>));
ret.Append(IsValid(index) ? GetBytes(index) : default(ReadOnlySpan<byte>)); }
return mapIndex; return ret;
}); }
private ArrowStringDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn<long> mapIndices, bool invertMapIndices)
{
ArrowStringDataFrameColumn ret = new ArrowStringDataFrameColumn(Name);
for (long i = 0; i < mapIndices.Length; i++)
{
long? index = mapIndices[invertMapIndices ? mapIndices.Length - 1 - i : i];
if (index == null)
{
ret.Append(default);
continue;
}
ret.Append(IsValid(index.Value) ? GetBytes(index.Value) : default(ReadOnlySpan<byte>));
}
return ret; return ret;
} }
@ -541,6 +555,25 @@ namespace Microsoft.Data.Analysis
} }
} }
/// <inheritdoc/>
public new ArrowStringDataFrameColumn DropNulls()
{
return (ArrowStringDataFrameColumn)DropNullsImplementation();
}
protected override DataFrameColumn DropNullsImplementation()
{
var ret = new ArrowStringDataFrameColumn(Name);
for (long i = 0; i < Length; i++)
{
if (IsValid(i))
ret.Append(GetBytes(i));
}
return ret;
}
public override DataFrameColumn Clamp<U>(U min, U max, bool inPlace = false) => throw new NotSupportedException(); public override DataFrameColumn Clamp<U>(U min, U max, bool inPlace = false) => throw new NotSupportedException();
public override DataFrameColumn Filter<U>(U min, U max) => throw new NotSupportedException(); public override DataFrameColumn Filter<U>(U min, U max) => throw new NotSupportedException();

Просмотреть файл

@ -79,6 +79,27 @@ namespace Microsoft.Data.Analysis
Length++; Length++;
} }
/// <summary>
/// Applies a function to all values in the column, that are not null.
/// </summary>
/// <param name="func">The function to apply.</param>
/// /// <param name="inPlace">A boolean flag to indicate if the operation should be in place.</param>
/// <returns>A new <see cref="PrimitiveDataFrameColumn{T}"/> if <paramref name="inPlace"/> is not set. Returns this column otherwise.</returns>
public StringDataFrameColumn Apply(Func<string, string> func, bool inPlace = false)
{
var column = inPlace ? this : Clone();
for (long i = 0; i < column.Length; i++)
{
var value = column[i];
if (value != null)
column[i] = func(value);
}
return column;
}
private int GetBufferIndexContainingRowIndex(long rowIndex) private int GetBufferIndexContainingRowIndex(long rowIndex)
{ {
if (rowIndex >= Length) if (rowIndex >= Length)
@ -310,88 +331,51 @@ namespace Microsoft.Data.Analysis
for (long i = 0; i < boolColumn.Length; i++) for (long i = 0; i < boolColumn.Length; i++)
{ {
bool? value = boolColumn[i]; bool? value = boolColumn[i];
if (value.HasValue && value.Value == true) if (value.HasValue && value.Value)
ret.Append(this[i]); ret.Append(this[i]);
} }
return ret; return ret;
} }
private StringDataFrameColumn CloneImplementation<U>(PrimitiveDataFrameColumn<U> mapIndices, bool invertMapIndices = false) private StringDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn<int> mapIndices, bool invertMapIndices = false)
where U : unmanaged
{ {
mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices)); mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices));
StringDataFrameColumn ret = new StringDataFrameColumn(Name, mapIndices.Length); var ret = new StringDataFrameColumn(Name, mapIndices.Length);
List<string> setBuffer = ret._stringBuffers[0]; long rowIndex = 0;
long setBufferMinRange = 0; for (int b = 0; b < mapIndices.ColumnContainer.Buffers.Count; b++)
long setBufferMaxRange = MaxCapacity;
List<string> getBuffer = _stringBuffers[0];
long getBufferMinRange = 0;
long getBufferMaxRange = MaxCapacity;
long maxCapacity = MaxCapacity;
if (mapIndices.DataType == typeof(long))
{ {
PrimitiveDataFrameColumn<long> longMapIndices = mapIndices as PrimitiveDataFrameColumn<long>; var span = mapIndices.ColumnContainer.Buffers[b].ReadOnlySpan;
longMapIndices.ApplyElementwise((long? mapIndex, long rowIndex) => var validitySpan = mapIndices.ColumnContainer.NullBitMapBuffers[b].ReadOnlySpan;
for (int i = 0; i < span.Length; i++)
{ {
long index = rowIndex; long index = invertMapIndices ? mapIndices.Length - 1 - rowIndex : rowIndex;
if (invertMapIndices) ret[index] = BitUtility.IsValid(validitySpan, i) ? this[span[i]] : null;
index = longMapIndices.Length - 1 - index; rowIndex++;
if (index < setBufferMinRange || index >= setBufferMaxRange) }
{
int bufferIndex = (int)(index / maxCapacity);
setBuffer = ret._stringBuffers[bufferIndex];
setBufferMinRange = bufferIndex * maxCapacity;
setBufferMaxRange = (bufferIndex + 1) * maxCapacity;
}
index -= setBufferMinRange;
if (mapIndex == null)
{
setBuffer[(int)index] = null;
return mapIndex;
}
if (mapIndex.Value < getBufferMinRange || mapIndex.Value >= getBufferMaxRange)
{
int bufferIndex = (int)(mapIndex.Value / maxCapacity);
getBuffer = _stringBuffers[bufferIndex];
getBufferMinRange = bufferIndex * maxCapacity;
getBufferMaxRange = (bufferIndex + 1) * maxCapacity;
}
int bufferLocalMapIndex = (int)(mapIndex - getBufferMinRange);
string value = getBuffer[bufferLocalMapIndex];
setBuffer[(int)index] = value;
if (value != null)
ret._nullCount--;
return mapIndex;
});
} }
else if (mapIndices.DataType == typeof(int))
return ret;
}
private StringDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn<long> mapIndices, bool invertMapIndices = false)
{
mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices));
var ret = new StringDataFrameColumn(Name, mapIndices.Length);
long rowIndex = 0;
for (int b = 0; b < mapIndices.ColumnContainer.Buffers.Count; b++)
{ {
PrimitiveDataFrameColumn<int> intMapIndices = mapIndices as PrimitiveDataFrameColumn<int>; var span = mapIndices.ColumnContainer.Buffers[b].ReadOnlySpan;
intMapIndices.ApplyElementwise((int? mapIndex, long rowIndex) => var validitySpan = mapIndices.ColumnContainer.NullBitMapBuffers[b].ReadOnlySpan;
for (int i = 0; i < span.Length; i++)
{ {
long index = rowIndex; long index = invertMapIndices ? mapIndices.Length - 1 - rowIndex : rowIndex;
if (invertMapIndices) ret[index] = BitUtility.IsValid(validitySpan, i) ? this[span[i]] : null;
index = intMapIndices.Length - 1 - index; rowIndex++;
}
if (mapIndex == null)
{
setBuffer[(int)index] = null;
return mapIndex;
}
string value = getBuffer[mapIndex.Value];
setBuffer[(int)index] = value;
if (value != null)
ret._nullCount--;
return mapIndex;
});
}
else
{
Debug.Assert(false, nameof(mapIndices.DataType));
} }
return ret; return ret;
@ -455,6 +439,12 @@ namespace Microsoft.Data.Analysis
} }
} }
/// <summary>
/// Returns a new column with <see langword="null" /> elements replaced by <paramref name="value"/>.
/// </summary>
/// <remarks>Tries to convert value to the column's DataType</remarks>
/// <param name="value"></param>
/// <param name="inPlace">Indicates if the operation should be performed in place</param>
public StringDataFrameColumn FillNulls(string value, bool inPlace = false) public StringDataFrameColumn FillNulls(string value, bool inPlace = false)
{ {
if (value == null) if (value == null)
@ -477,6 +467,30 @@ namespace Microsoft.Data.Analysis
throw new ArgumentException(String.Format(Strings.MismatchedValueType, typeof(string)), nameof(value)); throw new ArgumentException(String.Format(Strings.MismatchedValueType, typeof(string)), nameof(value));
} }
/// <inheritdoc/>
public new StringDataFrameColumn DropNulls()
{
return (StringDataFrameColumn)DropNullsImplementation();
}
protected override DataFrameColumn DropNullsImplementation()
{
var ret = new StringDataFrameColumn(Name, Length - NullCount);
long j = 0;
for (long i = 0; i < Length; i++)
{
var value = this[i];
if (value != null)
{
ret[j++] = value;
}
}
return ret;
}
protected internal override void AddDataViewColumn(DataViewSchema.Builder builder) protected internal override void AddDataViewColumn(DataViewSchema.Builder builder)
{ {
builder.AddColumn(Name, TextDataViewType.Instance); builder.AddColumn(Name, TextDataViewType.Instance);

Просмотреть файл

@ -223,68 +223,45 @@ namespace Microsoft.Data.Analysis
return ret; return ret;
} }
private VBufferDataFrameColumn<T> CloneImplementation<U>(PrimitiveDataFrameColumn<U> mapIndices, bool invertMapIndices = false, long numberOfNullsToAppend = 0) private VBufferDataFrameColumn<T> CloneImplementation(PrimitiveDataFrameColumn<long> mapIndices, bool invertMapIndices = false)
where U : unmanaged
{ {
mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices)); mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices));
VBufferDataFrameColumn<T> ret = new VBufferDataFrameColumn<T>(Name, mapIndices.Length); var ret = new VBufferDataFrameColumn<T>(Name, mapIndices.Length);
List<VBuffer<T>> setBuffer = ret._vBuffers[0]; long rowIndex = 0;
long setBufferMinRange = 0; for (int b = 0; b < mapIndices.ColumnContainer.Buffers.Count; b++)
long setBufferMaxRange = MaxCapacity;
List<VBuffer<T>> getBuffer = _vBuffers[0];
long getBufferMinRange = 0;
long getBufferMaxRange = MaxCapacity;
long maxCapacity = MaxCapacity;
if (mapIndices.DataType == typeof(long))
{ {
PrimitiveDataFrameColumn<long> longMapIndices = mapIndices as PrimitiveDataFrameColumn<long>; var span = mapIndices.ColumnContainer.Buffers[b].ReadOnlySpan;
longMapIndices.ApplyElementwise((long? mapIndex, long rowIndex) => var validitySpan = mapIndices.ColumnContainer.NullBitMapBuffers[b].ReadOnlySpan;
for (int i = 0; i < span.Length; i++)
{ {
long index = rowIndex; long index = invertMapIndices ? mapIndices.Length - 1 - rowIndex : rowIndex;
if (invertMapIndices) ret[index] = BitUtility.IsValid(validitySpan, i) ? this[span[i]] : default;
index = longMapIndices.Length - 1 - index; rowIndex++;
if (index < setBufferMinRange || index >= setBufferMaxRange) }
{
int bufferIndex = (int)(index / maxCapacity);
setBuffer = ret._vBuffers[bufferIndex];
setBufferMinRange = bufferIndex * maxCapacity;
setBufferMaxRange = (bufferIndex + 1) * maxCapacity;
}
index -= setBufferMinRange;
if (mapIndex.Value < getBufferMinRange || mapIndex.Value >= getBufferMaxRange)
{
int bufferIndex = (int)(mapIndex.Value / maxCapacity);
getBuffer = _vBuffers[bufferIndex];
getBufferMinRange = bufferIndex * maxCapacity;
getBufferMaxRange = (bufferIndex + 1) * maxCapacity;
}
int bufferLocalMapIndex = (int)(mapIndex - getBufferMinRange);
VBuffer<T> value = getBuffer[bufferLocalMapIndex];
setBuffer[(int)index] = value;
return mapIndex;
});
} }
else if (mapIndices.DataType == typeof(int))
return ret;
}
private VBufferDataFrameColumn<T> CloneImplementation(PrimitiveDataFrameColumn<int> mapIndices, bool invertMapIndices = false)
{
mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices));
var ret = new VBufferDataFrameColumn<T>(Name, mapIndices.Length);
long rowIndex = 0;
for (int b = 0; b < mapIndices.ColumnContainer.Buffers.Count; b++)
{ {
PrimitiveDataFrameColumn<int> intMapIndices = mapIndices as PrimitiveDataFrameColumn<int>; var span = mapIndices.ColumnContainer.Buffers[b].ReadOnlySpan;
intMapIndices.ApplyElementwise((int? mapIndex, long rowIndex) => var validitySpan = mapIndices.ColumnContainer.NullBitMapBuffers[b].ReadOnlySpan;
for (int i = 0; i < span.Length; i++)
{ {
long index = rowIndex; long index = invertMapIndices ? mapIndices.Length - 1 - rowIndex : rowIndex;
if (invertMapIndices) ret[index] = BitUtility.IsValid(validitySpan, i) ? this[span[i]] : default;
index = intMapIndices.Length - 1 - index; rowIndex++;
}
VBuffer<T> value = getBuffer[mapIndex.Value];
setBuffer[(int)index] = value;
return mapIndex;
});
}
else
{
Debug.Assert(false, nameof(mapIndices.DataType));
} }
return ret; return ret;
@ -395,6 +372,18 @@ namespace Microsoft.Data.Analysis
throw new NotSupportedException(); throw new NotSupportedException();
} }
protected override DataFrameColumn FillNullsImplementation(object value, bool inPlace)
{
//Do nothing as VBufferColumn doesn't have null values
return inPlace ? this : Clone();
}
protected override DataFrameColumn DropNullsImplementation()
{
//Do nothing as VBufferColumn doesn't have null values
return Clone();
}
internal override PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending, bool putNullValuesLast) => throw new NotImplementedException(); internal override PrimitiveDataFrameColumn<long> GetSortIndices(bool ascending, bool putNullValuesLast) => throw new NotImplementedException();
} }
} }

Просмотреть файл

@ -177,48 +177,86 @@ namespace Microsoft.Data.Analysis
public void ApplyElementwise(Func<T?, long, T?> func) public void ApplyElementwise(Func<T?, long, T?> func)
{ {
var bufferMaxCapacity = ReadOnlyDataFrameBuffer<T>.MaxCapacity; long curIndex = 0;
for (int b = 0; b < Buffers.Count; b++) for (int b = 0; b < Buffers.Count; b++)
{ {
long prevLength = checked(bufferMaxCapacity * b);
Span<T> mutableBuffer = Buffers.GetOrCreateMutable(b).Span; Span<T> mutableBuffer = Buffers.GetOrCreateMutable(b).Span;
Span<byte> mutableNullBitMapBuffer = NullBitMapBuffers.GetOrCreateMutable(b).Span; Span<byte> mutableNullBitMapBuffer = NullBitMapBuffers.GetOrCreateMutable(b).Span;
for (int i = 0; i < mutableBuffer.Length; i++) for (int i = 0; i < mutableBuffer.Length; i++)
{ {
long curIndex = i + prevLength;
bool isValid = BitUtility.IsValid(mutableNullBitMapBuffer, i); bool isValid = BitUtility.IsValid(mutableNullBitMapBuffer, i);
T? value = func(isValid ? mutableBuffer[i] : null, curIndex); T? value = func(isValid ? mutableBuffer[i] : null, curIndex);
mutableBuffer[i] = value.GetValueOrDefault(); mutableBuffer[i] = value.GetValueOrDefault();
SetValidityBit(mutableNullBitMapBuffer, i, value != null); SetValidityBit(mutableNullBitMapBuffer, i, value != null);
curIndex++;
} }
} }
} }
public void Apply(Func<T, T> func)
{
for (int b = 0; b < Buffers.Count; b++)
{
var span = Buffers.GetOrCreateMutable(b).Span;
var validitySpan = NullBitMapBuffers.GetOrCreateMutable(b).Span;
for (int i = 0; i < span.Length; i++)
{
if (NullCount == 0 || BitUtility.IsValid(validitySpan, i))
{
span[i] = func(span[i]);
}
}
}
}
[Obsolete]
public void Apply<TResult>(Func<T?, TResult?> func, PrimitiveColumnContainer<TResult> resultContainer) public void Apply<TResult>(Func<T?, TResult?> func, PrimitiveColumnContainer<TResult> resultContainer)
where TResult : unmanaged where TResult : unmanaged
{ {
var bufferMaxCapacity = ReadOnlyDataFrameBuffer<T>.MaxCapacity;
for (int b = 0; b < Buffers.Count; b++) for (int b = 0; b < Buffers.Count; b++)
{ {
long prevLength = checked(bufferMaxCapacity * b);
var sourceBuffer = Buffers[b]; var sourceBuffer = Buffers[b];
var sourceNullBitMap = NullBitMapBuffers[b].ReadOnlySpan; var sourceNullBitMap = NullBitMapBuffers[b].ReadOnlySpan;
Span<TResult> mutableResultBuffer = resultContainer.Buffers.GetOrCreateMutable(b).Span; Span<TResult> mutableResultBuffer = resultContainer.Buffers.GetOrCreateMutable(b).Span;
Span<byte> mutableResultNullBitMapBuffers = resultContainer.NullBitMapBuffers.GetOrCreateMutable(b).Span; Span<byte> mutableResultNullBitMapBuffer = resultContainer.NullBitMapBuffers.GetOrCreateMutable(b).Span;
for (int i = 0; i < sourceBuffer.Length; i++) for (int i = 0; i < sourceBuffer.Length; i++)
{ {
bool isValid = BitUtility.IsValid(sourceNullBitMap, i); bool isValid = BitUtility.IsValid(sourceNullBitMap, i);
TResult? value = func(isValid ? sourceBuffer[i] : null); TResult? value = func(isValid ? sourceBuffer[i] : null);
mutableResultBuffer[i] = value.GetValueOrDefault(); mutableResultBuffer[i] = value.GetValueOrDefault();
resultContainer.SetValidityBit(mutableResultNullBitMapBuffers, i, value != null); //Actually there is a bug in the previouse line. This code will not work correctly with containers having more than 1 buffers
//As buffer size for type T (sourceBuffer) is different from the size of buffer for type TResult (mutableResultBuffer) in case sizeof(T) not equal to sizeof(TResult)
//TODO fix (https://github.com/dotnet/machinelearning/issues/7122)
resultContainer.SetValidityBit(mutableResultNullBitMapBuffer, i, value != null);
} }
} }
} }
public void FillNulls(T value)
{
for (int b = 0; b < Buffers.Count; b++)
{
var span = Buffers.GetOrCreateMutable(b).Span;
var validitySpan = NullBitMapBuffers.GetOrCreateMutable(b).Span;
for (int i = 0; i < span.Length; i++)
{
if (BitUtility.IsValid(validitySpan, i))
continue;
span[i] = value;
BitUtility.SetBit(validitySpan, i, true);
}
}
NullCount = 0;
}
public bool IsValid(long index) => NullCount == 0 || GetValidityBit(index); public bool IsValid(long index) => NullCount == 0 || GetValidityBit(index);
private byte SetBit(byte curBitMap, int index, bool value) private byte SetBit(byte curBitMap, int index, bool value)

Просмотреть файл

@ -252,6 +252,7 @@ namespace Microsoft.Data.Analysis
set => _columnContainer[rowIndex] = value; set => _columnContainer[rowIndex] = value;
} }
/// <inheritdoc/>
public override double Median() public override double Median()
{ {
// Not the most efficient implementation. Using a selection algorithm here would be O(n) instead of O(nLogn) // Not the most efficient implementation. Using a selection algorithm here would be O(n) instead of O(nLogn)
@ -271,6 +272,7 @@ namespace Microsoft.Data.Analysis
return middleValue; return middleValue;
} }
/// <inheritdoc/>
public override double Mean() public override double Mean()
{ {
if (Length == 0) if (Length == 0)
@ -296,6 +298,7 @@ namespace Microsoft.Data.Analysis
Length += count; Length += count;
} }
/// <inheritdoc/>
public override long NullCount public override long NullCount
{ {
get get
@ -305,35 +308,40 @@ namespace Microsoft.Data.Analysis
} }
} }
public bool IsValid(long index) => _columnContainer.IsValid(index); /// <inheritdoc/>
public override bool IsValid(long index) => _columnContainer.IsValid(index);
public IEnumerator<T?> GetEnumerator() => _columnContainer.GetEnumerator(); public IEnumerator<T?> GetEnumerator() => _columnContainer.GetEnumerator();
protected override IEnumerator GetEnumeratorCore() => GetEnumerator(); protected override IEnumerator GetEnumeratorCore() => GetEnumerator();
/// <inheritdoc/>
public override bool IsNumericColumn() public override bool IsNumericColumn()
{ {
bool ret = true; var type = typeof(T);
if (typeof(T) == typeof(char) || typeof(T) == typeof(bool) || typeof(T) == typeof(DateTime))
ret = false; return type == typeof(byte)
return ret; || type == typeof(sbyte)
|| type == typeof(ushort)
|| type == typeof(short)
|| type == typeof(uint)
|| type == typeof(int)
|| type == typeof(ulong)
|| type == typeof(long)
|| type == typeof(float)
|| type == typeof(double)
|| type == typeof(decimal);
} }
/// <summary> /// <summary>
/// Returns a new column with nulls replaced by value /// Returns a new column with <see langword="null" /> elements replaced by <paramref name="value"/>.
/// </summary> /// </summary>
/// <param name="value"></param> /// <param name="value"></param>
/// <param name="inPlace">Indicates if the operation should be performed in place</param> /// <param name="inPlace">Indicates if the operation should be performed in place.</param>
public PrimitiveDataFrameColumn<T> FillNulls(T value, bool inPlace = false) public PrimitiveDataFrameColumn<T> FillNulls(T value, bool inPlace = false)
{ {
PrimitiveDataFrameColumn<T> column = inPlace ? this : Clone(); PrimitiveDataFrameColumn<T> column = inPlace ? this : Clone();
column.ApplyElementwise((T? columnValue, long index) => column.ColumnContainer.FillNulls(value);
{
if (columnValue.HasValue == false)
return value;
else
return columnValue.Value;
});
return column; return column;
} }
@ -343,6 +351,33 @@ namespace Microsoft.Data.Analysis
return FillNulls(convertedValue, inPlace); return FillNulls(convertedValue, inPlace);
} }
/// <inheritdoc/>
public new PrimitiveDataFrameColumn<T> DropNulls()
{
return (PrimitiveDataFrameColumn<T>)DropNullsImplementation();
}
protected override DataFrameColumn DropNullsImplementation()
{
var ret = CreateNewColumn(Name, Length - NullCount);
long j = 0;
for (int b = 0; b < ColumnContainer.NullBitMapBuffers.Count; b++)
{
var span = ColumnContainer.Buffers[b].ReadOnlySpan;
var validitySpan = ColumnContainer.NullBitMapBuffers[b].ReadOnlySpan;
for (int i = 0; i < span.Length; i++)
{
if (BitUtility.IsValid(validitySpan, i))
ret[j++] = span[i];
}
}
return ret;
}
/// <inheritdoc/>
public override DataFrame ValueCounts() public override DataFrame ValueCounts()
{ {
Dictionary<T, ICollection<long>> groupedValues = GroupColumnValues<T>(out HashSet<long> _); Dictionary<T, ICollection<long>> groupedValues = GroupColumnValues<T>(out HashSet<long> _);
@ -427,6 +462,7 @@ namespace Microsoft.Data.Analysis
if (boolColumn.Length > Length) if (boolColumn.Length > Length)
throw new ArgumentException(Strings.MapIndicesExceedsColumnLength, nameof(boolColumn)); throw new ArgumentException(Strings.MapIndicesExceedsColumnLength, nameof(boolColumn));
PrimitiveDataFrameColumn<T> ret = CreateNewColumn(Name); PrimitiveDataFrameColumn<T> ret = CreateNewColumn(Name);
for (long i = 0; i < boolColumn.Length; i++) for (long i = 0; i < boolColumn.Length; i++)
{ {
bool? value = boolColumn[i]; bool? value = boolColumn[i];
@ -615,14 +651,33 @@ namespace Microsoft.Data.Analysis
} }
} }
/// <summary>
/// Applies a function to all column values in place.
/// </summary>
/// <param name="func">The function to apply</param>
[Obsolete("Method is obsolete, use Apply(Func<T, T> func, bool inPlace = false) instead")]
public void ApplyElementwise(Func<T?, long, T?> func) => _columnContainer.ApplyElementwise(func); public void ApplyElementwise(Func<T?, long, T?> func) => _columnContainer.ApplyElementwise(func);
/// <summary> /// <summary>
/// Applies a function to all the values /// Applies a function to all values in the column, that are not null.
/// </summary>
/// <param name="func">The function to apply.</param>
/// /// <param name="inPlace">A boolean flag to indicate if the operation should be in place.</param>
/// <returns>A new <see cref="PrimitiveDataFrameColumn{T}"/> if <paramref name="inPlace"/> is not set. Returns this column otherwise.</returns>
public PrimitiveDataFrameColumn<T> Apply(Func<T, T> func, bool inPlace = false)
{
var column = inPlace ? this : this.Clone();
column.ColumnContainer.Apply(func);
return column;
}
/// <summary>
/// Applies a function to all column values.
/// </summary> /// </summary>
/// <typeparam name="TResult">The new column's type</typeparam> /// <typeparam name="TResult">The new column's type</typeparam>
/// <param name="func">The function to apply</param> /// <param name="func">The function to apply</param>
/// <returns>A new PrimitiveDataFrameColumn containing the new values</returns> /// <returns>A new PrimitiveDataFrameColumn containing the new values</returns>
[Obsolete("Method is obsolete, use Apply(Func<T, T> func, bool inPlace = false) instead")]
public PrimitiveDataFrameColumn<TResult> Apply<TResult>(Func<T?, TResult?> func) where TResult : unmanaged public PrimitiveDataFrameColumn<TResult> Apply<TResult>(Func<T?, TResult?> func) where TResult : unmanaged
{ {
var resultColumn = new PrimitiveDataFrameColumn<TResult>("Result", Length); var resultColumn = new PrimitiveDataFrameColumn<TResult>("Result", Length);

Просмотреть файл

@ -78,7 +78,7 @@
The following files are compressed using the DeflateStream and embedded as resources in the assembly. The following files are compressed using the DeflateStream and embedded as resources in the assembly.
The files are downloaded from the following sources and compressed to the Destination. The files are downloaded from the following sources and compressed to the Destination.
1. cl100k_base.tiktoken: https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken 1. cl100k_base.tiktoken: https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken
2. gpt2.tiktoken: https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken 2. gpt2.tiktoken: https://fossies.org/linux/misc/whisper-20231117.tar.gz/whisper-20231117/whisper/assets/gpt2.tiktoken?m=b
3. p50k_base.tiktoken: https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken 3. p50k_base.tiktoken: https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken
4. r50k_base.tiktoken: https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken 4. r50k_base.tiktoken: https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken
5. o200k_base.tiktoken https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken 5. o200k_base.tiktoken https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken

Просмотреть файл

@ -1151,7 +1151,7 @@ namespace Microsoft.ML.Tokenizers
private const string Cl100kBaseVocabFile = "cl100k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" private const string Cl100kBaseVocabFile = "cl100k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
private const string P50RanksFile = "p50k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" private const string P50RanksFile = "p50k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
private const string R50RanksFile = "r50k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken" private const string R50RanksFile = "r50k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
private const string GPT2File = "gpt2.tiktoken.deflate"; // "https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken" private const string GPT2File = "gpt2.tiktoken.deflate"; // "https://fossies.org/linux/misc/whisper-20231117.tar.gz/whisper-20231117/whisper/assets/gpt2.tiktoken?m=b"
private const string O200kBaseFile = "o200k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" private const string O200kBaseFile = "o200k_base.tiktoken.deflate"; // "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
internal const string Cl100kBaseEncodingName = "cl100k_base"; internal const string Cl100kBaseEncodingName = "cl100k_base";

Просмотреть файл

@ -781,9 +781,12 @@ namespace Microsoft.Data.Analysis.Tests
[Fact] [Fact]
public void TestDropNulls() public void TestDropNulls()
{ {
//Create dataframe with 20 rows, where 1 row has only 1 null value and 1 row has all null values
DataFrame df = MakeDataFrameWithAllMutableColumnTypes(20); DataFrame df = MakeDataFrameWithAllMutableColumnTypes(20);
df[0, 0] = null;
DataFrame anyNulls = df.DropNulls(); DataFrame anyNulls = df.DropNulls();
Assert.Equal(19, anyNulls.Rows.Count); Assert.Equal(18, anyNulls.Rows.Count);
DataFrame allNulls = df.DropNulls(DropNullOptions.All); DataFrame allNulls = df.DropNulls(DropNullOptions.All);
Assert.Equal(19, allNulls.Rows.Count); Assert.Equal(19, allNulls.Rows.Count);
@ -859,6 +862,7 @@ namespace Microsoft.Data.Analysis.Tests
Assert.Equal((long)5, valueCounts.Columns["Counts"][1]); Assert.Equal((long)5, valueCounts.Columns["Counts"][1]);
} }
#pragma warning disable CS0612, CS0618 // Type or member is obsolete
[Fact] [Fact]
public void TestApplyElementwiseNullCount() public void TestApplyElementwiseNullCount()
{ {
@ -867,12 +871,14 @@ namespace Microsoft.Data.Analysis.Tests
Assert.Equal(1, column.NullCount); Assert.Equal(1, column.NullCount);
// Change all existing values to null // Change all existing values to null
column.ApplyElementwise((int? value, long rowIndex) => column.ApplyElementwise((int? value, long rowIndex) =>
{ {
if (!(value is null)) if (!(value is null))
return null; return null;
return value; return value;
}); });
Assert.Equal(column.Length, column.NullCount); Assert.Equal(column.Length, column.NullCount);
// Don't change null values // Don't change null values
@ -897,6 +903,7 @@ namespace Microsoft.Data.Analysis.Tests
Assert.Equal(0, column.NullCount); Assert.Equal(0, column.NullCount);
} }
#pragma warning restore CS0612, CS0618 // Type or member is obsolete
[Theory] [Theory]
[InlineData(10, 5)] [InlineData(10, 5)]
@ -1207,10 +1214,12 @@ namespace Microsoft.Data.Analysis.Tests
} }
[Fact] [Fact]
#pragma warning disable CS0612, CS0618 // Type or member is obsolete
public void TestApply() public void TestApply()
{ {
int[] values = { 1, 2, 3, 4, 5 }; int[] values = { 1, 2, 3, 4, 5 };
var col = new Int32DataFrameColumn("Ints", values); var col = new Int32DataFrameColumn("Ints", values);
PrimitiveDataFrameColumn<double> newCol = col.Apply(i => i + 0.5d); PrimitiveDataFrameColumn<double> newCol = col.Apply(i => i + 0.5d);
Assert.Equal(values.Length, newCol.Length); Assert.Equal(values.Length, newCol.Length);
@ -1221,6 +1230,7 @@ namespace Microsoft.Data.Analysis.Tests
Assert.Equal(newCol[i], values[i] + 0.5d); Assert.Equal(newCol[i], values[i] + 0.5d);
} }
} }
#pragma warning disable CS0612, CS0618 // Type or member is obsolete
[Fact] [Fact]
public void TestDataFrameCreate() public void TestDataFrameCreate()

Просмотреть файл

@ -3,10 +3,7 @@
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text;
using Apache.Arrow;
using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.TestFramework.Attributes;
using Xunit; using Xunit;
@ -269,21 +266,6 @@ namespace Microsoft.Data.Analysis.Tests
Assert.Equal(intColumn[i], copy[i]); Assert.Equal(intColumn[i], copy[i]);
} }
[Fact]
public void TestClone_StringColumn()
{
var strColumn = new StringDataFrameColumn("Str", ["str1", "str2", "srt3", null]);
var copy = strColumn.Clone();
Assert.Equal(strColumn.Name, copy.Name);
Assert.Equal(strColumn.Length, copy.Length);
Assert.Equal(strColumn.DataType, copy.DataType);
Assert.Equal(strColumn.NullCount, copy.NullCount);
for (int i = 0; i < strColumn.Length; i++)
Assert.Equal(strColumn[i], copy[i]);
}
[Fact] [Fact]
public void TestNotNullableColumnClone() public void TestNotNullableColumnClone()
{ {
@ -326,7 +308,7 @@ namespace Microsoft.Data.Analysis.Tests
} }
[Fact] [Fact]
public void TestNotNullableColumnCloneWithIndicesMap() public void TestNotNullableColumnClone_WithIntIndicesMap()
{ {
//Arrange //Arrange
var column = new Int32DataFrameColumn("Int column", values: new[] { 0, 5, 2, 4, 1, 3 }); var column = new Int32DataFrameColumn("Int column", values: new[] { 0, 5, 2, 4, 1, 3 });
@ -346,6 +328,69 @@ namespace Microsoft.Data.Analysis.Tests
Assert.Equal(column[indicesMap[i].Value], clonedColumn[i]); Assert.Equal(column[indicesMap[i].Value], clonedColumn[i]);
} }
[Fact]
public void TestNotNullableColumnClone_WithIntIndicesMap_Invert()
{
//Arrange
var column = new Int32DataFrameColumn("Int column", values: new int?[] { 0, 5, null, 4, 1, 3 });
var indicesMap = new Int32DataFrameColumn("Indices", new[] { 0, 1, 2, 2, 5, 3, 4 });
//Act
var clonedColumn = column.Clone(indicesMap, true);
//Assert
Assert.NotSame(column, clonedColumn);
Assert.Equal(column.Name, clonedColumn.Name);
Assert.Equal(column.DataType, clonedColumn.DataType);
Assert.Equal(2, clonedColumn.NullCount);
Assert.Equal(indicesMap.Length, clonedColumn.Length);
for (int i = 0; i < indicesMap.Length; i++)
Assert.Equal(column[indicesMap[indicesMap.Length - 1 - i].Value], clonedColumn[i]);
}
[Fact]
public void TestNotNullableColumnClone_WithLongIndicesMap()
{
//Arrange
var column = new Int32DataFrameColumn("Int column", values: new[] { 0, 5, 2, 4, 1, 3 });
var indicesMap = new Int64DataFrameColumn("Indices", new long[] { 0, 1, 2, 5, 3, 4 });
//Act
var clonedColumn = column.Clone(indicesMap);
//Assert
Assert.NotSame(column, clonedColumn);
Assert.Equal(column.Name, clonedColumn.Name);
Assert.Equal(column.DataType, clonedColumn.DataType);
Assert.Equal(column.NullCount, clonedColumn.NullCount);
Assert.Equal(indicesMap.Length, clonedColumn.Length);
for (int i = 0; i < indicesMap.Length; i++)
Assert.Equal(column[indicesMap[i].Value], clonedColumn[i]);
}
[Fact]
public void TestNotNullableColumnClone_WithLongIndicesMap_Invert()
{
//Arrange
var column = new Int32DataFrameColumn("Int column", values: new int?[] { 0, 5, null, 4, 1, 3 });
var indicesMap = new Int64DataFrameColumn("Indices", new long[] { 0, 1, 2, 5, 3, 4, 4, 2 });
//Act
var clonedColumn = column.Clone(indicesMap, true);
//Assert
Assert.NotSame(column, clonedColumn);
Assert.Equal(column.Name, clonedColumn.Name);
Assert.Equal(column.DataType, clonedColumn.DataType);
Assert.Equal(2, clonedColumn.NullCount);
Assert.Equal(indicesMap.Length, clonedColumn.Length);
for (int i = 0; i < indicesMap.Length; i++)
Assert.Equal(column[indicesMap[indicesMap.Length - 1 - i].Value], clonedColumn[i]);
}
[Fact] [Fact]
public void TestNotNullableColumnCloneWithIndicesMapAsEnumerableLong() public void TestNotNullableColumnCloneWithIndicesMapAsEnumerableLong()
{ {
@ -580,6 +625,41 @@ namespace Microsoft.Data.Analysis.Tests
Assert.Null(div[3]); // null / null Assert.Null(div[3]); // null / null
} }
[Fact]
public void TestApply_InPlace()
{
// Arrange
var column = new Int32DataFrameColumn("int", new int?[] { 0, 1, 2, null, null, 5 });
column.Apply(x => x * 2, true);
// Assert
Assert.Equal(0, column[0]);
Assert.Equal(2, column[1]);
Assert.Equal(4, column[2]);
Assert.Null(column[3]);
Assert.Null(column[4]);
Assert.Equal(10, column[5]);
}
[Fact]
public void TestDropNulls()
{
// Arrange
var column = new Int32DataFrameColumn("int", new int?[] { null, 0, 1, 2, null, null, 3, null });
var res = column.DropNulls();
// Assert
Assert.Equal(4, res.Length);
Assert.Equal(0, res.NullCount);
Assert.Equal(0, res[0]);
Assert.Equal(1, res[1]);
Assert.Equal(2, res[2]);
Assert.Equal(3, res[3]);
}
//#if !NETFRAMEWORK // https://github.com/dotnet/corefxlab/issues/2796 //#if !NETFRAMEWORK // https://github.com/dotnet/corefxlab/issues/2796
// [Fact] // [Fact]
// public void TestPrimitiveColumnGetReadOnlyBuffers() // public void TestPrimitiveColumnGetReadOnlyBuffers()

Просмотреть файл

@ -0,0 +1,95 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Data.Analysis.Tests
{
public class StringDataFrameColumnTests
{
[Fact]
public void TestColumnClone()
{
var stringColumn = new StringDataFrameColumn("Test", new[] { "Zero", "One", "Two", null, "Four", "Five" });
var clonedColumn = stringColumn.Clone();
Assert.NotSame(stringColumn, clonedColumn);
Assert.Equal(stringColumn.Name, clonedColumn.Name);
Assert.Equal(stringColumn.Length, clonedColumn.Length);
Assert.Equal(stringColumn.NullCount, clonedColumn.NullCount);
for (int i = 0; i < stringColumn.Length; i++)
Assert.Equal(stringColumn[i], clonedColumn[i]);
}
[Fact]
public void TestColumnClone_WithIntMapIndices()
{
var mapIndices = new[] { 0, 1, 2, 2, 3, 4, 5 };
var stringColumn = new StringDataFrameColumn("Test", ["Zero", "One", null, "Three", "Four", "Five"]);
var clonedColumn = stringColumn.Clone(new Int32DataFrameColumn("Map Indices", mapIndices));
Assert.NotSame(stringColumn, clonedColumn);
Assert.Equal(stringColumn.Name, clonedColumn.Name);
Assert.Equal(mapIndices.Length, clonedColumn.Length);
Assert.Equal(2, clonedColumn.NullCount);
for (int i = 0; i < mapIndices.Length; i++)
Assert.Equal(stringColumn[mapIndices[i]], clonedColumn[i]);
}
[Fact]
public void TestColumnClone_WithIntMapIndices_InvertIndices()
{
var mapIndices = new[] { 0, 1, 2, 2, 3, 4, 5 };
var stringColumn = new StringDataFrameColumn("Test", ["Zero", "One", null, "Three", "Four", "Five"]);
var clonedColumn = stringColumn.Clone(new Int32DataFrameColumn("Map Indices", mapIndices), true);
Assert.NotSame(stringColumn, clonedColumn);
Assert.Equal(stringColumn.Name, clonedColumn.Name);
Assert.Equal(mapIndices.Length, clonedColumn.Length);
Assert.Equal(2, clonedColumn.NullCount);
for (int i = 0; i < mapIndices.Length; i++)
Assert.Equal(stringColumn[mapIndices[mapIndices.Length - 1 - i]], clonedColumn[i]);
}
[Fact]
public void TestColumnClone_WithLongMapIndices()
{
var mapIndices = new long[] { 0, 1, 2, 2, 3, 4, 5 };
var stringColumn = new StringDataFrameColumn("Test", ["Zero", "One", null, "Three", "Four", "Five"]);
var clonedColumn = stringColumn.Clone(new Int64DataFrameColumn("Map Indices", mapIndices));
Assert.NotSame(stringColumn, clonedColumn);
Assert.Equal(stringColumn.Name, clonedColumn.Name);
Assert.Equal(mapIndices.Length, clonedColumn.Length);
Assert.Equal(2, clonedColumn.NullCount);
for (int i = 0; i < mapIndices.Length; i++)
Assert.Equal(stringColumn[mapIndices[i]], clonedColumn[i]);
}
[Fact]
public void TestColumnClone_WithLongMapIndices_InvertIndices()
{
var mapIndices = new long[] { 0, 1, 2, 2, 3, 4, 5 };
var stringColumn = new StringDataFrameColumn("Test", ["Zero", "One", "Two", null, "Four", "Five"]);
var clonedColumn = stringColumn.Clone(new Int64DataFrameColumn("Map Indices", mapIndices), true);
Assert.Equal(1, clonedColumn.NullCount);
Assert.NotSame(stringColumn, clonedColumn);
Assert.Equal(stringColumn.Name, clonedColumn.Name);
Assert.Equal(mapIndices.Length, clonedColumn.Length);
for (int i = 0; i < mapIndices.Length; i++)
Assert.Equal(stringColumn[mapIndices[mapIndices.Length - 1 - i]], clonedColumn[i]);
}
}
}

Просмотреть файл

@ -96,7 +96,7 @@ namespace Microsoft.ML.Tokenizers.Tests
public static IEnumerable<object[]> ModelUrlData() public static IEnumerable<object[]> ModelUrlData()
{ {
yield return new object[] { GPT4, @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" }; yield return new object[] { GPT4, @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" };
yield return new object[] { GPT2, @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken" }; yield return new object[] { GPT2, @"https://fossies.org/linux/misc/whisper-20231117.tar.gz/whisper-20231117/whisper/assets/gpt2.tiktoken?m=b" };
yield return new object[] { P50kBase, @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" }; yield return new object[] { P50kBase, @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" };
yield return new object[] { R50kBase, @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken" }; yield return new object[] { R50kBase, @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken" };
yield return new object[] { GPT4o, @"https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" }; yield return new object[] { GPT4o, @"https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" };