Fix Saving csv with VBufferDataFrameColumn (#6860)

This commit is contained in:
Aleksei Smirnov 2023-10-14 08:23:43 +03:00 коммит произвёл GitHub
Родитель e3ec250d51
Коммит 9c183fc35b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 106 добавлений и 68 удалений

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
@ -11,6 +12,7 @@ using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Data;
namespace Microsoft.Data.Analysis
{
@ -675,58 +677,7 @@ namespace Microsoft.Data.Analysis
foreach (var row in dataFrame.Rows)
{
bool firstCell = true;
foreach (var cell in row)
{
if (!firstCell)
{
record.Append(separator);
}
else
{
firstCell = false;
}
Type t = cell?.GetType();
if (t == typeof(bool))
{
record.AppendFormat(cultureInfo, "{0}", cell);
continue;
}
if (t == typeof(float))
{
record.AppendFormat(cultureInfo, "{0:G9}", cell);
continue;
}
if (t == typeof(double))
{
record.AppendFormat(cultureInfo, "{0:G17}", cell);
continue;
}
if (t == typeof(decimal))
{
record.AppendFormat(cultureInfo, "{0:G31}", cell);
continue;
}
if (t == typeof(string))
{
string stringCell = (string)cell;
if (NeedsQuotes(stringCell, separator))
{
record.Append('\"');
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
record.Append('\"');
continue;
}
}
record.Append(cell);
}
AppendValuesToRecord(record, row, separator, cultureInfo);
csvFile.WriteLine(record);
@ -736,6 +687,54 @@ namespace Microsoft.Data.Analysis
}
}
private static void AppendValuesToRecord(StringBuilder record, IEnumerable values, char separator, CultureInfo cultureInfo)
{
bool firstCell = true;
foreach (var value in values)
{
if (!firstCell)
{
record.Append(separator);
}
else
{
firstCell = false;
}
switch (value)
{
case bool:
record.AppendFormat(cultureInfo, "{0}", value);
continue;
case float:
record.AppendFormat(cultureInfo, "{0:G9}", value);
continue;
case double:
record.AppendFormat(cultureInfo, "{0:G17}", value);
continue;
case decimal:
record.AppendFormat(cultureInfo, "{0:G31}", value);
continue;
case string stringCell:
if (NeedsQuotes(stringCell, separator))
{
record.Append('\"');
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
record.Append('\"');
continue;
}
break;
case IEnumerable nestedValues:
record.Append("(");
AppendValuesToRecord(record, nestedValues, ' ', cultureInfo);
record.Append(")");
continue;
}
record.Append(value);
}
}
private static void SaveHeader(StreamWriter csvFile, IReadOnlyList<string> columnNames, char separator)
{
bool firstColumn = true;

Просмотреть файл

@ -231,7 +231,7 @@ namespace Microsoft.Data.Analysis
return new PrimitiveDataFrameColumn<T>(name, length);
}
internal T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex];
protected T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex];
protected override object GetValue(long rowIndex) => GetTypedValue(rowIndex);

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using Microsoft.ML.Internal.DataView;
@ -27,7 +28,7 @@ namespace Microsoft.ML.Data
/// a value is sufficient to make a completely independent copy of it. So, for example, this means that a buffer of
/// buffers is not possible. But, things like <see cref="int"/>, <see cref="float"/>, and <see
/// cref="ReadOnlyMemory{Char}"/>, are totally fine.</typeparam>
public readonly struct VBuffer<T>
public readonly struct VBuffer<T> : IEnumerable
{
/// <summary>
/// The internal re-usable array of values.
@ -403,6 +404,14 @@ namespace Microsoft.ML.Data
public override string ToString()
=> IsDense ? $"Dense vector of size {Length}" : $"Sparse vector of size {Length}, {_count} explicit values";
/// <summary>
/// Returns an enumerator that iterates through the values in VBuffer.
/// </summary>
public IEnumerator GetEnumerator()
{
return _values.GetEnumerator();
}
internal VBufferEditor<T> GetEditor()
{
return GetEditor(Length, _count);

Просмотреть файл

@ -15,6 +15,7 @@ using System.Data.SQLite.EF6;
using Xunit;
using Microsoft.ML.TestFramework.Attributes;
using System.Threading;
using Microsoft.ML.Data;
namespace Microsoft.Data.Analysis.Tests
{
@ -273,7 +274,7 @@ MT";
[Theory]
[InlineData(false)]
[InlineData(true)]
public void TestReadCsvNoHeader(bool useQuotes)
public void TestLoadCsvNoHeader(bool useQuotes)
{
string CMT = useQuotes ? @"""C,MT""" : "CMT";
string verifyCMT = useQuotes ? "C,MT" : "CMT";
@ -349,7 +350,7 @@ MT";
[InlineData(false, 0)]
[InlineData(true, 10)]
[InlineData(false, 10)]
public void TestReadCsvWithTypesAndGuessRows(bool header, int guessRows)
public void TestLoadCsvWithTypesAndGuessRows(bool header, int guessRows)
{
/* Tests this matrix
*
@ -472,7 +473,7 @@ CMT,1,1,181,0.6,CSH,4.5";
}
[Fact]
public void TestReadCsvWithTypesDateTime()
public void TestLoadCsvWithTypesDateTime()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,date
CMT,1,1,1271,3.8,CRD,17.5,1-june-2020
@ -549,7 +550,7 @@ CMT,1,1,181,0.6,CSH,4.5,4-june-2020";
}
[Fact]
public void TestReadCsvWithPipeSeparator()
public void TestLoadCsvWithPipeSeparator()
{
string data = @"vendor_id|rate_code|passenger_count|trip_time_in_secs|trip_distance|payment_type|fare_amount
CMT|1|1|1271|3.8|CRD|17.5
@ -588,7 +589,7 @@ CMT|1|1|181|0.6|CSH|4.5";
}
[Fact]
public void TestReadCsvWithSemicolonSeparator()
public void TestLoadCsvWithSemicolonSeparator()
{
string data = @"vendor_id;rate_code;passenger_count;trip_time_in_secs;trip_distance;payment_type;fare_amount
CMT;1;1;1271;3.8;CRD;17.5
@ -627,7 +628,7 @@ CMT;1;1;181;0.6;CSH;4.5";
}
[Fact]
public void TestReadCsvWithExtraColumnInHeader()
public void TestLoadCsvWithExtraColumnInHeader()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,extra
CMT,1,1,1271,3.8,CRD,17.5
@ -656,7 +657,7 @@ CMT,1,1,181,0.6,CSH,4.5";
}
[Fact]
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
CMT,1,1,1271,3.8,CRD,17.5,0
@ -671,7 +672,7 @@ CMT,1,1,181,0.6,CSH,4.5,0";
}
[Fact]
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn()
public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
CMT,1,1,1271,3.8,CRD,17.5,0
@ -713,7 +714,7 @@ CMT,1,1,181,0.6,CSH,4.5,0";
}
[Fact]
public void TestReadCsvWithExtraColumnInRow()
public void TestLoadCsvWithExtraColumnInRow()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD,17.5,0
@ -726,7 +727,7 @@ CMT,1,1,181,0.6,CSH,4.5,0";
}
[Fact]
public void TestReadCsvWithLessColumnsInRow()
public void TestLoadCsvWithLessColumnsInRow()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD
@ -755,7 +756,7 @@ CMT,1,1,181,0.6,CSH";
}
[Fact]
public void TestReadCsvWithAllNulls()
public void TestLoadCsvWithAllNulls()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,null,null,null
@ -798,7 +799,7 @@ null,null,null,null";
}
[Fact]
public void TestReadCsvWithNullsAndDataTypes()
public void TestLoadCsvWithNullsAndDataTypes()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,1,1,1271
@ -860,7 +861,7 @@ CMT,1,1,null";
}
[Fact]
public void TestReadCsvWithNulls()
public void TestLoadCsvWithNulls()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,1,1,1271
@ -922,7 +923,36 @@ CMT,1,1,null";
}
[Fact]
public void TestWriteCsvWithHeader()
public void TestSaveCsvVBufferColumn()
{
var vBuffers = new[]
{
new VBuffer<int> (3, new int[] { 1, 2, 3 }),
new VBuffer<int> (3, new int[] { 2, 3, 4 }),
new VBuffer<int> (3, new int[] { 3, 4, 5 }),
};
var vBufferColumn = new VBufferDataFrameColumn<int>("VBuffer", vBuffers);
DataFrame dataFrame = new DataFrame(vBufferColumn);
using MemoryStream csvStream = new MemoryStream();
DataFrame.SaveCsv(dataFrame, csvStream);
csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream);
Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
Assert.Equal(typeof(string), readIn.Columns[0].DataType);
Assert.Equal("(1 2 3)", readIn[0, 0]);
Assert.Equal("(2 3 4)", readIn[1, 0]);
Assert.Equal("(3 4 5)", readIn[2, 0]);
}
[Fact]
public void TestSaveCsvWithHeader()
{
using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, true);