Fix Saving csv with VBufferDataFrameColumn (#6860)

This commit is contained in:
Aleksei Smirnov 2023-10-14 08:23:43 +03:00 коммит произвёл GitHub
Родитель e3ec250d51
Коммит 9c183fc35b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 106 добавлений и 68 удалений

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Data; using System.Data;
using System.Data.Common; using System.Data.Common;
@ -11,6 +12,7 @@ using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.ML.Data;
namespace Microsoft.Data.Analysis namespace Microsoft.Data.Analysis
{ {
@ -675,58 +677,7 @@ namespace Microsoft.Data.Analysis
foreach (var row in dataFrame.Rows) foreach (var row in dataFrame.Rows)
{ {
bool firstCell = true; AppendValuesToRecord(record, row, separator, cultureInfo);
foreach (var cell in row)
{
if (!firstCell)
{
record.Append(separator);
}
else
{
firstCell = false;
}
Type t = cell?.GetType();
if (t == typeof(bool))
{
record.AppendFormat(cultureInfo, "{0}", cell);
continue;
}
if (t == typeof(float))
{
record.AppendFormat(cultureInfo, "{0:G9}", cell);
continue;
}
if (t == typeof(double))
{
record.AppendFormat(cultureInfo, "{0:G17}", cell);
continue;
}
if (t == typeof(decimal))
{
record.AppendFormat(cultureInfo, "{0:G31}", cell);
continue;
}
if (t == typeof(string))
{
string stringCell = (string)cell;
if (NeedsQuotes(stringCell, separator))
{
record.Append('\"');
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
record.Append('\"');
continue;
}
}
record.Append(cell);
}
csvFile.WriteLine(record); csvFile.WriteLine(record);
@ -736,6 +687,54 @@ namespace Microsoft.Data.Analysis
} }
} }
private static void AppendValuesToRecord(StringBuilder record, IEnumerable values, char separator, CultureInfo cultureInfo)
{
bool firstCell = true;
foreach (var value in values)
{
if (!firstCell)
{
record.Append(separator);
}
else
{
firstCell = false;
}
switch (value)
{
case bool:
record.AppendFormat(cultureInfo, "{0}", value);
continue;
case float:
record.AppendFormat(cultureInfo, "{0:G9}", value);
continue;
case double:
record.AppendFormat(cultureInfo, "{0:G17}", value);
continue;
case decimal:
record.AppendFormat(cultureInfo, "{0:G31}", value);
continue;
case string stringCell:
if (NeedsQuotes(stringCell, separator))
{
record.Append('\"');
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
record.Append('\"');
continue;
}
break;
case IEnumerable nestedValues:
record.Append("(");
AppendValuesToRecord(record, nestedValues, ' ', cultureInfo);
record.Append(")");
continue;
}
record.Append(value);
}
}
private static void SaveHeader(StreamWriter csvFile, IReadOnlyList<string> columnNames, char separator) private static void SaveHeader(StreamWriter csvFile, IReadOnlyList<string> columnNames, char separator)
{ {
bool firstColumn = true; bool firstColumn = true;

Просмотреть файл

@ -231,7 +231,7 @@ namespace Microsoft.Data.Analysis
return new PrimitiveDataFrameColumn<T>(name, length); return new PrimitiveDataFrameColumn<T>(name, length);
} }
internal T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex]; protected T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex];
protected override object GetValue(long rowIndex) => GetTypedValue(rowIndex); protected override object GetValue(long rowIndex) => GetTypedValue(rowIndex);

Просмотреть файл

