diff --git a/src/Microsoft.Data.Analysis/DataFrame.IO.cs b/src/Microsoft.Data.Analysis/DataFrame.IO.cs index a2a55612a..7fb56833d 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IO.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IO.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections; using System.Collections.Generic; using System.Data; using System.Data.Common; @@ -11,6 +12,7 @@ using System.IO; using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.ML.Data; namespace Microsoft.Data.Analysis { @@ -675,58 +677,7 @@ namespace Microsoft.Data.Analysis foreach (var row in dataFrame.Rows) { - bool firstCell = true; - foreach (var cell in row) - { - if (!firstCell) - { - record.Append(separator); - } - else - { - firstCell = false; - } - - Type t = cell?.GetType(); - - if (t == typeof(bool)) - { - record.AppendFormat(cultureInfo, "{0}", cell); - continue; - } - - if (t == typeof(float)) - { - record.AppendFormat(cultureInfo, "{0:G9}", cell); - continue; - } - - if (t == typeof(double)) - { - record.AppendFormat(cultureInfo, "{0:G17}", cell); - continue; - } - - if (t == typeof(decimal)) - { - record.AppendFormat(cultureInfo, "{0:G31}", cell); - continue; - } - - if (t == typeof(string)) - { - string stringCell = (string)cell; - if (NeedsQuotes(stringCell, separator)) - { - record.Append('\"'); - record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation - record.Append('\"'); - continue; - } - } - - record.Append(cell); - } + AppendValuesToRecord(record, row, separator, cultureInfo); csvFile.WriteLine(record); @@ -736,6 +687,54 @@ namespace Microsoft.Data.Analysis } } + private static void AppendValuesToRecord(StringBuilder record, IEnumerable values, char separator, CultureInfo cultureInfo) + { + bool firstCell = true; + foreach (var value in values) + { + if (!firstCell) + { + record.Append(separator); + } + else + { + firstCell = false; + } + + switch (value) + { + case bool: + record.AppendFormat(cultureInfo, "{0}", value); + continue; + case float: + record.AppendFormat(cultureInfo, "{0:G9}", value); + continue; + case double: + record.AppendFormat(cultureInfo, "{0:G17}", value); + continue; + case decimal: + record.AppendFormat(cultureInfo, "{0:G31}", value); + continue; + case string stringCell: + if (NeedsQuotes(stringCell, separator)) + { + record.Append('\"'); + record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation + record.Append('\"'); + continue; + } + break; + case IEnumerable nestedValues: + record.Append("("); + AppendValuesToRecord(record, nestedValues, ' ', cultureInfo); + record.Append(")"); + continue; + } + + record.Append(value); + } + } + private static void SaveHeader(StreamWriter csvFile, IReadOnlyList columnNames, char separator) { bool firstColumn = true; diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index 6fe5270de..54fc1744a 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -231,7 +231,7 @@ namespace Microsoft.Data.Analysis return new PrimitiveDataFrameColumn(name, length); } - internal T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex]; + protected T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex]; protected override object GetValue(long rowIndex) => GetTypedValue(rowIndex); diff --git a/src/Microsoft.ML.DataView/VBuffer.cs b/src/Microsoft.ML.DataView/VBuffer.cs index 6d2372637..326b762d0 100644 --- a/src/Microsoft.ML.DataView/VBuffer.cs +++ b/src/Microsoft.ML.DataView/VBuffer.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections; using System.Collections.Generic; using System.Diagnostics; using Microsoft.ML.Internal.DataView; @@ -27,7 +28,7 @@ namespace Microsoft.ML.Data /// a value is sufficient to make a completely independent copy of it. So, for example, this means that a buffer of /// buffers is not possible. But, things like , , and , are totally fine. - public readonly struct VBuffer + public readonly struct VBuffer : IEnumerable { /// /// The internal re-usable array of values. @@ -403,6 +404,14 @@ namespace Microsoft.ML.Data public override string ToString() => IsDense ? $"Dense vector of size {Length}" : $"Sparse vector of size {Length}, {_count} explicit values"; + /// + /// Returns an enumerator that iterates through the values in VBuffer. + /// + public IEnumerator GetEnumerator() + { + return _values.GetEnumerator(); + } + internal VBufferEditor GetEditor() { return GetEditor(Length, _count); diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs index 8fb4c89fd..5c1eb7aae 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs @@ -15,6 +15,7 @@ using System.Data.SQLite.EF6; using Xunit; using Microsoft.ML.TestFramework.Attributes; using System.Threading; +using Microsoft.ML.Data; namespace Microsoft.Data.Analysis.Tests { @@ -273,7 +274,7 @@ MT"; [Theory] [InlineData(false)] [InlineData(true)] - public void TestReadCsvNoHeader(bool useQuotes) + public void TestLoadCsvNoHeader(bool useQuotes) { string CMT = useQuotes ? @"""C,MT""" : "CMT"; string verifyCMT = useQuotes ? "C,MT" : "CMT"; @@ -349,7 +350,7 @@ MT"; [InlineData(false, 0)] [InlineData(true, 10)] [InlineData(false, 10)] - public void TestReadCsvWithTypesAndGuessRows(bool header, int guessRows) + public void TestLoadCsvWithTypesAndGuessRows(bool header, int guessRows) { /* Tests this matrix * @@ -472,7 +473,7 @@ CMT,1,1,181,0.6,CSH,4.5"; } [Fact] - public void TestReadCsvWithTypesDateTime() + public void TestLoadCsvWithTypesDateTime() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,date CMT,1,1,1271,3.8,CRD,17.5,1-june-2020 @@ -549,7 +550,7 @@ CMT,1,1,181,0.6,CSH,4.5,4-june-2020"; } [Fact] - public void TestReadCsvWithPipeSeparator() + public void TestLoadCsvWithPipeSeparator() { string data = @"vendor_id|rate_code|passenger_count|trip_time_in_secs|trip_distance|payment_type|fare_amount CMT|1|1|1271|3.8|CRD|17.5 @@ -588,7 +589,7 @@ CMT|1|1|181|0.6|CSH|4.5"; } [Fact] - public void TestReadCsvWithSemicolonSeparator() + public void TestLoadCsvWithSemicolonSeparator() { string data = @"vendor_id;rate_code;passenger_count;trip_time_in_secs;trip_distance;payment_type;fare_amount CMT;1;1;1271;3.8;CRD;17.5 @@ -627,7 +628,7 @@ CMT;1;1;181;0.6;CSH;4.5"; } [Fact] - public void TestReadCsvWithExtraColumnInHeader() + public void TestLoadCsvWithExtraColumnInHeader() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,extra CMT,1,1,1271,3.8,CRD,17.5 @@ -656,7 +657,7 @@ CMT,1,1,181,0.6,CSH,4.5"; } [Fact] - public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn() + public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,, CMT,1,1,1271,3.8,CRD,17.5,0 @@ -671,7 +672,7 @@ CMT,1,1,181,0.6,CSH,4.5,0"; } [Fact] - public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn() + public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,, CMT,1,1,1271,3.8,CRD,17.5,0 @@ -713,7 +714,7 @@ CMT,1,1,181,0.6,CSH,4.5,0"; } [Fact] - public void TestReadCsvWithExtraColumnInRow() + public void TestLoadCsvWithExtraColumnInRow() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount CMT,1,1,1271,3.8,CRD,17.5,0 @@ -726,7 +727,7 @@ CMT,1,1,181,0.6,CSH,4.5,0"; } [Fact] - public void TestReadCsvWithLessColumnsInRow() + public void TestLoadCsvWithLessColumnsInRow() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount CMT,1,1,1271,3.8,CRD @@ -755,7 +756,7 @@ CMT,1,1,181,0.6,CSH"; } [Fact] - public void TestReadCsvWithAllNulls() + public void TestLoadCsvWithAllNulls() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs null,null,null,null @@ -798,7 +799,7 @@ null,null,null,null"; } [Fact] - public void TestReadCsvWithNullsAndDataTypes() + public void TestLoadCsvWithNullsAndDataTypes() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs null,1,1,1271 @@ -860,7 +861,7 @@ CMT,1,1,null"; } [Fact] - public void TestReadCsvWithNulls() + public void TestLoadCsvWithNulls() { string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs null,1,1,1271 @@ -922,7 +923,36 @@ CMT,1,1,null"; } [Fact] - public void TestWriteCsvWithHeader() + public void TestSaveCsvVBufferColumn() + { + var vBuffers = new[] + { + new VBuffer (3, new int[] { 1, 2, 3 }), + new VBuffer (3, new int[] { 2, 3, 4 }), + new VBuffer (3, new int[] { 3, 4, 5 }), + }; + + var vBufferColumn = new VBufferDataFrameColumn("VBuffer", vBuffers); + DataFrame dataFrame = new DataFrame(vBufferColumn); + + using MemoryStream csvStream = new MemoryStream(); + + DataFrame.SaveCsv(dataFrame, csvStream); + + csvStream.Seek(0, SeekOrigin.Begin); + DataFrame readIn = DataFrame.LoadCsv(csvStream); + + Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count); + Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count); + + Assert.Equal(typeof(string), readIn.Columns[0].DataType); + Assert.Equal("(1 2 3)", readIn[0, 0]); + Assert.Equal("(2 3 4)", readIn[1, 0]); + Assert.Equal("(3 4 5)", readIn[2, 0]); + } + + [Fact] + public void TestSaveCsvWithHeader() { using MemoryStream csvStream = new MemoryStream(); DataFrame dataFrame = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, true);