diff --git a/src/Microsoft.Data.Analysis/DataFrame.IO.cs b/src/Microsoft.Data.Analysis/DataFrame.IO.cs index f057ea7cf..dd0752a0a 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IO.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IO.cs @@ -198,14 +198,23 @@ namespace Microsoft.Data.Analysis Encoding encoding = null) { if (!csvStream.CanSeek) + { throw new ArgumentException(Strings.NonSeekableStream, nameof(csvStream)); + } + + if (dataTypes == null && guessRows <= 0) + { + throw new ArgumentException(string.Format(Strings.ExpectedEitherGuessRowsOrDataTypes, nameof(guessRows), nameof(dataTypes))); + } var linesForGuessType = new List(); long rowline = 0; int numberOfColumns = dataTypes?.Length ?? 0; if (header == true && numberOfRowsToRead != -1) + { numberOfRowsToRead++; + } List columns; long streamStart = csvStream.Position; @@ -213,40 +222,39 @@ namespace Microsoft.Data.Analysis using (var streamReader = new StreamReader(csvStream, encoding ?? Encoding.UTF8, detectEncodingFromByteOrderMarks: true, DefaultStreamReaderBufferSize, leaveOpen: true)) { string line = null; - if (dataTypes == null) + line = streamReader.ReadLine(); + while (line != null) { - line = streamReader.ReadLine(); - while (line != null) + if ((numberOfRowsToRead == -1) || rowline < numberOfRowsToRead) { - if ((numberOfRowsToRead == -1) || rowline < numberOfRowsToRead) + if (linesForGuessType.Count < guessRows || (header && rowline == 0)) { - if (linesForGuessType.Count < guessRows) + var spl = line.Split(separator); + if (header && rowline == 0) { - var spl = line.Split(separator); - if (header && rowline == 0) + if (columnNames == null) { - if (columnNames == null) - columnNames = spl; - } - else - { - linesForGuessType.Add(spl); - numberOfColumns = Math.Max(numberOfColumns, spl.Length); + columnNames = spl; } } + else + { + linesForGuessType.Add(spl); + numberOfColumns = Math.Max(numberOfColumns, spl.Length); + } } - ++rowline; - if (rowline == guessRows) - { - break; - } - line = streamReader.ReadLine(); } - - if (linesForGuessType.Count == 0) + ++rowline; + if (rowline == guessRows || guessRows == 0) { - throw new FormatException(Strings.EmptyFile); + break; } + line = streamReader.ReadLine(); + } + + if (rowline == 0) + { + throw new FormatException(Strings.EmptyFile); } columns = new List(numberOfColumns); diff --git a/src/Microsoft.Data.Analysis/strings.Designer.cs b/src/Microsoft.Data.Analysis/strings.Designer.cs index 6196e234a..92d9a8771 100644 --- a/src/Microsoft.Data.Analysis/strings.Designer.cs +++ b/src/Microsoft.Data.Analysis/strings.Designer.cs @@ -132,6 +132,15 @@ namespace Microsoft.Data { } } + /// + /// Looks up a localized string similar to Expected either {0} or {1} to be provided. + /// + internal static string ExpectedEitherGuessRowsOrDataTypes { + get { + return ResourceManager.GetString("ExpectedEitherGuessRowsOrDataTypes", resourceCulture); + } + } + /// /// Looks up a localized string similar to Column is immutable. /// diff --git a/src/Microsoft.Data.Analysis/strings.resx b/src/Microsoft.Data.Analysis/strings.resx index 446b3b3f3..ba2e3e0b5 100644 --- a/src/Microsoft.Data.Analysis/strings.resx +++ b/src/Microsoft.Data.Analysis/strings.resx @@ -141,6 +141,9 @@ Parameter.Count exceeds the number of columns({0}) in the DataFrame + + Expected either {0} or {1} to be provided + Column is immutable diff --git a/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs b/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs index 0decb3ec2..cd49f02b6 100644 --- a/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs +++ b/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs @@ -151,6 +151,99 @@ CMT,1,1,181,0.6,CSH,4.5"; VerifyColumnTypes(df); } + void VerifyDataFrameWithNamedColumnsAndDataTypes(DataFrame df, bool verifyColumnDataType, bool verifyNames) + { + Assert.Equal(4, df.Rows.Count); + Assert.Equal(7, df.Columns.Count); + + if (verifyColumnDataType) + { + Assert.True(typeof(string) == df.Columns[0].DataType); + Assert.True(typeof(short) == df.Columns[1].DataType); + Assert.True(typeof(int) == df.Columns[2].DataType); + Assert.True(typeof(long) == df.Columns[3].DataType); + Assert.True(typeof(float) == df.Columns[4].DataType); + Assert.True(typeof(string) == df.Columns[5].DataType); + Assert.True(typeof(double) == df.Columns[6].DataType); + } + + if (verifyNames) + { + Assert.Equal("vendor_id", df.Columns[0].Name); + Assert.Equal("rate_code", df.Columns[1].Name); + Assert.Equal("passenger_count", df.Columns[2].Name); + Assert.Equal("trip_time_in_secs", df.Columns[3].Name); + Assert.Equal("trip_distance", df.Columns[4].Name); + Assert.Equal("payment_type", df.Columns[5].Name); + Assert.Equal("fare_amount", df.Columns[6].Name); + } + + VerifyColumnTypes(df); + + foreach (var column in df.Columns) + { + Assert.Equal(0, column.NullCount); + } + } + + [Theory] + [InlineData(true, 0)] + [InlineData(false, 0)] + [InlineData(true, 10)] + [InlineData(false, 10)] + public void TestReadCsvWithTypesAndGuessRows(bool header, int guessRows) + { + /* Tests this matrix + * + header GuessRows DataTypes + True 0 NotNull + False 0 NotNull + True 10 NotNull + False 10 NotNull + True 0 Null -----> Throws an exception + False 0 Null -----> Throws an exception + True 10 Null + False 10 Null + * + */ + string headerLine = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount +"; + string dataLines = +@"CMT,1,1,1271,3.8,CRD,17.5 +CMT,1,1,474,1.5,CRD,8 +CMT,1,1,637,1.4,CRD,8.5 +CMT,1,1,181,0.6,CSH,4.5"; + + Stream GetStream(string streamData) + { + return new MemoryStream(Encoding.Default.GetBytes(streamData)); + } + + string data = header ? headerLine + dataLines : dataLines; + DataFrame df = DataFrame.LoadCsv(GetStream(data), + header: header, + guessRows: guessRows, + dataTypes: new Type[] { typeof(string), typeof(short), typeof(int), typeof(long), typeof(float), typeof(string), typeof(double) } + ); + VerifyDataFrameWithNamedColumnsAndDataTypes(df, verifyColumnDataType: true, verifyNames: header); + + if (guessRows == 10) + { + df = DataFrame.LoadCsv(GetStream(data), + header: header, + guessRows: guessRows + ); + VerifyDataFrameWithNamedColumnsAndDataTypes(df, verifyColumnDataType: false, verifyNames: header); + } + else + { + Assert.ThrowsAny(() => DataFrame.LoadCsv(GetStream(data), + header: header, + guessRows: guessRows + )); + } + } + [Fact] public void TestReadCsvWithTypes() { @@ -176,6 +269,14 @@ CMT,1,1,181,0.6,CSH,4.5"; Assert.True(typeof(float) == df.Columns[4].DataType); Assert.True(typeof(string) == df.Columns[5].DataType); Assert.True(typeof(double) == df.Columns[6].DataType); + + Assert.Equal("vendor_id", df.Columns[0].Name); + Assert.Equal("rate_code", df.Columns[1].Name); + Assert.Equal("passenger_count", df.Columns[2].Name); + Assert.Equal("trip_time_in_secs", df.Columns[3].Name); + Assert.Equal("trip_distance", df.Columns[4].Name); + Assert.Equal("payment_type", df.Columns[5].Name); + Assert.Equal("fare_amount", df.Columns[6].Name); VerifyColumnTypes(df); foreach (var column in df.Columns)