@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information. // See the LICENSE file in the project root for more information.
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using Microsoft.ML.Internal.DataView; using Microsoft.ML.Internal.DataView;
@ -27,7 +28,7 @@ namespace Microsoft.ML.Data
/// a value is sufficient to make a completely independent copy of it. So, for example, this means that a buffer of /// a value is sufficient to make a completely independent copy of it. So, for example, this means that a buffer of
/// buffers is not possible. But, things like <see cref="int"/>, <see cref="float"/>, and <see /// buffers is not possible. But, things like <see cref="int"/>, <see cref="float"/>, and <see
/// cref="ReadOnlyMemory{Char}"/>, are totally fine.</typeparam> /// cref="ReadOnlyMemory{Char}"/>, are totally fine.</typeparam>
public readonly struct VBuffer<T> public readonly struct VBuffer<T> : IEnumerable
{ {
/// <summary> /// <summary>
/// The internal re-usable array of values. /// The internal re-usable array of values.
@ -403,6 +404,14 @@ namespace Microsoft.ML.Data
public override string ToString() public override string ToString()
=> IsDense ? $"Dense vector of size {Length}" : $"Sparse vector of size {Length}, {_count} explicit values"; => IsDense ? $"Dense vector of size {Length}" : $"Sparse vector of size {Length}, {_count} explicit values";
/// <summary>
/// Returns an enumerator that iterates through the values in VBuffer.
/// </summary>
public IEnumerator GetEnumerator()
{
return _values.GetEnumerator();
}
internal VBufferEditor<T> GetEditor() internal VBufferEditor<T> GetEditor()
{ {
return GetEditor(Length, _count); return GetEditor(Length, _count);

Просмотреть файл

@ -15,6 +15,7 @@ using System.Data.SQLite.EF6;
using Xunit; using Xunit;
using Microsoft.ML.TestFramework.Attributes; using Microsoft.ML.TestFramework.Attributes;
using System.Threading; using System.Threading;
using Microsoft.ML.Data;
namespace Microsoft.Data.Analysis.Tests namespace Microsoft.Data.Analysis.Tests
{ {
@ -273,7 +274,7 @@ MT";
[Theory] [Theory]
[InlineData(false)] [InlineData(false)]
[InlineData(true)] [InlineData(true)]
public void TestReadCsvNoHeader(bool useQuotes) public void TestLoadCsvNoHeader(bool useQuotes)
{ {
string CMT = useQuotes ? @"""C,MT""" : "CMT"; string CMT = useQuotes ? @"""C,MT""" : "CMT";
string verifyCMT = useQuotes ? "C,MT" : "CMT"; string verifyCMT = useQuotes ? "C,MT" : "CMT";
@ -349,7 +350,7 @@ MT";
[InlineData(false, 0)] [InlineData(false, 0)]
[InlineData(true, 10)] [InlineData(true, 10)]
[InlineData(false, 10)] [InlineData(false, 10)]
public void TestReadCsvWithTypesAndGuessRows(bool header, int guessRows) public void TestLoadCsvWithTypesAndGuessRows(bool header, int guessRows)
{ {
/* Tests this matrix /* Tests this matrix
* *
@ -472,7 +473,7 @@ CMT,1,1,181,0.6,CSH,4.5";
} }
[Fact] [Fact]
public void TestReadCsvWithTypesDateTime() public void TestLoadCsvWithTypesDateTime()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,date string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,date
CMT,1,1,1271,3.8,CRD,17.5,1-june-2020 CMT,1,1,1271,3.8,CRD,17.5,1-june-2020
@ -549,7 +550,7 @@ CMT,1,1,181,0.6,CSH,4.5,4-june-2020";
} }
[Fact] [Fact]
public void TestReadCsvWithPipeSeparator() public void TestLoadCsvWithPipeSeparator()
{ {
string data = @"vendor_id|rate_code|passenger_count|trip_time_in_secs|trip_distance|payment_type|fare_amount string data = @"vendor_id|rate_code|passenger_count|trip_time_in_secs|trip_distance|payment_type|fare_amount
CMT|1|1|1271|3.8|CRD|17.5 CMT|1|1|1271|3.8|CRD|17.5
@ -588,7 +589,7 @@ CMT|1|1|181|0.6|CSH|4.5";
} }
[Fact] [Fact]
public void TestReadCsvWithSemicolonSeparator() public void TestLoadCsvWithSemicolonSeparator()
{ {
string data = @"vendor_id;rate_code;passenger_count;trip_time_in_secs;trip_distance;payment_type;fare_amount string data = @"vendor_id;rate_code;passenger_count;trip_time_in_secs;trip_distance;payment_type;fare_amount
CMT;1;1;1271;3.8;CRD;17.5 CMT;1;1;1271;3.8;CRD;17.5
@ -627,7 +628,7 @@ CMT;1;1;181;0.6;CSH;4.5";
} }
[Fact] [Fact]
public void TestReadCsvWithExtraColumnInHeader() public void TestLoadCsvWithExtraColumnInHeader()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,extra string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,extra
CMT,1,1,1271,3.8,CRD,17.5 CMT,1,1,1271,3.8,CRD,17.5
@ -656,7 +657,7 @@ CMT,1,1,181,0.6,CSH,4.5";
} }
[Fact] [Fact]
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn() public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,, string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
CMT,1,1,1271,3.8,CRD,17.5,0 CMT,1,1,1271,3.8,CRD,17.5,0
@ -671,7 +672,7 @@ CMT,1,1,181,0.6,CSH,4.5,0";
} }
[Fact] [Fact]
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn() public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,, string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
CMT,1,1,1271,3.8,CRD,17.5,0 CMT,1,1,1271,3.8,CRD,17.5,0
@ -713,7 +714,7 @@ CMT,1,1,181,0.6,CSH,4.5,0";
} }
[Fact] [Fact]
public void TestReadCsvWithExtraColumnInRow() public void TestLoadCsvWithExtraColumnInRow()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD,17.5,0 CMT,1,1,1271,3.8,CRD,17.5,0
@ -726,7 +727,7 @@ CMT,1,1,181,0.6,CSH,4.5,0";
} }
[Fact] [Fact]
public void TestReadCsvWithLessColumnsInRow() public void TestLoadCsvWithLessColumnsInRow()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD CMT,1,1,1271,3.8,CRD
@ -755,7 +756,7 @@ CMT,1,1,181,0.6,CSH";
} }
[Fact] [Fact]
public void TestReadCsvWithAllNulls() public void TestLoadCsvWithAllNulls()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,null,null,null null,null,null,null
@ -798,7 +799,7 @@ null,null,null,null";
} }
[Fact] [Fact]
public void TestReadCsvWithNullsAndDataTypes() public void TestLoadCsvWithNullsAndDataTypes()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,1,1,1271 null,1,1,1271
@ -860,7 +861,7 @@ CMT,1,1,null";
} }
[Fact] [Fact]
public void TestReadCsvWithNulls() public void TestLoadCsvWithNulls()
{ {
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,1,1,1271 null,1,1,1271
@ -922,7 +923,36 @@ CMT,1,1,null";
} }
[Fact] [Fact]
public void TestWriteCsvWithHeader() public void TestSaveCsvVBufferColumn()
{
var vBuffers = new[]
{
new VBuffer<int> (3, new int[] { 1, 2, 3 }),
new VBuffer<int> (3, new int[] { 2, 3, 4 }),
new VBuffer<int> (3, new int[] { 3, 4, 5 }),
};
var vBufferColumn = new VBufferDataFrameColumn<int>("VBuffer", vBuffers);
DataFrame dataFrame = new DataFrame(vBufferColumn);
using MemoryStream csvStream = new MemoryStream();
DataFrame.SaveCsv(dataFrame, csvStream);
csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream);
Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
Assert.Equal(typeof(string), readIn.Columns[0].DataType);
Assert.Equal("(1 2 3)", readIn[0, 0]);
Assert.Equal("(2 3 4)", readIn[1, 0]);
Assert.Equal("(3 4 5)", readIn[2, 0]);
}
[Fact]
public void TestSaveCsvWithHeader()
{ {
using MemoryStream csvStream = new MemoryStream(); using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, true); DataFrame dataFrame = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, true);